In [1]:
# | output: false
import matplotlib.pyplot as plt
import torch
from transformers import AutoModel, AutoTokenizer

# Set the default style for matplotlib plots
plt.style.use("seaborn-v0_8-whitegrid")

# Load the Qwen3 model and tokenizer
qwen3_model = (
    AutoModel.from_pretrained("Qwen/Qwen3-Embedding-4B", torch_dtype=torch.bfloat16)
    .cuda()
    .eval()
)
qwen3_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Embedding-4B")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [14]:
# | echo: false
query = "Which flora is found in Yosemite?"
text = "Black bears roam the area, while wildflowers and towering conifers thrive."
# Tokenize the input text and query
query_tokenized = qwen3_tokenizer(query, return_tensors="pt").to("cuda")
text_tokenized = qwen3_tokenizer(text, return_tensors="pt").to("cuda")

# Convert back to text
tokenized_query = [qwen3_tokenizer.decode(t) for t in query_tokenized["input_ids"][0]]
tokenized_document = [qwen3_tokenizer.decode(t) for t in text_tokenized["input_ids"][0]]
print("Query text:", tokenized_query)
print("Text text:", tokenized_document)

Query text: ['Which', ' flora', ' is', ' found', ' in', ' Yosemite', '?', '<|endoftext|>']
Text text: ['Black', ' bears', ' roam', ' the', ' area', ',', ' while', ' wild', 'flowers', ' and', ' towering', ' con', 'ifers', ' thrive', '.', '<|endoftext|>']


In [15]:
# | echo: false
import torch.nn.functional as F

# Compute embeddings for the query and text
with torch.inference_mode():
    query_embedding = (
        qwen3_model(**query_tokenized).last_hidden_state.squeeze().detach().cpu()
    )
    text_embedding = (
        qwen3_model(**text_tokenized).last_hidden_state.squeeze().detach().cpu()
    )

In [16]:
# | echo: false
import pickle

# Save the embeddings to a file
with open("assets/project_video_scene_embeddings.pkl", "wb") as f:
    pickle.dump(
        {
            "query_embedding": query_embedding,
            "document_embedding": text_embedding,
            "tokenized_query": tokenized_query,
            "tokenized_document": tokenized_document,
        },
        f,
    )

## Main animation

In [39]:
# | echo: false
from manim import *  # noqa: F403
import pickle

import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.colors as mcolors
import torch
import itertools


class ProjectVideo(MovingCameraScene):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.shown_embeddings = 6

    def get_color_from_cmap(self, value, cmap_name="hot", cmap_range=(0, 1)):
        cmap = cm.get_cmap(cmap_name)
        norm_value = mcolors.Normalize(vmin=cmap_range[0], vmax=cmap_range[1])(value)
        rgba = cmap(norm_value)
        return ManimColor(rgba[:3])  # Only use RGB (ignore A)

    def embed_text(
        self,
        text,
        tokenized_text,
        embedding,
        move_camera,
        title,
        font_sizes=(54, 36, 20),
        speed=1.0,
    ):
        title_tex = MarkupText(
            title,
            font_size=36,
            color=WHITE,
        )
        title_tex.move_to(
            self.camera.frame.get_corner(UL) + DOWN * 0.5 + RIGHT * 0.5, aligned_edge=UL
        )

        query_tex = MarkupText(
            text,
            font_size=font_sizes[0],
        )
        if not move_camera:
            query_tex.shift(DOWN * 1)
        self.play(
            Write(query_tex),
            FadeIn(title_tex, shift=UP * 0.5),
            run_time=2 / speed,
        )

        tokenize_arrow = Arrow(
            query_tex.get_bottom(),
            query_tex.get_bottom() + DOWN * 1.5,
            buff=0.2,
        )
        anims = [GrowArrow(tokenize_arrow)]
        if move_camera:
            anims.append(self.camera.frame.animate.move_to(tokenize_arrow.get_end()))
        self.play(
            *anims,
            run_time=1 / speed,
        )

        query_tokenized_tex = VGroup()
        for i, token in enumerate(tokenized_text):
            if token == "<|endoftext|>":
                # Escape properly
                token_tex = Tex(
                    r"\texttt{<|endoftext|>}",
                    font_size=int(font_sizes[1] * 0.8),
                )
            else:
                token_tex = Tex(
                    token.strip(),
                    font_size=font_sizes[1],
                )
            token_tex.add(
                SurroundingRectangle(token_tex, color=WHITE, buff=0.1, stroke_width=1)
            )
            query_tokenized_tex.add(token_tex)
        # Consistent height for all SurroundingRectangles
        max_height = max(tex.height for tex in query_tokenized_tex)
        for tex in query_tokenized_tex:
            tex.set_height(max_height)
        query_tokenized_tex.arrange(RIGHT, buff=0.1).move_to(
            tokenize_arrow.get_end() + DOWN * 0.5
        )
        self.play(
            LaggedStartMap(
                Write,
                query_tokenized_tex,
                lag_ratio=0.1,
            ),
            run_time=2 / speed,
        )

        # Now create arrows from each token to the embedding
        embedding_arrows = Group()
        for token_tex in query_tokenized_tex:
            arrow = Arrow(
                token_tex.get_bottom(),
                token_tex.get_bottom() + DOWN * 1,
                buff=0.2,
            )
            embedding_arrows.add(arrow)

        anims = [GrowArrow(arrow) for arrow in embedding_arrows]
        if move_camera:
            anims.append(
                self.camera.frame.animate.move_to(
                    [
                        self.camera.frame.get_x(),
                        embedding_arrows[0].get_end()[1],
                        self.camera.frame.get_z(),
                    ]
                )
            )
        self.play(
            *anims,
            run_time=1 / speed,
        )
        # Now show the embeddings
        query_embedding_tex = Group()
        for i, arrow in enumerate(embedding_arrows):
            embedding_value = embedding[i]
            embedding_vector = VGroup()
            for val in embedding_value[: self.shown_embeddings]:
                val_tex = DecimalNumber(
                    float(val),
                    num_decimal_places=2,
                    font_size=font_sizes[2],
                    include_sign=True,
                )
                embedding_vector.add(val_tex)
            embedding_vector.add(MathTex(r"\vdots", font_size=font_sizes[2]))
            embedding_vector.arrange(DOWN, buff=0.1)
            embedding_vector.next_to(arrow.get_end(), DOWN, buff=0.1)
            query_embedding_tex.add(embedding_vector)
        self.play(
            LaggedStartMap(
                Write,
                query_embedding_tex,
                lag_ratio=0.1,
            ),
            run_time=2 / speed,
        )

        # Now replace the numbers for a square of color
        query_embedding_squares = Group()
        for i, vector in enumerate(query_embedding_tex):
            for j, val in enumerate(vector):
                if j == len(vector) - 1:
                    # Skip the ellipsis
                    query_embedding_squares.add(val.copy())
                else:
                    value = embedding[i][j].item()
                    square = Square(
                        side_length=val.height + 0.05,
                        fill_color=self.get_color_from_cmap(
                            (float(value) + 1) / 2,
                            cmap_name="plasma",
                            cmap_range=(-10, 10),
                        ),
                        fill_opacity=1,
                        stroke_opacity=0,
                    )
                    square.move_to(val.get_center())
                    query_embedding_squares.add(square)
        self.play(
            FadeOut(query_embedding_tex),
            FadeIn(query_embedding_squares),
            FadeOut(title_tex),
            run_time=1 / speed,
        )
        return (
            query_tokenized_tex,
            query_tex,
            tokenize_arrow,
            embedding_arrows,
            query_embedding_squares,
        )

    def create_table(
        self, tokenized_query, tokenized_document, sim_matrix, font_size=36
    ):
        assert len(tokenized_query) == sim_matrix.shape[0]
        assert len(tokenized_document) == sim_matrix.shape[1]
        # Create y-axis labels
        y_axis_labels = VGroup()
        for i, token in enumerate(tokenized_query):
            if token == "<|endoftext|>":
                # Escape properly
                token_tex = Tex(
                    r"\texttt{<|endoftext|>}",
                    font_size=int(font_size * 0.8),
                )
            else:
                token_tex = Tex(
                    token.strip(),
                    font_size=font_size,
                )
            y_axis_labels.add(token_tex)

        y_axis_labels.arrange(DOWN, buff=0.2, aligned_edge=RIGHT).move_to(LEFT * 3)

        # Create x-axis labels
        x_axis_labels = VGroup()

        for i, token in enumerate(tokenized_document):
            x_axis_label = Tex(
                token.strip(),
                font_size=font_size,
            ).rotate(PI / 2)
            rect = Rectangle()
            rect.move_to(
                x_axis_label.get_corner(UP),
                aligned_edge=UP,
            )
            x_axis_label.add(rect)
            x_axis_labels.add(x_axis_label)
        # Make sure that the heights of the x-axis labels are consistent
        max_w = max(tex[0].width for tex in x_axis_labels)
        for tex in x_axis_labels:
            tex[1].width = max_w

        x_axis_labels.arrange(RIGHT, buff=0.2, aligned_edge=UP).next_to(
            y_axis_labels.get_corner(DR), DR, buff=0.2, aligned_edge=UL
        )

        # Create the heatmap
        heatmap = [Group() for _ in range(sim_matrix.shape[0])]
        val_range = (sim_matrix.min(), sim_matrix.max())
        for i in range(sim_matrix.shape[0]):
            for j in range(sim_matrix.shape[1]):
                value = sim_matrix[i, j]
                square = Square(
                    side_length=0.4,
                    fill_color=self.get_color_from_cmap(
                        float(value),
                        cmap_name="hot",
                        cmap_range=val_range,
                    ),
                    fill_opacity=1,
                    stroke_opacity=0,
                )
                square.move_to(
                    [
                        x_axis_labels[j].get_x(),
                        y_axis_labels[i].get_y(),
                        0,
                    ]
                )
                heatmap[i].add(square)
        heatmap = Group(*heatmap)

        return y_axis_labels, x_axis_labels, heatmap

    def construct(self):
        self.renderer.camera.use_z_index = True
        # Add watermark
        watermark = Tex(
            "Carles Onielfa",
            font_size=30,
            color=GRAY,
        )
        # Set to follow the camera
        watermark.add_updater(
            lambda m: m.move_to(
                self.camera.frame.get_corner(UR) + DOWN * 0.5 + LEFT * 0.5,
                aligned_edge=UR,
            )
        )
        self.add(watermark)
        # Definitions
        query = "Which flora is found in Yosemite?"
        document = text
        with open("assets/project_video_scene_embeddings.pkl", "rb") as f:
            embeddings = pickle.load(f)
        query_embedding = embeddings["query_embedding"]
        document_embedding = embeddings["document_embedding"]
        tokenized_query = embeddings["tokenized_query"]
        tokenized_document = embeddings["tokenized_document"]
        sim_matrix = (
            query_embedding[:, : self.shown_embeddings]
            @ document_embedding[:-1, : self.shown_embeddings].T
        )
        table_data = (
            torch.nn.functional.normalize(sim_matrix.to(torch.float16)).numpy() + 1
        )
        # 1. tokenizing query and document and generating embeddings. show the embeddings
        (
            q_tokenized_tex,
            q_tex,
            q_tokenize_arrow,
            q_embedding_arrows,
            q_embedding_squares,
        ) = self.embed_text(
            query,
            tokenized_query,
            query_embedding,
            move_camera=True,
            title="Query Embedding",
        )
        self.wait(1)
        # Translate the query_tokenized_tex to the bottom of the screen while fading out other elements
        self.play(
            q_tokenized_tex.animate.move_to(
                self.camera.frame.get_bottom() + UP
            ),  # Move to the bottom of the screen
            FadeOut(q_tex),
            FadeOut(q_tokenize_arrow),
            FadeOut(q_embedding_arrows),
            FadeOut(q_embedding_squares),
            run_time=1,
        )
        (
            d_tokenized_tex,
            d_tex,
            d_tokenize_arrow,
            d_embedding_arrows,
            d_embedding_squares,
        ) = self.embed_text(
            document,
            tokenized_document,
            document_embedding,
            move_camera=False,
            title="Document Embedding",
            font_sizes=(30, 18, 16),
            speed=1.75,
        )
        self.wait(1)
        self.play(
            FadeOut(d_tex),
            FadeOut(d_tokenize_arrow),
            FadeOut(d_embedding_arrows),
            FadeOut(d_embedding_squares),
        )
        # 2. rotate the tokenized query 90% to arrange in a matrix, display the heatmap of relevance between query and document
        y_axis_labels, x_axis_labels, heatmap = self.create_table(
            tokenized_query, tokenized_document[:-1], table_data
        )
        table_group = Group(
            y_axis_labels,
            x_axis_labels,
            heatmap,
        )
        table_group.move_to(self.camera.frame.get_center() + DOWN * 0.8 + LEFT * 0.5)
        # Fade out the other elements
        anims = []
        # Transform the query tokenized text to the y-axis labels
        anims.extend(
            [
                Transform(
                    q_tokenized_tex[i][0],
                    y_axis_labels[i],
                )
                for i in range(len(q_tokenized_tex))
            ]
            + [FadeOut(q[1]) for q in q_tokenized_tex]
        )
        # Transform the document tokenized text to the x-axis labels
        anims.extend(
            [d.animate.rotate(PI / 2) for d in d_tokenized_tex[:-1]]
            + [
                Transform(
                    d_tokenized_tex[i][0],
                    x_axis_labels[i][0],
                    replace_mobject_with_target_in_scene=True,
                )
                if i < len(d_tokenized_tex) - 1
                else FadeOut(d_tokenized_tex[i])
                for i in range(len(d_tokenized_tex))
            ]
            + [FadeOut(d[1]) for d in d_tokenized_tex[:-1]]
        )
        self.play(
            *anims,
            run_time=1,
        )
        self.wait(1)
        heatmap_title = MarkupText(
            "Compute Token-level Relevance",
            font_size=36,
            color=WHITE,
        )
        heatmap_title.move_to(
            self.camera.frame.get_corner(UL) + DOWN * 0.5 + RIGHT * 0.5, aligned_edge=UL
        )
        self.play(
            LaggedStartMap(
                FadeIn,
                heatmap,
                lag_ratio=0.1,
            ),
            FadeIn(heatmap_title, shift=UP * 0.5),
            run_time=2,
        )
        self.wait(1)
        # 3. flatten to show only the heatmap from the <|endoftext|> token to the document tokens

        # Fade out the all but the last row of the heatmap and the y-axis labels
        flat_heatmap_title = MarkupText(
            "Focus on the similarities for the last query token",
            font_size=36,
            color=WHITE,
        )
        flat_heatmap_title.move_to(
            self.camera.frame.get_corner(UL) + DOWN * 0.5 + RIGHT * 0.5, aligned_edge=UL
        )
        self.play(
            FadeOut(heatmap_title),
            FadeIn(flat_heatmap_title, shift=UP * 0.5),
            LaggedStartMap(
                FadeOut,
                heatmap[:-1],
                lag_ratio=0.1,
            ),
            LaggedStartMap(
                FadeOut,
                [q[0] for q in q_tokenized_tex[:-1]],
                lag_ratio=0.1,
            ),
            run_time=1,
        )

        # 4. animate to show a line plot from the heatmap

        y_values = table_data[-1]  # Last row of the heatmap
        x_values = np.arange(len(y_values)) + 1
        ax = Axes(
            x_range=(0, len(y_values) + 1),
            y_range=(0, y_values.max() + 0.4),
            y_length=self.camera.frame.get_height() * 0.65,
        )
        ax.set_z_index(1)
        # Ax bottom left corner to y_axis labels bottom right corner
        ax.width = x_axis_labels.width + 0.75
        ax.move_to(y_axis_labels.get_corner(DR) + LEFT * 0.22, aligned_edge=DL)

        self.play(
            FadeOut(q_tokenized_tex[-1][0]),
            Create(ax.x_axis),
            Create(ax.y_axis),
        )
        lines = VGroup()
        for i in range(len(x_values) - 1):
            start = ax.c2p(x_values[i], y_values[i])
            end = ax.c2p(x_values[i + 1], y_values[i + 1])
            l = Line(start, end)
            lines.add(l)
            l.set_z_index(0)
        lines.set_z_index(0)

        # Plot the points
        dots = VGroup()
        for x, y in zip(x_values, y_values):
            dot = Dot(
                radius=0.08,
                point=ax.c2p(x, y),
                stroke_width=2,
                stroke_color=WHITE,
                color=self.get_color_from_cmap(
                    y,
                    cmap_name="hot",
                    cmap_range=(table_data.min(), table_data.max()),
                ),
            )
            dots.add(dot)
            dot.set_z_index(2)
        dots.set_z_index(2)

        # Transform the heatmap squares into dots
        self.play(
            *[
                Transform(
                    heatmap[-1][i][0],
                    dots[i],
                    replace_mobject_with_target_in_scene=True,
                )
                for i in range(len(dots))
            ],
            run_time=2,
            lag_ratio=0.1,
        )

        self.play(
            Create(lines),
        )
        # 5. highlight peaks in the line plot and group them into clusters
        peaks_title = MarkupText(
            "Cluster peaks in relevance",
            font_size=36,
            color=WHITE,
        )
        peaks_title.move_to(
            self.camera.frame.get_corner(UL) + DOWN * 0.5 + RIGHT * 0.5, aligned_edge=UL
        )
        self.play(
            FadeOut(flat_heatmap_title),
            FadeIn(peaks_title, shift=UP * 0.5),
            run_time=1,
        )
        # Draw threshold line
        threshold = y_values.max() * 0.893
        threshold_line = ax.get_horizontal_line(
            ax.c2p(x_values[-1], threshold), color=YELLOW, stroke_width=3
        )
        self.play(
            Create(threshold_line), *[dot.animate.set_color(WHITE) for dot in dots]
        )
        self.wait(1.5)
        # Flash the points above the threshold
        anims = []
        peaks: list[Dot] = []
        for i, y in enumerate(y_values):
            if y > threshold:
                peaks.append(dots[i])
                anims.append(
                    Flash(
                        dots[i],
                        color=RED,
                    )
                )
                anims.append(dots[i].animate.set_color(RED).scale(1.5))
        self.play(LaggedStart(*anims, lag_ratio=0.2))
        # Create clusters by scanning the line plot
        scan = Rectangle(
            width=0.01,
            height=(ax.c2p(0, y_values.max() + 0.1) - ax.c2p(0, 0))[1],
            color=RED,
            stroke_width=2,
            fill_opacity=0.1,
        )
        scan.z_index = 0
        scan.move_to(ax.c2p(0, 0), aligned_edge=DOWN)
        # Animate the scan
        self.play(
            Create(scan),
            run_time=0.5,
        )
        destination = np.array(
            [
                peaks[0].get_x(),
                scan.get_y(),
                0,
            ]
        )
        distance = np.linalg.norm(destination - scan.get_center())
        speed = 2.5
        self.play(
            scan.animate.move_to(destination),
            run_time=distance / speed,
            rate_func=linear,
        )
        target_width = peaks[-1].get_x() - peaks[0].get_x()

        self.play(
            scan.animate.stretch_to_fit_width(target_width)
            .move_to(destination, aligned_edge=LEFT)
            .set_run_time(target_width / speed)
            .set_rate_func(linear),
            # Run these in sequence
            LaggedStart(
                *[
                    label[0].animate.set_color(RED)
                    for label in x_axis_labels
                    if label[0].get_corner(UL)[0] >= peaks[0].get_x() - 0.5
                    and label[0].get_corner(UR)[0] <= peaks[-1].get_x() + 0.5
                ],
                lag_ratio=0.05,
                run_time=target_width / speed,
            ),
        )
        # 6. extract highlighted text
        extracted_text_mock = MarkupText(document, font_size=32, color=WHITE)
        extracted_text_mock.move_to(self.camera.frame.get_center())
        cumm_text_length = [0] + list(
            itertools.accumulate([len(t.strip()) for t in tokenized_document[:-1]])
        )
        self.play(
            FadeOut(peaks_title),
            FadeOut(ax),
            FadeOut(dots),
            FadeOut(lines),
            FadeOut(threshold_line),
            FadeOut(scan),
            *[
                x_axis_labels[i][0]
                .animate.rotate(-PI / 2)
                .move_to(
                    extracted_text_mock[cumm_text_length[i] : cumm_text_length[i + 1]]
                )
                for i in range(len(x_axis_labels))
            ],
        )
        # Expand colored text to nearest sentence
        start_idx = 6
        end_idx = len(tokenized_document) - 1
        extend_title = MarkupText(
            "Extend spans to sentences",
            font_size=36,
            color=WHITE,
        )
        extend_title.move_to(
            self.camera.frame.get_corner(UL) + DOWN * 0.5 + RIGHT * 0.5, aligned_edge=UL
        )
        anims_start = []
        tokens_start = [
            label[0]
            for label in x_axis_labels[start_idx:]
            if label[0].get_corner(UR)[0] <= peaks[-1].get_x() + 0.5
        ]
        # Animate each letter separately
        for tok in tokens_start:
            for letter in tok:
                anims_start.append(letter.animate.set_color(RED))
        anims_end = []
        tokens_end = [
            label[0]
            for label in x_axis_labels[:end_idx]
            if label[0].get_corner(UL)[0] >= peaks[0].get_x() - 0.5
        ]
        for tok in tokens_end:
            for letter in tok:
                anims_end.append(letter.animate.set_color(RED))
        self.play(
            FadeIn(extend_title, shift=UP * 0.5),
            LaggedStart(
                *reversed(anims_start),
                lag_ratio=0.05,
            ),
            LaggedStart(
                *anims_end,
                lag_ratio=0.05,
            ),
        )

        q_text = MarkupText(query, font_size=54, color=WHITE)
        q_text.move_to(self.camera.frame.get_top() + UP * 0.5)
        self.play(
            FadeOut(extend_title),
            q_text.animate.move_to(self.camera.frame.get_top() + DOWN * 2),
        )
        self.wait(3)
        self.play(
            FadeOut(q_text),
            *[FadeOut(x_axis_labels[i][0]) for i in range(len(x_axis_labels))],
        )

In [41]:
%%manim -qh -v ERROR ProjectVideo
""

  tex.set_height(max_height)
  cmap = cm.get_cmap(cmap_name)
  y_length=self.camera.frame.get_height() * 0.65,
                                                                                                                                              

In [None]:
Encoder animation

In [50]:
# | echo: false
from manim import *


class DecoderDiagram(MovingCameraScene):
    def construct(self):
        self.camera.background_color = None
        watermark = Tex(
            "Carles Onielfa",
            font_size=30,
            color=GRAY,
        )
        # Set to follow the camera
        watermark.add_updater(
            lambda m: m.move_to(
                self.camera.frame.get_corner(DR) + UP * 0.5 + LEFT * 0.25,
                aligned_edge=UR,
            )
        )
        self.add(watermark)

        Text.set_default(font="monospace")
        RoundedRectangle.set_default(
            corner_radius=0.05,
            stroke_width=2,
        )
        Rectangle.set_default(
            stroke_width=2,
        )
        text = "Help Patrick ... <|endoftext|>"

        token_box_width = 2
        # Reusable embedding_block
        num_embedding_blocks = 4
        embedding_block = VGroup(
            *[
                Rectangle(
                    width=0.89 * token_box_width / num_embedding_blocks,
                    height=0.9 * token_box_width / num_embedding_blocks,
                )
                for _ in range(num_embedding_blocks)
            ]
        )
        embedding_block.arrange(RIGHT, buff=0)

        # ----- INPUT TOKENS BLOCK -----
        input_tokens_block = VGroup()
        for i, word in enumerate(text.split()):
            token = RoundedRectangle(corner_radius=0, width=token_box_width, height=0.5)
            label = Text(word, font_size=16)
            label.move_to(token.get_center())
            token.add(label)
            # Add the token embedding and the position encoding
            token_embedding = embedding_block.copy()
            token_embedding.set_color(ORANGE)
            plus_sign = Text("+", font_size=32)
            position_encoding = embedding_block.copy()
            position_encoding.set_color(TEAL)
            up_arrow = Arrow(
                token_embedding.get_bottom() * 0.75,
                token.get_top() * 0.75,
                buff=0,
            )
            block = VGroup(
                token, up_arrow, token_embedding, plus_sign, position_encoding
            )
            block.arrange(UP, buff=0.1)

            # Add the rectangle to the input tokens block
            input_tokens_block.add(block)

        input_tokens_block.arrange(RIGHT, buff=0.3)
        input_tokens_block.to_edge(DOWN)

        # ----- DECODER BLOCKS -----
        # Create a vertical stack of decoder blocks
        # Each block represents a layer in the decoder
        decoder_block = VGroup()
        num_hidden_layers = 36
        for i in range(4):
            if i == 2:
                # Place dots instead of a block for the third layer
                block = RoundedRectangle(
                    width=input_tokens_block.width,
                    height=0.8,
                    stroke_opacity=0,
                )
                label = Text(f"...", font_size=28)
                label.move_to(block.get_center())
                block.add(label)
            else:
                block = RoundedRectangle(
                    width=input_tokens_block.width, height=0.8, fill_opacity=0.6
                )
                block.set_fill(BLACK)

                block.z_index = 2
            block.move_to(ORIGIN + DOWN * i)
            decoder_block.add(block)
        decoder_block.arrange(DOWN, buff=0.1)

        # ----- HIDDEN STATES -----
        hidden_states = VGroup()

        for block in input_tokens_block:
            # hidden rectangle the same size as the input token block for matching width
            last_hidden_state = RoundedRectangle(
                corner_radius=0.1,
                width=token_box_width,
                height=0.5,
                stroke_width=0,
            )
            b = embedding_block.copy()
            last_hidden_state.add(b)
            hidden_states.add(last_hidden_state)

        hidden_states.arrange(RIGHT, buff=0.3)

        # Add the input tokens block and decoder block to the scene
        llm_stack = VGroup(input_tokens_block, decoder_block, hidden_states)
        llm_stack.arrange(UP, buff=0.5)
        # llm_stack.move_to(llm_stack.get_center() + RIGHT * 1)

        # ----- LABELS -----
        label_font_size = 22
        token = input_tokens_block[0][0]
        token_embedding = input_tokens_block[0][2]
        position_encoding = input_tokens_block[0][4]

        # Token labels
        labels_group = Group()
        token_label = Tex(f"Input Tokens", font_size=label_font_size)
        token_label.move_to(token.get_left() + LEFT * 0.15, aligned_edge=RIGHT)
        labels_group.add(token_label)

        # Token embedding label
        token_embedding_label = Tex(
            r"$\blacksquare$ Token Embeddings",
            font_size=label_font_size,
        )
        token_embedding_label[0][0].set_color(ORANGE)
        token_embedding_label.move_to(
            token_embedding.get_left() + LEFT * 0.15, aligned_edge=RIGHT
        )
        labels_group.add(token_embedding_label)

        # Position encoding label
        position_encoding_label = Tex(
            r"$\blacksquare$ Position Encodings",
            font_size=label_font_size,
        )
        position_encoding_label[0][0].set_color(TEAL)
        position_encoding_label.move_to(
            position_encoding.get_left() + LEFT * 0.15, aligned_edge=RIGHT
        )
        labels_group.add(position_encoding_label)

        # Decoder block labels
        decoder_block_labels = Group()
        for i, block in enumerate(decoder_block):
            if i != 2:
                label = Tex(
                    f"Decoder block {1 if i == 3 else num_hidden_layers - i}",
                    font_size=label_font_size,
                )
            else:
                # Add dummy label for the third block
                label = Tex("")

            decoder_block_labels.add(label)
            label.move_to(block.get_left() + LEFT * 0.15, aligned_edge=RIGHT)

        # Default embedding label
        default_embedding_label = Tex(
            r"$\blacksquare$ Text Embedding\\(default embedding)",
            font_size=label_font_size,
        )
        default_embedding_label[0][0].set_color(YELLOW)
        default_embedding_label.move_to(
            hidden_states.get_right() + RIGHT * 0.4, aligned_edge=LEFT
        )
        default_embedding_label_arrow = Arrow(
            hidden_states.get_right(),
            default_embedding_label.get_left() + LEFT * 0.1,
            buff=0,
        )
        # Hidden states label
        hidden_states_label = Tex(
            r"$\blacksquare + \blacksquare$ Hidden States\\(our embeddings)",
            font_size=label_font_size,
        )
        hidden_states_label[0][0].set_color(PURPLE)
        hidden_states_label[0][2].set_color(YELLOW)
        hidden_states_label.move_to(
            hidden_states.get_left() + LEFT * 0.15, aligned_edge=RIGHT
        )
        labels_group.add(hidden_states_label)

        # ----- OUTPUT ARROWS -----
        output_arrows = Group()
        for input_block, output_block in zip(input_tokens_block, hidden_states):
            # Create an arrow from the input token block to the hidden state block
            arrow = Arrow(
                input_block.get_top(),
                output_block.get_bottom(),
                buff=0.1,
                stroke_width=2,
                max_tip_length_to_length_ratio=0.03,
            )
            arrow.z_index = 0
            output_arrows.add(arrow)

        # ----- ANIMATION -----
        # Show input tokens block
        self.play(FadeIn(*[i[0] for i in input_tokens_block]), run_time=1)
        self.play(Create(token_label, run_time=0.5))
        self.wait(1)
        # Show decoder blocks
        # self.play(
        #     FadeIn(*reversed(decoder_block), lag_ratio=0.2),
        #     run_time=1,
        # )
        # self.play(
        #     *[Create(label) for label in reversed(decoder_block_labels)],
        #     run_time=1,
        #     lag_ratio=0.2,
        # )
        self.play(
            Succession(
                *[
                    (FadeIn(block), Create(label))
                    for block, label in zip(
                        reversed(decoder_block), reversed(decoder_block_labels)
                    )
                ],
                lag_ratio=0.2,
                run_time=1,
            )
        )
        self.wait(1)
        # Show input token embeddings and position encodings
        self.play(
            FadeIn(*[i[2] for i in input_tokens_block]),
            FadeIn(*[i[4] for i in input_tokens_block]),
            run_time=1,
        )
        # Show labels for input token embeddings and position encodings

        self.play(
            Create(token_embedding_label),
            Create(position_encoding_label),
            run_time=0.5,
        )
        self.wait(1)
        # Show arrows
        self.play(
            FadeIn(*[i[1] for i in input_tokens_block]),
            run_time=0.5,
        )
        self.play(
            FadeIn(*[i[3] for i in input_tokens_block]),
            run_time=0.5,
        )
        self.play(
            *[GrowArrow(arrow) for arrow in output_arrows],
            run_time=1,
        )
        self.wait(0.5)
        # Show hidden states
        self.play(
            FadeIn(*hidden_states),
            run_time=1,
        )
        # Show default embedding label and arrow
        # while flickering the last embedding block to yellow
        last_embedding_block = hidden_states[-1]
        self.play(
            last_embedding_block.animate.set_color(YELLOW),
            GrowArrow(default_embedding_label_arrow),
            Create(default_embedding_label),
            run_time=1,
        )
        self.wait(1.5)
        self.play(
            hidden_states[:-1].animate.set_color(PURPLE),
            Create(hidden_states_label),
            run_time=0.5,
        )

        self.wait(5)

In [51]:
%%manim -qh -v WARNING DecoderDiagram
""

                                                                                                                   