In [7]:
"""
从零开始理解 BERT — Manim 动画教程 (无 LaTeX 依赖版)
=====================================================
使用方法 (Jupyter Notebook):
    from manim import *
    %manim -qm -v WARNING BertTutorial

分场景渲染：
    %manim -ql -v WARNING Scene1_WhatIsBert
    %manim -ql -v WARNING Scene2_MaskedLanguageModel
    %manim -ql -v WARNING Scene3_SelfAttention
    %manim -ql -v WARNING Scene4_Training
    %manim -ql -v WARNING Scene5_Results
"""

from manim import *
import numpy as np

# ============================================================
#  全局配色
# ============================================================
BG_COLOR = "#0f0f1a"
TEXT_COLOR = "#e0e0e0"
ACCENT_BLUE = "#4fc3f7"
ACCENT_PINK = "#f48fb1"
ACCENT_GREEN = "#81c784"
ACCENT_ORANGE = "#ffb74d"
ACCENT_PURPLE = "#ce93d8"
ACCENT_YELLOW = "#fff176"
DIM_COLOR = "#555577"

config.background_color = BG_COLOR


# ============================================================
#  工具函数 (全部用 Text，不用 MathTex)
# ============================================================
def make_matrix_mob(rows, color=TEXT_COLOR, font_size=28, h_buff=0.9, v_buff=0.55):
    """用纯 Text 手动画一个矩阵（带方括号）"""
    entries = VGroup()
    for r, row in enumerate(rows):
        for c, val in enumerate(row):
            t = Text(str(val), font_size=font_size, color=color)
            t.move_to(RIGHT * c * h_buff + DOWN * r * v_buff)
            entries.add(t)

    entries.center()

    # 方括号
    h = entries.height + 0.4
    w_offset = entries.width / 2 + 0.25

    left_bracket = VGroup(
        Line(UP * h / 2, DOWN * h / 2, color=color, stroke_width=2),
        Line(UP * h / 2, UP * h / 2 + RIGHT * 0.12, color=color, stroke_width=2),
        Line(DOWN * h / 2, DOWN * h / 2 + RIGHT * 0.12, color=color, stroke_width=2),
    ).shift(LEFT * w_offset)

    right_bracket = VGroup(
        Line(UP * h / 2, DOWN * h / 2, color=color, stroke_width=2),
        Line(UP * h / 2, UP * h / 2 + LEFT * 0.12, color=color, stroke_width=2),
        Line(DOWN * h / 2, DOWN * h / 2 + LEFT * 0.12, color=color, stroke_width=2),
    ).shift(RIGHT * w_offset)

    return VGroup(left_bracket, entries, right_bracket)


def labeled_matrix(label_str, rows, label_color=ACCENT_BLUE, val_color=TEXT_COLOR, font_size=28):
    """带标签的矩阵，如 'Q =' + 矩阵"""
    label = Text(label_str, font_size=font_size + 4, color=label_color, weight=BOLD)
    mat = make_matrix_mob(rows, color=val_color, font_size=font_size)
    label.next_to(mat, LEFT, buff=0.25)
    return VGroup(label, mat)


def highlight_box(mobject, color=ACCENT_BLUE, buff=0.15, corner_radius=0.1):
    return SurroundingRectangle(
        mobject, color=color, buff=buff,
        corner_radius=corner_radius, stroke_width=2
    )


# ============================================================
#  场景 1: BERT 是什么？
# ============================================================
class Scene1_WhatIsBert(Scene):
    def construct(self):
        # --- 标题 ---
        main_title = Text(
            "从零开始理解 BERT", font_size=52,
            color=ACCENT_BLUE, weight=BOLD
        )
        subtitle = Text(
            "Bidirectional Encoder Representations from Transformers",
            font_size=22, color=DIM_COLOR
        )
        subtitle.next_to(main_title, DOWN, buff=0.4)

        self.play(Write(main_title), run_time=1.5)
        self.play(FadeIn(subtitle, shift=UP * 0.3), run_time=0.8)
        self.wait(1.5)
        self.play(FadeOut(main_title), FadeOut(subtitle))

        # --- 核心问题 ---
        question = Text(
            "同一个词，不同的含义", font_size=40,
            color=ACCENT_ORANGE, weight=BOLD
        )
        question.to_edge(UP, buff=0.6)
        self.play(FadeIn(question, shift=DOWN * 0.3))

        sentences = [
            ("你来就来，还带礼物。你几个", "意思", "啊？"),
            ("别多想。我就", "意思", "一下"),
            ("我看不懂这个词是什么", "意思", ""),
            ("你这么说就没", "意思", "了"),
        ]

        sent_groups = VGroup()
        for pre, keyword, post in sentences:
            pre_t = Text(pre, font_size=28, color=TEXT_COLOR)
            key_t = Text(keyword, font_size=28, color=ACCENT_PINK, weight=BOLD)
            post_t = Text(post, font_size=28, color=TEXT_COLOR)
            key_t.next_to(pre_t, RIGHT, buff=0.05)
            post_t.next_to(key_t, RIGHT, buff=0.05)
            line = VGroup(pre_t, key_t, post_t)
            sent_groups.add(line)

        sent_groups.arrange(DOWN, buff=0.4, aligned_edge=LEFT)
        sent_groups.next_to(question, DOWN, buff=0.8)

        for line in sent_groups:
            self.play(FadeIn(line, shift=RIGHT * 0.3), run_time=0.6)
            box = highlight_box(line[1], color=ACCENT_PINK)
            self.play(Create(box), run_time=0.3)
            self.wait(0.3)
            self.play(FadeOut(box), run_time=0.2)

        self.wait(0.8)

        # --- 解决方案 ---
        solution = Text(
            "→ 上下文相关的向量表示", font_size=34,
            color=ACCENT_GREEN, weight=BOLD
        )
        solution.next_to(sent_groups, DOWN, buff=0.7)
        self.play(FadeIn(solution, shift=UP * 0.3))
        self.wait(0.8)

        arrow_text = Text(
            "靠什么实现？ Self-Attention!", font_size=30,
            color=ACCENT_YELLOW
        )
        arrow_text.next_to(solution, DOWN, buff=0.4)
        self.play(FadeIn(arrow_text, shift=UP * 0.2))
        self.wait(2)

        self.play(*[FadeOut(m) for m in self.mobjects], run_time=0.8)


# ============================================================
#  场景 2: Masked Language Model
# ============================================================
class Scene2_MaskedLanguageModel(Scene):
    def construct(self):
        title = Text("BERT 的训练核心：遮词预测", font_size=40, color=ACCENT_BLUE, weight=BOLD)
        title.to_edge(UP, buff=0.5)
        self.play(Write(title), run_time=1)

        # --- 原始句子 ---
        words = ["你", "这", "么", "说", "就", "没", "意", "思", "了", "啊"]
        word_boxes = VGroup()
        for w in words:
            t = Text(w, font_size=30, color=TEXT_COLOR)
            box = RoundedRectangle(
                width=0.7, height=0.7,
                corner_radius=0.1,
                stroke_color=ACCENT_BLUE, stroke_width=1.5,
                fill_color="#1a1a2e", fill_opacity=0.8
            )
            t.move_to(box)
            word_boxes.add(VGroup(box, t))

        word_boxes.arrange(RIGHT, buff=0.12)
        word_boxes.next_to(title, DOWN, buff=1.0)

        self.play(
            LaggedStart(*[FadeIn(wb, shift=UP * 0.2) for wb in word_boxes],
                        lag_ratio=0.08),
            run_time=1.2
        )
        self.wait(0.5)

        # --- 遮住 "说" ---
        mask_idx = 3
        target_box = word_boxes[mask_idx]

        mask_cover = RoundedRectangle(
            width=0.7, height=0.7,
            corner_radius=0.1,
            stroke_color=ACCENT_PINK, stroke_width=2.5,
            fill_color=ACCENT_PINK, fill_opacity=0.9
        )
        mask_label = Text("MASK", font_size=16, color="#0f0f1a", weight=BOLD)
        mask_label.move_to(mask_cover)
        mask_group = VGroup(mask_cover, mask_label)
        mask_group.move_to(target_box)

        self.play(
            FadeIn(mask_group, scale=1.3),
            target_box[1].animate.set_opacity(0),
            run_time=0.8
        )
        self.wait(0.8)

        # --- 训练流程 4 步 ---
        steps_data = [
            ("①", "查词表", "每个字 → 向量"),
            ("②", "加位置", "告诉模型字的位置"),
            ("③", "Self-Attention", "字与字互相关注"),
            ("④", "预测", "MASK → 概率分布"),
        ]

        step_groups = VGroup()
        for num, name, desc in steps_data:
            num_t = Text(num, font_size=26, color=ACCENT_ORANGE, weight=BOLD)
            name_t = Text(name, font_size=26, color=ACCENT_BLUE, weight=BOLD)
            desc_t = Text(desc, font_size=20, color=DIM_COLOR)
            name_t.next_to(num_t, RIGHT, buff=0.15)
            desc_t.next_to(VGroup(num_t, name_t), DOWN, buff=0.1, aligned_edge=LEFT)
            step_groups.add(VGroup(num_t, name_t, desc_t))

        step_groups.arrange(DOWN, buff=0.35, aligned_edge=LEFT)
        step_groups.next_to(word_boxes, DOWN, buff=1.0)
        step_groups.shift(LEFT * 2)

        for sg in step_groups:
            self.play(FadeIn(sg, shift=RIGHT * 0.3), run_time=0.5)
            self.wait(0.2)

        # --- 预测结果 ---
        predict_arrow = Arrow(
            start=RIGHT * 1.5, end=RIGHT * 3.5,
            color=ACCENT_GREEN, stroke_width=3
        )
        predict_arrow.next_to(step_groups, RIGHT, buff=0.5)
        predict_arrow.shift(DOWN * 0.5)

        result_text = Text(
            '"说"  → 98.5%', font_size=28,
            color=ACCENT_GREEN, weight=BOLD
        )
        result_text.next_to(predict_arrow, RIGHT, buff=0.3)

        self.play(GrowArrow(predict_arrow), run_time=0.6)
        self.play(FadeIn(result_text, shift=LEFT * 0.2), run_time=0.5)

        self.wait(2)
        self.play(*[FadeOut(m) for m in self.mobjects], run_time=0.8)


# ============================================================
#  场景 3: Self-Attention 手算 (核心)
# ============================================================
class Scene3_SelfAttention(Scene):
    def construct(self):
        self.scene_title = Text(
            "手算 Self-Attention", font_size=42,
            color=ACCENT_BLUE, weight=BOLD
        )
        self.scene_title.to_edge(UP, buff=0.4)
        self.play(Write(self.scene_title), run_time=1)

        self.show_setup()
        self.show_qkv()
        self.show_scores()
        self.show_softmax()
        self.show_weighted_sum()

    # ------ 3a: 输入设定 ------
    def show_setup(self):
        setup_text = Text(
            '输入: "one cat"    维度 d = 3', font_size=28,
            color=TEXT_COLOR
        )
        setup_text.next_to(self.scene_title, DOWN, buff=0.6)
        self.play(FadeIn(setup_text))

        # X 矩阵
        x_mat = labeled_matrix(
            "X =", [["1", "0", "2"], ["-1", "2", "0"]],
            label_color=ACCENT_BLUE
        )
        x_mat.next_to(setup_text, DOWN, buff=0.7)

        # 行标签
        mat_body = x_mat[1]  # the matrix part
        one_label = Text("one →", font_size=20, color=ACCENT_GREEN)
        cat_label = Text("cat →", font_size=20, color=ACCENT_PINK)

        # 获取矩阵 entries 的上半和下半位置
        entries = mat_body[1]  # entries VGroup
        top_row_y = entries[0].get_center()[1]  # first row y
        bot_row_y = entries[3].get_center()[1]  # second row y

        one_label.next_to(mat_body, LEFT, buff=0.5)
        one_label.set_y(top_row_y)
        cat_label.next_to(mat_body, LEFT, buff=0.5)
        cat_label.set_y(bot_row_y)

        self.play(
            FadeIn(x_mat),
            FadeIn(one_label), FadeIn(cat_label),
            run_time=1.0
        )
        self.wait(1.5)

        self.play(
            FadeOut(setup_text), FadeOut(x_mat),
            FadeOut(one_label), FadeOut(cat_label),
            run_time=0.5
        )

    # ------ 3b: Q, K, V ------
    def show_qkv(self):
        subtitle = Text(
            "步骤 1-2：计算 Q, K, V", font_size=26, color=ACCENT_ORANGE
        )
        subtitle.next_to(self.scene_title, DOWN, buff=0.5)
        self.play(FadeIn(subtitle))

        # 说明
        explain = Text(
            "Q = X·Wq    K = X·Wk    V = X·Wv",
            font_size=22, color=DIM_COLOR
        )
        explain.next_to(subtitle, DOWN, buff=0.3)
        self.play(FadeIn(explain))

        data = [
            ("Q =", [["1", "0", "2"], ["-1", "2", "0"]], ACCENT_BLUE),
            ("K =", [["1", "2", "0"], ["-1", "0", "2"]], ACCENT_GREEN),
            ("V =", [["2", "0", "4"], ["-2", "4", "0"]], ACCENT_PINK),
        ]

        matrices = VGroup()
        for label_str, vals, color in data:
            m = labeled_matrix(label_str, vals, label_color=color, font_size=24)
            matrices.add(m)

        matrices.arrange(RIGHT, buff=0.7)
        matrices.next_to(explain, DOWN, buff=0.6)

        # 确保不超出屏幕
        if matrices.width > 13:
            matrices.scale_to_fit_width(13)

        for m in matrices:
            self.play(FadeIn(m, shift=UP * 0.2), run_time=0.6)
            self.wait(0.2)

        self.wait(1)
        self.play(FadeOut(subtitle), FadeOut(explain), FadeOut(matrices), run_time=0.5)

    # ------ 3c: 注意力分数 ------
    def show_scores(self):
        subtitle = Text(
            "步骤 3-4：注意力分数 + 缩放", font_size=26, color=ACCENT_ORANGE
        )
        subtitle.next_to(self.scene_title, DOWN, buff=0.5)
        self.play(FadeIn(subtitle))

        formula = Text(
            "Score = Q · Kᵀ", font_size=30, color=TEXT_COLOR
        )
        formula.next_to(subtitle, DOWN, buff=0.5)
        self.play(Write(formula), run_time=0.8)

        # 四个计算
        calcs = VGroup(
            Text("one↔one: 1×1+0×2+2×0 = 1", font_size=20, color=DIM_COLOR),
            Text("one↔cat: 1×(-1)+0×0+2×2 = 3", font_size=20, color=ACCENT_YELLOW),
            Text("cat↔one: (-1)×1+2×2+0×0 = 3", font_size=20, color=ACCENT_YELLOW),
            Text("cat↔cat: (-1)×(-1)+2×0+0×2 = 1", font_size=20, color=DIM_COLOR),
        )
        calcs.arrange_in_grid(2, 2, buff=(0.6, 0.3))
        calcs.next_to(formula, DOWN, buff=0.5)

        for c in calcs:
            self.play(FadeIn(c, shift=UP * 0.1), run_time=0.35)

        self.wait(0.5)

        # 缩放结果
        scale_label = Text("÷ √3 ≈", font_size=26, color=TEXT_COLOR)
        scaled_mat = labeled_matrix(
            "", [["0.58", "1.73"], ["1.73", "0.58"]],
            label_color=ACCENT_BLUE, val_color=ACCENT_BLUE, font_size=26
        )
        scale_group = VGroup(scale_label, scaled_mat).arrange(RIGHT, buff=0.2)
        scale_group.next_to(calcs, DOWN, buff=0.5)

        self.play(FadeIn(scale_group, shift=UP * 0.2), run_time=0.8)
        self.wait(1.5)

        self.play(
            FadeOut(subtitle), FadeOut(formula),
            FadeOut(calcs), FadeOut(scale_group),
            run_time=0.5
        )

    # ------ 3d: Softmax ------
    def show_softmax(self):
        subtitle = Text(
            "步骤 5：Softmax → 注意力权重", font_size=26, color=ACCENT_ORANGE
        )
        subtitle.next_to(self.scene_title, DOWN, buff=0.5)
        self.play(FadeIn(subtitle))

        # Softmax 计算说明
        calc_text = VGroup(
            Text("e⁰·⁵⁸ ≈ 1.79    e¹·⁷³ ≈ 5.64", font_size=22, color=DIM_COLOR),
            Text("权重: 1.79/7.43 ≈ 0.24    5.64/7.43 ≈ 0.76", font_size=22, color=DIM_COLOR),
        ).arrange(DOWN, buff=0.2)
        calc_text.next_to(subtitle, DOWN, buff=0.5)
        self.play(FadeIn(calc_text), run_time=0.6)

        # 注意力矩阵
        a_mat = labeled_matrix(
            "A =", [["0.24", "0.76"], ["0.76", "0.24"]],
            label_color=ACCENT_PURPLE, val_color=TEXT_COLOR, font_size=30
        )
        a_mat.next_to(calc_text, DOWN, buff=0.5)
        self.play(FadeIn(a_mat), run_time=0.8)

        # 解读框
        insight_box = RoundedRectangle(
            width=10, height=1.5,
            corner_radius=0.15,
            stroke_color=ACCENT_GREEN, stroke_width=1.5,
            fill_color="#1a2e1a", fill_opacity=0.6
        )
        insight_1 = Text(
            '计算 "one" 的新向量时：', font_size=22, color=TEXT_COLOR
        )
        insight_2 = Text(
            "保留 24% 自己  +  吸收 76% 来自 cat",
            font_size=24, color=ACCENT_GREEN, weight=BOLD
        )
        insight = VGroup(insight_1, insight_2).arrange(DOWN, buff=0.12)
        insight_box.move_to(insight)
        insight_group = VGroup(insight_box, insight)
        insight_group.next_to(a_mat, DOWN, buff=0.5)

        self.play(FadeIn(insight_group, shift=UP * 0.2), run_time=0.8)
        self.wait(2)

        self.play(
            FadeOut(subtitle), FadeOut(calc_text),
            FadeOut(a_mat), FadeOut(insight_group),
            run_time=0.5
        )

    # ------ 3e: 加权求和 ------
    def show_weighted_sum(self):
        subtitle = Text(
            "步骤 6：加权求和 → 新向量", font_size=26, color=ACCENT_ORANGE
        )
        subtitle.next_to(self.scene_title, DOWN, buff=0.5)
        self.play(FadeIn(subtitle))

        # Before
        before_title = Text("原始向量", font_size=22, color=DIM_COLOR)
        before_one = Text("one:  [ 1,   0,   2 ]", font_size=24, color=ACCENT_BLUE)
        before_cat = Text("cat:  [-1,   2,   0 ]", font_size=24, color=ACCENT_PINK)
        before_group = VGroup(before_title, before_one, before_cat).arrange(DOWN, buff=0.2)

        # Arrow
        arrow = Arrow(LEFT * 0.3, RIGHT * 0.8, color=ACCENT_YELLOW, stroke_width=3)

        # After
        after_title = Text("Attention 之后", font_size=22, color=ACCENT_GREEN)
        after_one = Text("one:  [-1.04,  3.04,  0.96]", font_size=24, color=ACCENT_BLUE)
        after_cat = Text("cat:  [ 1.04,  0.96,  3.04]", font_size=24, color=ACCENT_PINK)
        after_group = VGroup(after_title, after_one, after_cat).arrange(DOWN, buff=0.2)

        comparison = VGroup(before_group, arrow, after_group).arrange(RIGHT, buff=0.5)
        comparison.next_to(subtitle, DOWN, buff=0.7)

        self.play(FadeIn(before_group, shift=RIGHT * 0.3), run_time=0.6)
        self.wait(0.5)
        self.play(GrowArrow(arrow), run_time=0.4)
        self.play(FadeIn(after_group, shift=LEFT * 0.3), run_time=0.6)

        # 变化说明
        note1 = Text(
            "one 的中间维度: 0 → 3.04  (吸收了 cat 的特征!)",
            font_size=20, color=ACCENT_BLUE
        )
        note2 = Text(
            "cat 的第三维度: 0 → 3.04  (吸收了 one 的特征!)",
            font_size=20, color=ACCENT_PINK
        )
        notes = VGroup(note1, note2).arrange(DOWN, buff=0.15)
        notes.next_to(comparison, DOWN, buff=0.5)
        self.play(FadeIn(notes), run_time=0.6)
        self.wait(0.5)

        # 结论
        conclusion = Text(
            "核心本质：通过互相关注，将彼此的特征融入自己",
            font_size=26, color=ACCENT_YELLOW, weight=BOLD
        )
        conclusion.next_to(notes, DOWN, buff=0.5)
        self.play(FadeIn(conclusion, shift=UP * 0.2), run_time=0.8)

        self.wait(2.5)
        self.play(*[FadeOut(m) for m in self.mobjects], run_time=0.8)


# ============================================================
#  场景 4: 训练流程
# ============================================================
class Scene4_Training(Scene):
    def construct(self):
        title = Text("BERT 训练流程", font_size=42, color=ACCENT_BLUE, weight=BOLD)
        title.to_edge(UP, buff=0.4)
        self.play(Write(title), run_time=1)

        # --- 流程图 ---
        steps = [
            ("随机选句子", ACCENT_BLUE),
            ("遮住一个词", ACCENT_PINK),
            ("查 Embedding", ACCENT_GREEN),
            ("Self-Attention", ACCENT_ORANGE),
            ("预测被遮词", ACCENT_PURPLE),
            ("计算 Loss", ACCENT_YELLOW),
            ("反向传播", ACCENT_PINK),
            ("更新参数", ACCENT_GREEN),
        ]

        nodes = VGroup()
        for text, color in steps:
            box = RoundedRectangle(
                width=2.5, height=0.7,
                corner_radius=0.1,
                stroke_color=color, stroke_width=2,
                fill_color=color, fill_opacity=0.12
            )
            label = Text(text, font_size=20, color=color, weight=BOLD)
            label.move_to(box)
            nodes.add(VGroup(box, label))

        row1 = VGroup(*nodes[:4]).arrange(RIGHT, buff=0.35)
        row2 = VGroup(*nodes[4:]).arrange(RIGHT, buff=0.35)
        flow = VGroup(row1, row2).arrange(DOWN, buff=0.8)
        flow.next_to(title, DOWN, buff=0.7)

        for node in nodes:
            self.play(FadeIn(node, scale=0.9), run_time=0.3)

        # 箭头
        arrows = VGroup()
        for i in range(3):
            a = Arrow(
                nodes[i].get_right(), nodes[i + 1].get_left(),
                color=DIM_COLOR, stroke_width=2, buff=0.08,
                max_tip_length_to_length_ratio=0.15
            )
            arrows.add(a)

        a_down = Arrow(
            nodes[3].get_bottom(), nodes[4].get_top(),
            color=DIM_COLOR, stroke_width=2, buff=0.08,
            max_tip_length_to_length_ratio=0.15
        )
        arrows.add(a_down)

        for i in range(4, 7):
            a = Arrow(
                nodes[i].get_right(), nodes[i + 1].get_left(),
                color=DIM_COLOR, stroke_width=2, buff=0.08,
                max_tip_length_to_length_ratio=0.15
            )
            arrows.add(a)

        self.play(
            LaggedStart(*[GrowArrow(a) for a in arrows], lag_ratio=0.1),
            run_time=1.0
        )

        # 循环
        loop_arrow = CurvedArrow(
            nodes[7].get_bottom() + DOWN * 0.1,
            nodes[0].get_bottom() + DOWN * 0.1,
            color=ACCENT_YELLOW, stroke_width=2,
            angle=-TAU / 3
        )
        repeat_label = Text("重复直到收敛", font_size=18, color=ACCENT_YELLOW)
        repeat_label.next_to(loop_arrow, DOWN, buff=0.15)

        self.play(Create(loop_arrow), FadeIn(repeat_label), run_time=0.8)
        self.wait(1.5)

        self.play(
            *[FadeOut(m) for m in [flow, arrows, loop_arrow, repeat_label]],
            run_time=0.5
        )

        # --- Loss 曲线 ---
        loss_title = Text("训练 Loss 下降曲线", font_size=28, color=ACCENT_ORANGE)
        loss_title.next_to(title, DOWN, buff=0.6)
        self.play(FadeIn(loss_title))

        loss_values = [3.86, 1.21, 0.66, 0.29, 0.21, 0.19, 0.19, 0.21, 0.08, 0.13]
        epochs = list(range(len(loss_values)))

        axes = Axes(
            x_range=[0, 9, 1],
            y_range=[0, 4.2, 1],
            x_length=8,
            y_length=3.5,
            axis_config={
                "color": DIM_COLOR,
                "stroke_width": 1.5,
                "include_tip": False,
                "include_numbers": False,
            },
        )
        axes.next_to(loss_title, DOWN, buff=0.5)

        # 手动添加坐标轴数字 (纯 Text，不触发 LaTeX)
        x_nums = VGroup()
        for v in [0, 2, 4, 6, 8]:
            lbl = Text(str(v), font_size=16, color=DIM_COLOR)
            lbl.next_to(axes.x_axis.number_to_point(v), DOWN, buff=0.15)
            x_nums.add(lbl)
        y_nums = VGroup()
        for v in [0, 1, 2, 3, 4]:
            lbl = Text(str(v), font_size=16, color=DIM_COLOR)
            lbl.next_to(axes.y_axis.number_to_point(v), LEFT, buff=0.15)
            y_nums.add(lbl)

        x_label = Text("Epoch (×50)", font_size=18, color=DIM_COLOR)
        y_label = Text("Loss", font_size=18, color=DIM_COLOR)
        x_label.next_to(axes.x_axis, DOWN, buff=0.35)
        y_label.next_to(axes.y_axis, LEFT, buff=0.35)

        self.play(Create(axes), FadeIn(x_nums), FadeIn(y_nums),
                  FadeIn(x_label), FadeIn(y_label), run_time=0.8)

        dots = VGroup()
        for e, l in zip(epochs, loss_values):
            dot = Dot(axes.c2p(e, l), radius=0.06, color=ACCENT_BLUE)
            dots.add(dot)

        line_segments = VGroup()
        for i in range(len(epochs) - 1):
            seg = Line(
                axes.c2p(epochs[i], loss_values[i]),
                axes.c2p(epochs[i + 1], loss_values[i + 1]),
                color=ACCENT_BLUE, stroke_width=2.5
            )
            line_segments.add(seg)

        self.play(
            LaggedStart(*[FadeIn(d, scale=1.5) for d in dots], lag_ratio=0.1),
            run_time=1.2
        )
        self.play(
            LaggedStart(*[Create(s) for s in line_segments], lag_ratio=0.1),
            run_time=1.0
        )

        # 标注起止
        start_label = Text("3.86", font_size=18, color=ACCENT_PINK)
        end_label = Text("0.13", font_size=18, color=ACCENT_GREEN)
        start_label.next_to(dots[0], UP, buff=0.15)
        end_label.next_to(dots[-1], UP, buff=0.15)
        self.play(FadeIn(start_label), FadeIn(end_label), run_time=0.5)

        self.wait(2)
        self.play(*[FadeOut(m) for m in self.mobjects], run_time=0.8)


# ============================================================
#  场景 5: 预测结果 + 总结
# ============================================================
class Scene5_Results(Scene):
    def construct(self):
        title = Text("训练结果验证", font_size=42, color=ACCENT_BLUE, weight=BOLD)
        title.to_edge(UP, buff=0.4)
        self.play(Write(title), run_time=1)

        test_cases = [
            {
                "pre": "the man eats the ",
                "post": "",
                "results": [("meat", 99.51), ("ball", 0.31), ("bread", 0.05)],
                "verdict": "✓",
                "v_color": ACCENT_GREEN,
            },
            {
                "pre": "the man read the ",
                "post": "",
                "results": [("book", 47.85), ("meat", 20.08), ("house", 11.64)],
                "verdict": "~",
                "v_color": ACCENT_YELLOW,
            },
            {
                "pre": "the man ",
                "post": " the book",
                "results": [("man", 71.53), ("plays", 22.49), ("woman", 5.23)],
                "verdict": "✗",
                "v_color": ACCENT_PINK,
            },
        ]

        all_groups = VGroup()

        for tc in test_cases:
            # 句子行
            pre_t = Text(tc["pre"], font_size=22, color=TEXT_COLOR)
            mask_t = Text("[MASK]", font_size=22, color=ACCENT_PINK, weight=BOLD)
            post_t = Text(tc["post"], font_size=22, color=TEXT_COLOR)
            mask_t.next_to(pre_t, RIGHT, buff=0.04)
            post_t.next_to(mask_t, RIGHT, buff=0.04)
            sent_line = VGroup(pre_t, mask_t, post_t)

            # 条形图
            bars = VGroup()
            max_bar_width = 4.5
            for word, pct in tc["results"]:
                word_t = Text(f"{word:<6}", font_size=18, color=TEXT_COLOR)
                bar_w = max(pct / 100 * max_bar_width, 0.05)
                bar = RoundedRectangle(
                    width=bar_w, height=0.25,
                    corner_radius=0.05,
                    fill_color=ACCENT_BLUE, fill_opacity=0.7,
                    stroke_width=0
                )
                pct_t = Text(f"{pct:.1f}%", font_size=16, color=ACCENT_BLUE)
                bar.next_to(word_t, RIGHT, buff=0.15)
                pct_t.next_to(bar, RIGHT, buff=0.1)
                bars.add(VGroup(word_t, bar, pct_t))

            bars.arrange(DOWN, buff=0.06, aligned_edge=LEFT)
            bars.next_to(sent_line, DOWN, buff=0.15, aligned_edge=LEFT)

            # 判定符号
            verdict = Text(tc["verdict"], font_size=36, color=tc["v_color"], weight=BOLD)
            verdict.next_to(bars, RIGHT, buff=0.6)

            group = VGroup(sent_line, bars, verdict)
            all_groups.add(group)

        all_groups.arrange(DOWN, buff=0.45, aligned_edge=LEFT)
        all_groups.next_to(title, DOWN, buff=0.5)

        # 确保不超出
        if all_groups.width > 12:
            all_groups.scale_to_fit_width(12)

        for g in all_groups:
            sent_line, bars, verdict = g[0], g[1], g[2]
            self.play(FadeIn(sent_line), run_time=0.4)
            for bar_row in bars:
                self.play(FadeIn(bar_row, shift=RIGHT * 0.3), run_time=0.2)
            self.play(FadeIn(verdict, scale=1.3), run_time=0.3)
            self.wait(0.4)

        self.wait(1.5)

        # --- 总结 ---
        self.play(*[FadeOut(m) for m in self.mobjects], run_time=0.8)

        final_title = Text("总结", font_size=44, color=ACCENT_BLUE, weight=BOLD)
        final_title.to_edge(UP, buff=0.8)
        self.play(FadeIn(final_title))

        points = [
            ("BERT 的核心", "通过遮词预测，学习上下文相关的向量表示", ACCENT_GREEN),
            ("Self-Attention", "让每个词「看到」并融合其他词的信息", ACCENT_ORANGE),
            ("训练本质", "不断调参，使预测越来越准", ACCENT_PURPLE),
        ]

        point_groups = VGroup()
        for label, desc, color in points:
            label_t = Text(label, font_size=28, color=color, weight=BOLD)
            desc_t = Text(desc, font_size=22, color=TEXT_COLOR)
            desc_t.next_to(label_t, DOWN, buff=0.1, aligned_edge=LEFT)
            point_groups.add(VGroup(label_t, desc_t))

        point_groups.arrange(DOWN, buff=0.5, aligned_edge=LEFT)
        point_groups.next_to(final_title, DOWN, buff=0.8)

        for pg in point_groups:
            self.play(FadeIn(pg, shift=RIGHT * 0.3), run_time=0.5)
            self.wait(0.3)

        self.wait(1)

        author = Text("教程作者：郝鸿涛", font_size=20, color=DIM_COLOR)
        author.to_edge(DOWN, buff=0.5)
        self.play(FadeIn(author, shift=UP * 0.2))
        self.wait(2)


# ============================================================
#  完整版
# ============================================================
class BertTutorial(Scene):
    def construct(self):
        self._run_scene1()
        self._run_scene2()
        self._run_scene3()
        self._run_scene4()
        self._run_scene5()

    # --- Scene 1 ---
    def _run_scene1(self):
        main_title = Text(
            "从零开始理解 BERT", font_size=52,
            color=ACCENT_BLUE, weight=BOLD
        )
        subtitle = Text(
            "Bidirectional Encoder Representations from Transformers",
            font_size=22, color=DIM_COLOR
        )
        subtitle.next_to(main_title, DOWN, buff=0.4)
        self.play(Write(main_title), run_time=1.5)
        self.play(FadeIn(subtitle, shift=UP * 0.3), run_time=0.8)
        self.wait(1.5)
        self.play(FadeOut(main_title), FadeOut(subtitle))

        question = Text("同一个词，不同的含义", font_size=40, color=ACCENT_ORANGE, weight=BOLD)
        question.to_edge(UP, buff=0.6)
        self.play(FadeIn(question, shift=DOWN * 0.3))

        sentences = [
            ("你来就来，还带礼物。你几个", "意思", "啊？"),
            ("别多想。我就", "意思", "一下"),
            ("我看不懂这个词是什么", "意思", ""),
            ("你这么说就没", "意思", "了"),
        ]
        sent_groups = VGroup()
        for pre, keyword, post in sentences:
            pre_t = Text(pre, font_size=28, color=TEXT_COLOR)
            key_t = Text(keyword, font_size=28, color=ACCENT_PINK, weight=BOLD)
            post_t = Text(post, font_size=28, color=TEXT_COLOR)
            key_t.next_to(pre_t, RIGHT, buff=0.05)
            post_t.next_to(key_t, RIGHT, buff=0.05)
            sent_groups.add(VGroup(pre_t, key_t, post_t))
        sent_groups.arrange(DOWN, buff=0.4, aligned_edge=LEFT)
        sent_groups.next_to(question, DOWN, buff=0.8)

        for line in sent_groups:
            self.play(FadeIn(line, shift=RIGHT * 0.3), run_time=0.6)
            box = highlight_box(line[1], color=ACCENT_PINK)
            self.play(Create(box), run_time=0.3)
            self.wait(0.3)
            self.play(FadeOut(box), run_time=0.2)
        self.wait(0.8)

        solution = Text("→ 上下文相关的向量表示", font_size=34, color=ACCENT_GREEN, weight=BOLD)
        solution.next_to(sent_groups, DOWN, buff=0.7)
        self.play(FadeIn(solution, shift=UP * 0.3))
        self.wait(0.8)
        arrow_text = Text("靠什么实现？ Self-Attention!", font_size=30, color=ACCENT_YELLOW)
        arrow_text.next_to(solution, DOWN, buff=0.4)
        self.play(FadeIn(arrow_text, shift=UP * 0.2))
        self.wait(2)
        self.play(*[FadeOut(m) for m in self.mobjects], run_time=0.8)

    # --- Scene 2 ---
    def _run_scene2(self):
        title = Text("BERT 的训练核心：遮词预测", font_size=40, color=ACCENT_BLUE, weight=BOLD)
        title.to_edge(UP, buff=0.5)
        self.play(Write(title), run_time=1)

        words = ["你", "这", "么", "说", "就", "没", "意", "思", "了", "啊"]
        word_boxes = VGroup()
        for w in words:
            t = Text(w, font_size=30, color=TEXT_COLOR)
            box = RoundedRectangle(width=0.7, height=0.7, corner_radius=0.1,
                stroke_color=ACCENT_BLUE, stroke_width=1.5,
                fill_color="#1a1a2e", fill_opacity=0.8)
            t.move_to(box)
            word_boxes.add(VGroup(box, t))
        word_boxes.arrange(RIGHT, buff=0.12)
        word_boxes.next_to(title, DOWN, buff=1.0)
        self.play(LaggedStart(*[FadeIn(wb, shift=UP * 0.2) for wb in word_boxes], lag_ratio=0.08), run_time=1.2)
        self.wait(0.5)

        mask_idx = 3
        target_box = word_boxes[mask_idx]
        mask_cover = RoundedRectangle(width=0.7, height=0.7, corner_radius=0.1,
            stroke_color=ACCENT_PINK, stroke_width=2.5, fill_color=ACCENT_PINK, fill_opacity=0.9)
        mask_label = Text("MASK", font_size=16, color="#0f0f1a", weight=BOLD)
        mask_label.move_to(mask_cover)
        mask_group = VGroup(mask_cover, mask_label)
        mask_group.move_to(target_box)
        self.play(FadeIn(mask_group, scale=1.3), target_box[1].animate.set_opacity(0), run_time=0.8)
        self.wait(0.8)

        steps_data = [
            ("①", "查词表", "每个字 → 向量"),
            ("②", "加位置", "告诉模型字的位置"),
            ("③", "Self-Attention", "字与字互相关注"),
            ("④", "预测", "MASK → 概率分布"),
        ]
        step_groups = VGroup()
        for num, name, desc in steps_data:
            num_t = Text(num, font_size=26, color=ACCENT_ORANGE, weight=BOLD)
            name_t = Text(name, font_size=26, color=ACCENT_BLUE, weight=BOLD)
            desc_t = Text(desc, font_size=20, color=DIM_COLOR)
            name_t.next_to(num_t, RIGHT, buff=0.15)
            desc_t.next_to(VGroup(num_t, name_t), DOWN, buff=0.1, aligned_edge=LEFT)
            step_groups.add(VGroup(num_t, name_t, desc_t))
        step_groups.arrange(DOWN, buff=0.35, aligned_edge=LEFT)
        step_groups.next_to(word_boxes, DOWN, buff=1.0).shift(LEFT * 2)
        for sg in step_groups:
            self.play(FadeIn(sg, shift=RIGHT * 0.3), run_time=0.5)
            self.wait(0.2)

        predict_arrow = Arrow(start=RIGHT * 1.5, end=RIGHT * 3.5, color=ACCENT_GREEN, stroke_width=3)
        predict_arrow.next_to(step_groups, RIGHT, buff=0.5).shift(DOWN * 0.5)
        result_text = Text('"说"  → 98.5%', font_size=28, color=ACCENT_GREEN, weight=BOLD)
        result_text.next_to(predict_arrow, RIGHT, buff=0.3)
        self.play(GrowArrow(predict_arrow), run_time=0.6)
        self.play(FadeIn(result_text, shift=LEFT * 0.2), run_time=0.5)
        self.wait(2)
        self.play(*[FadeOut(m) for m in self.mobjects], run_time=0.8)

    # --- Scene 3 ---
    def _run_scene3(self):
        scene_title = Text("手算 Self-Attention", font_size=42, color=ACCENT_BLUE, weight=BOLD)
        scene_title.to_edge(UP, buff=0.4)
        self.play(Write(scene_title), run_time=1)

        # 3a: setup
        setup_text = Text('输入: "one cat"    维度 d = 3', font_size=28, color=TEXT_COLOR)
        setup_text.next_to(scene_title, DOWN, buff=0.6)
        self.play(FadeIn(setup_text))

        x_mat = labeled_matrix("X =", [["1", "0", "2"], ["-1", "2", "0"]], label_color=ACCENT_BLUE)
        x_mat.next_to(setup_text, DOWN, buff=0.7)
        mat_body = x_mat[1]
        entries = mat_body[1]
        top_row_y = entries[0].get_center()[1]
        bot_row_y = entries[3].get_center()[1]
        one_label = Text("one →", font_size=20, color=ACCENT_GREEN)
        cat_label = Text("cat →", font_size=20, color=ACCENT_PINK)
        one_label.next_to(mat_body, LEFT, buff=0.5).set_y(top_row_y)
        cat_label.next_to(mat_body, LEFT, buff=0.5).set_y(bot_row_y)

        self.play(FadeIn(x_mat), FadeIn(one_label), FadeIn(cat_label), run_time=1.0)
        self.wait(1.5)
        self.play(FadeOut(setup_text), FadeOut(x_mat), FadeOut(one_label), FadeOut(cat_label), run_time=0.5)

        # 3b: QKV
        subtitle = Text("步骤 1-2：计算 Q, K, V", font_size=26, color=ACCENT_ORANGE)
        subtitle.next_to(scene_title, DOWN, buff=0.5)
        self.play(FadeIn(subtitle))
        explain = Text("Q = X·Wq    K = X·Wk    V = X·Wv", font_size=22, color=DIM_COLOR)
        explain.next_to(subtitle, DOWN, buff=0.3)
        self.play(FadeIn(explain))

        data = [
            ("Q =", [["1", "0", "2"], ["-1", "2", "0"]], ACCENT_BLUE),
            ("K =", [["1", "2", "0"], ["-1", "0", "2"]], ACCENT_GREEN),
            ("V =", [["2", "0", "4"], ["-2", "4", "0"]], ACCENT_PINK),
        ]
        matrices = VGroup()
        for label_str, vals, color in data:
            m = labeled_matrix(label_str, vals, label_color=color, font_size=24)
            matrices.add(m)
        matrices.arrange(RIGHT, buff=0.7)
        matrices.next_to(explain, DOWN, buff=0.6)
        if matrices.width > 13:
            matrices.scale_to_fit_width(13)
        for m in matrices:
            self.play(FadeIn(m, shift=UP * 0.2), run_time=0.6)
            self.wait(0.2)
        self.wait(1)
        self.play(FadeOut(subtitle), FadeOut(explain), FadeOut(matrices), run_time=0.5)

        # 3c: scores
        subtitle2 = Text("步骤 3-4：注意力分数 + 缩放", font_size=26, color=ACCENT_ORANGE)
        subtitle2.next_to(scene_title, DOWN, buff=0.5)
        self.play(FadeIn(subtitle2))
        formula = Text("Score = Q · Kᵀ", font_size=30, color=TEXT_COLOR)
        formula.next_to(subtitle2, DOWN, buff=0.5)
        self.play(Write(formula), run_time=0.8)

        calcs = VGroup(
            Text("one↔one: 1×1+0×2+2×0 = 1", font_size=20, color=DIM_COLOR),
            Text("one↔cat: 1×(-1)+0×0+2×2 = 3", font_size=20, color=ACCENT_YELLOW),
            Text("cat↔one: (-1)×1+2×2+0×0 = 3", font_size=20, color=ACCENT_YELLOW),
            Text("cat↔cat: (-1)×(-1)+2×0+0×2 = 1", font_size=20, color=DIM_COLOR),
        )
        calcs.arrange_in_grid(2, 2, buff=(0.6, 0.3))
        calcs.next_to(formula, DOWN, buff=0.5)
        for c in calcs:
            self.play(FadeIn(c, shift=UP * 0.1), run_time=0.35)
        self.wait(0.5)

        scale_label = Text("÷ √3 ≈", font_size=26, color=TEXT_COLOR)
        scaled_mat = labeled_matrix("", [["0.58", "1.73"], ["1.73", "0.58"]],
            label_color=ACCENT_BLUE, val_color=ACCENT_BLUE, font_size=26)
        scale_group = VGroup(scale_label, scaled_mat).arrange(RIGHT, buff=0.2)
        scale_group.next_to(calcs, DOWN, buff=0.5)
        self.play(FadeIn(scale_group, shift=UP * 0.2), run_time=0.8)
        self.wait(1.5)
        self.play(FadeOut(subtitle2), FadeOut(formula), FadeOut(calcs), FadeOut(scale_group), run_time=0.5)

        # 3d: softmax
        subtitle3 = Text("步骤 5：Softmax → 注意力权重", font_size=26, color=ACCENT_ORANGE)
        subtitle3.next_to(scene_title, DOWN, buff=0.5)
        self.play(FadeIn(subtitle3))
        calc_text = VGroup(
            Text("e⁰·⁵⁸ ≈ 1.79    e¹·⁷³ ≈ 5.64", font_size=22, color=DIM_COLOR),
            Text("权重: 1.79/7.43 ≈ 0.24    5.64/7.43 ≈ 0.76", font_size=22, color=DIM_COLOR),
        ).arrange(DOWN, buff=0.2)
        calc_text.next_to(subtitle3, DOWN, buff=0.5)
        self.play(FadeIn(calc_text), run_time=0.6)

        a_mat = labeled_matrix("A =", [["0.24", "0.76"], ["0.76", "0.24"]],
            label_color=ACCENT_PURPLE, val_color=TEXT_COLOR, font_size=30)
        a_mat.next_to(calc_text, DOWN, buff=0.5)
        self.play(FadeIn(a_mat), run_time=0.8)

        insight_box = RoundedRectangle(width=10, height=1.5, corner_radius=0.15,
            stroke_color=ACCENT_GREEN, stroke_width=1.5, fill_color="#1a2e1a", fill_opacity=0.6)
        insight_1 = Text('计算 "one" 的新向量时：', font_size=22, color=TEXT_COLOR)
        insight_2 = Text("保留 24% 自己  +  吸收 76% 来自 cat", font_size=24, color=ACCENT_GREEN, weight=BOLD)
        insight = VGroup(insight_1, insight_2).arrange(DOWN, buff=0.12)
        insight_box.move_to(insight)
        insight_group = VGroup(insight_box, insight)
        insight_group.next_to(a_mat, DOWN, buff=0.5)
        self.play(FadeIn(insight_group, shift=UP * 0.2), run_time=0.8)
        self.wait(2)
        self.play(FadeOut(subtitle3), FadeOut(calc_text), FadeOut(a_mat), FadeOut(insight_group), run_time=0.5)

        # 3e: weighted sum
        subtitle4 = Text("步骤 6：加权求和 → 新向量", font_size=26, color=ACCENT_ORANGE)
        subtitle4.next_to(scene_title, DOWN, buff=0.5)
        self.play(FadeIn(subtitle4))

        before_title = Text("原始向量", font_size=22, color=DIM_COLOR)
        before_one = Text("one:  [ 1,   0,   2 ]", font_size=24, color=ACCENT_BLUE)
        before_cat = Text("cat:  [-1,   2,   0 ]", font_size=24, color=ACCENT_PINK)
        before_group = VGroup(before_title, before_one, before_cat).arrange(DOWN, buff=0.2)
        arrow = Arrow(LEFT * 0.3, RIGHT * 0.8, color=ACCENT_YELLOW, stroke_width=3)
        after_title = Text("Attention 之后", font_size=22, color=ACCENT_GREEN)
        after_one = Text("one:  [-1.04,  3.04,  0.96]", font_size=24, color=ACCENT_BLUE)
        after_cat = Text("cat:  [ 1.04,  0.96,  3.04]", font_size=24, color=ACCENT_PINK)
        after_group = VGroup(after_title, after_one, after_cat).arrange(DOWN, buff=0.2)
        comparison = VGroup(before_group, arrow, after_group).arrange(RIGHT, buff=0.5)
        comparison.next_to(subtitle4, DOWN, buff=0.7)

        self.play(FadeIn(before_group, shift=RIGHT * 0.3), run_time=0.6)
        self.wait(0.5)
        self.play(GrowArrow(arrow), run_time=0.4)
        self.play(FadeIn(after_group, shift=LEFT * 0.3), run_time=0.6)

        note1 = Text("one 的中间维度: 0 → 3.04  (吸收了 cat 的特征!)", font_size=20, color=ACCENT_BLUE)
        note2 = Text("cat 的第三维度: 0 → 3.04  (吸收了 one 的特征!)", font_size=20, color=ACCENT_PINK)
        notes = VGroup(note1, note2).arrange(DOWN, buff=0.15)
        notes.next_to(comparison, DOWN, buff=0.5)
        self.play(FadeIn(notes), run_time=0.6)
        self.wait(0.5)
        conclusion = Text("核心本质：通过互相关注，将彼此的特征融入自己", font_size=26, color=ACCENT_YELLOW, weight=BOLD)
        conclusion.next_to(notes, DOWN, buff=0.5)
        self.play(FadeIn(conclusion, shift=UP * 0.2), run_time=0.8)
        self.wait(2.5)
        self.play(*[FadeOut(m) for m in self.mobjects], run_time=0.8)

    # --- Scene 4 ---
    def _run_scene4(self):
        title = Text("BERT 训练流程", font_size=42, color=ACCENT_BLUE, weight=BOLD)
        title.to_edge(UP, buff=0.4)
        self.play(Write(title), run_time=1)

        steps = [
            ("随机选句子", ACCENT_BLUE), ("遮住一个词", ACCENT_PINK),
            ("查 Embedding", ACCENT_GREEN), ("Self-Attention", ACCENT_ORANGE),
            ("预测被遮词", ACCENT_PURPLE), ("计算 Loss", ACCENT_YELLOW),
            ("反向传播", ACCENT_PINK), ("更新参数", ACCENT_GREEN),
        ]
        nodes = VGroup()
        for text, color in steps:
            box = RoundedRectangle(width=2.5, height=0.7, corner_radius=0.1,
                stroke_color=color, stroke_width=2, fill_color=color, fill_opacity=0.12)
            label = Text(text, font_size=20, color=color, weight=BOLD)
            label.move_to(box)
            nodes.add(VGroup(box, label))
        row1 = VGroup(*nodes[:4]).arrange(RIGHT, buff=0.35)
        row2 = VGroup(*nodes[4:]).arrange(RIGHT, buff=0.35)
        flow = VGroup(row1, row2).arrange(DOWN, buff=0.8)
        flow.next_to(title, DOWN, buff=0.7)
        for node in nodes:
            self.play(FadeIn(node, scale=0.9), run_time=0.3)

        arrows = VGroup()
        for i in range(3):
            arrows.add(Arrow(nodes[i].get_right(), nodes[i+1].get_left(),
                color=DIM_COLOR, stroke_width=2, buff=0.08, max_tip_length_to_length_ratio=0.15))
        arrows.add(Arrow(nodes[3].get_bottom(), nodes[4].get_top(),
            color=DIM_COLOR, stroke_width=2, buff=0.08, max_tip_length_to_length_ratio=0.15))
        for i in range(4, 7):
            arrows.add(Arrow(nodes[i].get_right(), nodes[i+1].get_left(),
                color=DIM_COLOR, stroke_width=2, buff=0.08, max_tip_length_to_length_ratio=0.15))
        self.play(LaggedStart(*[GrowArrow(a) for a in arrows], lag_ratio=0.1), run_time=1.0)

        loop_arrow = CurvedArrow(nodes[7].get_bottom() + DOWN * 0.1,
            nodes[0].get_bottom() + DOWN * 0.1, color=ACCENT_YELLOW, stroke_width=2, angle=-TAU / 3)
        repeat_label = Text("重复直到收敛", font_size=18, color=ACCENT_YELLOW)
        repeat_label.next_to(loop_arrow, DOWN, buff=0.15)
        self.play(Create(loop_arrow), FadeIn(repeat_label), run_time=0.8)
        self.wait(1.5)
        self.play(*[FadeOut(m) for m in [flow, arrows, loop_arrow, repeat_label]], run_time=0.5)

        # Loss curve
        loss_title = Text("训练 Loss 下降曲线", font_size=28, color=ACCENT_ORANGE)
        loss_title.next_to(title, DOWN, buff=0.6)
        self.play(FadeIn(loss_title))
        loss_values = [3.86, 1.21, 0.66, 0.29, 0.21, 0.19, 0.19, 0.21, 0.08, 0.13]
        epochs = list(range(len(loss_values)))
        axes = Axes(x_range=[0, 9, 1], y_range=[0, 4.2, 1], x_length=8, y_length=3.5,
            axis_config={"color": DIM_COLOR, "stroke_width": 1.5, "include_tip": False, "include_numbers": False})
        axes.next_to(loss_title, DOWN, buff=0.5)
        x_nums = VGroup()
        for v in [0, 2, 4, 6, 8]:
            lbl = Text(str(v), font_size=16, color=DIM_COLOR)
            lbl.next_to(axes.x_axis.number_to_point(v), DOWN, buff=0.15)
            x_nums.add(lbl)
        y_nums = VGroup()
        for v in [0, 1, 2, 3, 4]:
            lbl = Text(str(v), font_size=16, color=DIM_COLOR)
            lbl.next_to(axes.y_axis.number_to_point(v), LEFT, buff=0.15)
            y_nums.add(lbl)
        x_label = Text("Epoch (×50)", font_size=18, color=DIM_COLOR)
        y_label = Text("Loss", font_size=18, color=DIM_COLOR)
        x_label.next_to(axes.x_axis, DOWN, buff=0.35)
        y_label.next_to(axes.y_axis, LEFT, buff=0.35)
        self.play(Create(axes), FadeIn(x_nums), FadeIn(y_nums), FadeIn(x_label), FadeIn(y_label), run_time=0.8)

        dots = VGroup(*[Dot(axes.c2p(e, l), radius=0.06, color=ACCENT_BLUE) for e, l in zip(epochs, loss_values)])
        segs = VGroup(*[Line(axes.c2p(epochs[i], loss_values[i]), axes.c2p(epochs[i+1], loss_values[i+1]),
            color=ACCENT_BLUE, stroke_width=2.5) for i in range(len(epochs)-1)])
        self.play(LaggedStart(*[FadeIn(d, scale=1.5) for d in dots], lag_ratio=0.1), run_time=1.2)
        self.play(LaggedStart(*[Create(s) for s in segs], lag_ratio=0.1), run_time=1.0)
        start_l = Text("3.86", font_size=18, color=ACCENT_PINK)
        end_l = Text("0.13", font_size=18, color=ACCENT_GREEN)
        start_l.next_to(dots[0], UP, buff=0.15)
        end_l.next_to(dots[-1], UP, buff=0.15)
        self.play(FadeIn(start_l), FadeIn(end_l), run_time=0.5)
        self.wait(2)
        self.play(*[FadeOut(m) for m in self.mobjects], run_time=0.8)

    # --- Scene 5 ---
    def _run_scene5(self):
        title = Text("训练结果验证", font_size=42, color=ACCENT_BLUE, weight=BOLD)
        title.to_edge(UP, buff=0.4)
        self.play(Write(title), run_time=1)

        test_cases = [
            {"pre": "the man eats the ", "post": "",
             "results": [("meat", 99.51), ("ball", 0.31), ("bread", 0.05)],
             "verdict": "✓", "v_color": ACCENT_GREEN},
            {"pre": "the man read the ", "post": "",
             "results": [("book", 47.85), ("meat", 20.08), ("house", 11.64)],
             "verdict": "~", "v_color": ACCENT_YELLOW},
            {"pre": "the man ", "post": " the book",
             "results": [("man", 71.53), ("plays", 22.49), ("woman", 5.23)],
             "verdict": "✗", "v_color": ACCENT_PINK},
        ]

        all_groups = VGroup()
        for tc in test_cases:
            pre_t = Text(tc["pre"], font_size=22, color=TEXT_COLOR)
            mask_t = Text("[MASK]", font_size=22, color=ACCENT_PINK, weight=BOLD)
            post_t = Text(tc["post"], font_size=22, color=TEXT_COLOR)
            mask_t.next_to(pre_t, RIGHT, buff=0.04)
            post_t.next_to(mask_t, RIGHT, buff=0.04)
            sent_line = VGroup(pre_t, mask_t, post_t)

            bars = VGroup()
            for word, pct in tc["results"]:
                word_t = Text(f"{word:<6}", font_size=18, color=TEXT_COLOR)
                bar_w = max(pct / 100 * 4.5, 0.05)
                bar = RoundedRectangle(width=bar_w, height=0.25, corner_radius=0.05,
                    fill_color=ACCENT_BLUE, fill_opacity=0.7, stroke_width=0)
                pct_t = Text(f"{pct:.1f}%", font_size=16, color=ACCENT_BLUE)
                bar.next_to(word_t, RIGHT, buff=0.15)
                pct_t.next_to(bar, RIGHT, buff=0.1)
                bars.add(VGroup(word_t, bar, pct_t))
            bars.arrange(DOWN, buff=0.06, aligned_edge=LEFT)
            bars.next_to(sent_line, DOWN, buff=0.15, aligned_edge=LEFT)
            verdict = Text(tc["verdict"], font_size=36, color=tc["v_color"], weight=BOLD)
            verdict.next_to(bars, RIGHT, buff=0.6)
            all_groups.add(VGroup(sent_line, bars, verdict))

        all_groups.arrange(DOWN, buff=0.45, aligned_edge=LEFT)
        all_groups.next_to(title, DOWN, buff=0.5)
        if all_groups.width > 12:
            all_groups.scale_to_fit_width(12)

        for g in all_groups:
            self.play(FadeIn(g[0]), run_time=0.4)
            for bar_row in g[1]:
                self.play(FadeIn(bar_row, shift=RIGHT * 0.3), run_time=0.2)
            self.play(FadeIn(g[2], scale=1.3), run_time=0.3)
            self.wait(0.4)
        self.wait(1.5)

        self.play(*[FadeOut(m) for m in self.mobjects], run_time=0.8)

        final_title = Text("总结", font_size=44, color=ACCENT_BLUE, weight=BOLD)
        final_title.to_edge(UP, buff=0.8)
        self.play(FadeIn(final_title))
        points = [
            ("BERT 的核心", "通过遮词预测，学习上下文相关的向量表示", ACCENT_GREEN),
            ("Self-Attention", "让每个词「看到」并融合其他词的信息", ACCENT_ORANGE),
            ("训练本质", "不断调参，使预测越来越准", ACCENT_PURPLE),
        ]
        point_groups = VGroup()
        for label, desc, color in points:
            label_t = Text(label, font_size=28, color=color, weight=BOLD)
            desc_t = Text(desc, font_size=22, color=TEXT_COLOR)
            desc_t.next_to(label_t, DOWN, buff=0.1, aligned_edge=LEFT)
            point_groups.add(VGroup(label_t, desc_t))
        point_groups.arrange(DOWN, buff=0.5, aligned_edge=LEFT)
        point_groups.next_to(final_title, DOWN, buff=0.8)
        for pg in point_groups:
            self.play(FadeIn(pg, shift=RIGHT * 0.3), run_time=0.5)
            self.wait(0.3)
        self.wait(1)
        author = Text("教程作者：郝鸿涛", font_size=20, color=DIM_COLOR)
        author.to_edge(DOWN, buff=0.5)
        self.play(FadeIn(author, shift=UP * 0.2))
        self.wait(2)

In [8]:
%manim -qm -v WARNING BertTutorial

                                                                                                                                      