# Mini-GPT: 从零实现 (SOTA 标准)

**严谨实现** | 包含权重初始化、温度采样、因果掩码可视化

---

## 1. 理论基础

### 1.1 GPT 架构核心

**Decoder-Only Transformer**:
- 无 Cross-Attention（不像标准 Decoder）
- 因果自注意力（Causal Mask）防止信息泄漏
- 自回归生成：每个 token 只能依赖之前的历史

### 1.2 权重初始化策略 ⭐ (GPT-2 Paper)

**问题**: 随机初始化会导致训练初期梯度不稳定。

**解决方案**:

**标准层** (Linear/Embedding):
$$W \sim \mathcal{N}(0, 0.02^2)$$

**残差投影层** (Residual Projections):
$$W \sim \mathcal{N}\left(0, \frac{0.02^2}{2 \times n_{layers}}\right)$$

**直觉**: 残差路径上每层方差应保持一致。$n$ 层的残差连接会让方差累积 $\sqrt{n}$ 倍，因此需要除以 $\sqrt{n}$ 补偿。

### 1.3 采样策略

**Temperature Sampling**:
$$P'(x_i) = \frac{\exp(\log P(x_i) / T)}{\sum_j \exp(\log P(x_j) / T)}$$

- **T = 0.1**: 分布变尖锐，确定性生成，更保守
- **T = 1.0**: 原始分布
- **T = 2.0**: 分布变平滑，更随机/混乱

**Top-K Sampling**:
1. 将概率低于前 K 大的置 0
2. 重新归一化
3. 采样

这避免了采样到极低概率的"坏"token。

---

## 2. 代码实现

In [None]:
from __future__ import annotations

import math
from dataclasses import dataclass
from typing import Optional

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

In [None]:
@dataclass
class GPTConfig:
    """GPT 配置类。"""

    vocab_size: int = 65
    block_size: int = 256
    n_embd: int = 384
    n_head: int = 6
    n_layer: int = 6
    dropout: float = 0.2

    # 初始化参数 (GPT-2 风格)
    init_std: float = 0.02

    def __post_init__(self):
        assert self.n_embd % self.n_head == 0
        self.head_size = self.n_embd // self.n_head

In [None]:
class CausalSelfAttention(nn.Module):
    """因果自注意力（单个模块包含多头注意力）。

    核心思想:
        使用因果掩码确保位置 i 只能看到位置 0~i-1，
        防止自回归生成时的信息泄漏。
    """

    def __init__(self, config: GPTConfig) -> None:
        super().__init__()
        assert config.n_embd % config.n_head == 0

        self.n_head = config.n_head
        self.head_size = config.n_embd // config.n_head
        self.scale = 1.0 / math.sqrt(self.head_size)

        # 合并的 QKV 投影
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        # 输出投影
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)

        self.dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)

        # 因果掩码 (不参与梯度计算)
        mask = torch.tril(torch.ones(config.block_size, config.block_size))
        self.register_buffer("mask", mask.view(1, 1, config.block_size, config.block_size))

    def forward(self, x: Tensor) -> Tensor:
        B, T, C = x.shape

        # QKV 投影并分头
        qkv = self.c_attn(x)
        q, k, v = qkv.split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, self.head_size).transpose(1, 2)
        q = q.view(B, T, self.n_head, self.head_size).transpose(1, 2)
        v = v.view(B, T, self.n_head, self.head_size).transpose(1, 2)

        # 注意力分数: einsum 'b h i d, b h j d -> b h i j'
        scores = torch.einsum("b h i d, b h j d -> b h i j", q, k) * self.scale

        # 应用因果掩码
        scores = scores.masked_fill(self.mask[:, :, :T, :T] == 0, -1e9)

        # Softmax + Dropout
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # 加权求和: einsum 'b h i j, b h j d -> b h i d'
        out = torch.einsum("b h i j, b h j d -> b h i d", attn_weights, v)
        out = out.transpose(1, 2).contiguous().view(B, T, C)

        return self.resid_dropout(self.c_proj(out))

In [None]:
class FeedForward(nn.Module):
    """位置前馈网络，使用 GELU 激活。"""

    def __init__(self, config: GPTConfig) -> None:
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
        self.dropout = nn.Dropout(config.dropout)
        self.gelu = nn.GELU()

    def forward(self, x: Tensor) -> Tensor:
        x = self.gelu(self.c_fc(x))
        x = self.c_proj(x)
        return self.dropout(x)

In [None]:
class Block(nn.Module):
    """Transformer Block (Pre-Norm)。"""

    def __init__(self, config: GPTConfig) -> None:
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = FeedForward(config)

    def forward(self, x: Tensor) -> Tensor:
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

In [None]:
class GPT(nn.Module):
    """完整的 GPT 语言模型。

    核心思想:
        Decoder-Only Transformer + 因果掩码 + 自回归生成

    初始化策略 (GPT-2):
        - 标准层: N(0, 0.02)
        - 残差投影: N(0, 0.02 / sqrt(2 * n_layers))
    """

    def __init__(self, config: GPTConfig) -> None:
        super().__init__()
        self.config = config

        self.wte = nn.Embedding(config.vocab_size, config.n_embd)
        self.wpe = nn.Embedding(config.block_size, config.n_embd)
        self.drop = nn.Dropout(config.dropout)

        self.h = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
        self.ln_f = nn.LayerNorm(config.n_embd)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        # 权重共享: token embedding 和 lm_head 共享权重
        self.wte.weight = self.lm_head.weight

        self.apply(self._init_weights)

        # 残差投影特殊初始化
        for pn, p in self.named_parameters():
            if pn.endswith("c_proj.weight"):
                nn.init.normal_(p, mean=0.0, std=config.init_std / math.sqrt(2 * config.n_layer))

    def _init_weights(self, module: nn.Module) -> None:
        """GPT-2 风格的权重初始化。"""
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std)
        elif isinstance(module, nn.LayerNorm):
            nn.init.zeros_(module.bias)
            nn.init.ones_(module.weight)

    def forward(
        self, idx: Tensor, targets: Optional[Tensor] = None
    ) -> tuple[Tensor, Optional[Tensor]]:
        B, T = idx.shape
        assert (
            T <= self.config.block_size
        ), f"Sequence length {T} exceeds block_size {self.config.block_size}"

        pos = torch.arange(0, T, dtype=torch.long, device=idx.device)
        x = self.wte(idx) + self.wpe(pos)
        x = self.drop(x)

        for block in self.h:
            x = block(x)

        x = self.ln_f(x)
        logits = self.lm_head(x)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))

        return logits, loss

    @torch.no_grad()
    def generate(
        self,
        idx: Tensor,
        max_new_tokens: int,
        temperature: float = 1.0,
        top_k: Optional[int] = None,
    ) -> Tensor:
        """自回归生成文本。

        Args:
            idx: 上下文 (B, T)
            max_new_tokens: 生成 token 数量
            temperature: 温度参数，<1 更保守，>1 更随机
            top_k: 只从概率最高的 K 个 token 中采样
        """
        for _ in range(max_new_tokens):
            idx_crop = idx[:, -self.config.block_size :]
            logits, _ = self(idx_crop)
            logits = logits[:, -1, :] / temperature

            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float("Inf")

            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, idx_next], dim=1)

        return idx

---

## 3. 因果掩码可视化

In [None]:
def visualize_causal_mask(seq_len: int = 10) -> None:
    """可视化因果掩码如何阻止信息泄漏。"""
    mask = torch.tril(torch.ones(seq_len, seq_len))

    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    # 左图：掩码矩阵
    ax1 = axes[0]
    im1 = ax1.imshow(mask.numpy(), cmap="Blues", origin="upper")
    ax1.set_xticks(range(seq_len))
    ax1.set_yticks(range(seq_len))
    ax1.set_xlabel("Key 位置 (j)")
    ax1.set_ylabel("Query 位置 (i)")
    ax1.set_title("因果掩码矩阵\n白色=屏蔽, 蓝色=可见")

    # 标注关键点
    for i in range(seq_len):
        for j in range(seq_len):
            if mask[i, j] == 1:
                ax1.text(j, i, "✓", ha="center", va="center", color="green", fontsize=10)
            else:
                ax1.text(j, i, "✗", ha="center", va="center", color="red", fontsize=10)

    # 右图：信息流向图
    ax2 = axes[1]
    im2 = ax2.imshow(mask.numpy(), cmap="Greens", origin="upper")
    ax2.set_xticks(range(seq_len))
    ax2.set_yticks(range(seq_len))
    ax2.set_xlabel("来源位置")
    ax2.set_ylabel("目标位置")
    ax2.set_title("信息流向图\n位置 i 可以从哪些位置获取信息")

    for i in range(seq_len):
        ax2.text(
            i,
            i,
            f"t={i}",
            ha="center",
            va="center",
            color="black" if mask[i, i] > 0.5 else "white",
            fontsize=8,
            fontweight="bold",
        )

    plt.tight_layout()
    plt.show()

    print("\n关键观察:")
    print("  - 下三角 (含对角线) 为 1 (蓝色/绿色): 允许的注意力连接")
    print("  - 上三角为 0 (白色): 被屏蔽的连接")
    print("  - 位置 0 只能看到自己")
    print(f"  - 位置 {seq_len-1} 可以看到所有位置 0~{seq_len-1}")


visualize_causal_mask()

In [None]:
def explain_information_leakage() -> None:
    """解释因果掩码如何防止信息泄漏。"""
    print("=" * 60)
    print("信息泄漏 (Information Leakage) 解释")
    print("=" * 60)
    print()
    print("问题场景: 自回归生成")
    print("  假设我们要生成句子: 'The cat sat on the ...'")
    print()
    print("  步骤 1: 已生成 'The'")
    print("  步骤 2: 预测下一个词 → 得到 'cat'")
    print("  步骤 3: 预测下一个词 → 得到 'sat'")
    print()
    print("如果没有因果掩码:")
    print("  - 'The' 在步骤 3 时能看到 'sat' (未来信息！)")
    print("  - 这相当于考试时偷看答案")
    print("  - 模型无法学会真正的预测")
    print()
    print("有因果掩码:")
    print("  - 'The' 只能看到自己")
    print("  - 'cat' 只能看到 'The' 和 'cat'")
    print("  - 每个位置只能基于历史信息进行预测")
    print("  - 保证训练和生成的一致性")
    print("=" * 60)


explain_information_leakage()

---

## 4. 测试与验证

In [None]:
def test_forward_pass() -> None:
    """验证前向传播形状正确性。"""
    config = GPTConfig(vocab_size=100, block_size=32, n_embd=64, n_head=4, n_layer=2)
    model = GPT(config)

    B, T = 2, 16
    idx = torch.randint(0, config.vocab_size, (B, T))
    targets = torch.randint(0, config.vocab_size, (B, T))

    logits, loss = model(idx, targets)

    assert logits.shape == (B, T, config.vocab_size)
    assert loss is not None

    print("[PASS] test_forward_pass")
    print(f"  Logits shape: {logits.shape}")
    print(f"  Loss: {loss.item():.4f}")


test_forward_pass()

In [None]:
def test_initialization() -> None:
    """验证权重初始化是否符合 GPT-2 标准。"""
    config = GPTConfig(n_layer=6)
    model = GPT(config)

    # 检查标准层的初始化
    for name, param in model.named_parameters():
        if "c_proj.weight" in name:
            expected_std = config.init_std / math.sqrt(2 * config.n_layer)
            actual_std = param.std().item()
            print(f"{name}: std={actual_std:.6f} (expected ~{expected_std:.6f})")
        elif "wte.weight" in name or "c_attn.weight" in name:
            print(f"{name}: std={param.std().item():.6f} (expected ~{config.init_std:.6f})")

    print("[PASS] test_initialization")


test_initialization()

---

## 5. 采样策略对比

In [None]:
def demo_temperature_sampling() -> None:
    """演示不同温度对采样结果的影响。"""
    # 模拟一个 logit 分布 (10 个 token)
    logits = torch.tensor([2.0, 1.5, 1.0, 0.5, 0.0, -0.5, -1.0, -1.5, -2.0, -3.0])
    tokens = [f"token_{i}" for i in range(len(logits))]

    temperatures = [0.1, 0.5, 1.0, 2.0]

    fig, axes = plt.subplots(1, len(temperatures), figsize=(16, 4))

    for ax, temp in zip(axes, temperatures):
        probs = F.softmax(logits / temp, dim=0)
        ax.bar(range(len(tokens)), probs.numpy(), color="steelblue")
        ax.set_xticks(range(len(tokens)))
        ax.set_xticklabels(tokens, rotation=45, ha="right")
        ax.set_ylabel("Probability")
        ax.set_title(f"Temperature = {temp}")
        ax.set_ylim(0, 1)

        # 标注最高概率
        top_idx = probs.argmax().item()
        ax.bar(top_idx, probs[top_idx].item(), color="orange")

    plt.tight_layout()
    plt.show()

    print("\n温度影响分析:")
    print("  T=0.1: 分布极尖锐，几乎总是选择 token_0")
    print("  T=0.5: 分布较集中，偏向高概率 token")
    print("  T=1.0: 原始分布")
    print("  T=2.0: 分布平滑，低概率 token 也有机会")


demo_temperature_sampling()

In [None]:
def demo_top_k_sampling() -> None:
    """演示 Top-K 采样。"""
    logits = torch.randn(20)
    k_values = [1, 3, 5, 10]

    fig, axes = plt.subplots(1, len(k_values), figsize=(16, 4))

    for ax, k in zip(axes, k_values):
        probs = F.softmax(logits, dim=0)

        # Top-K 处理
        v, _ = torch.topk(logits, k)
        logits_masked = logits.clone()
        logits_masked[logits < v[-1]] = -float("Inf")
        probs_topk = F.softmax(logits_masked, dim=0)

        colors = ["orange" if logits[i] >= v[-1] else "lightgray" for i in range(len(logits))]
        ax.bar(range(len(logits)), probs.numpy(), color="lightgray")
        ax.bar(range(len(logits)), probs_topp.numpy(), color="orange")
        ax.set_title(f"Top-K={k}")
        ax.set_xlabel("Token Index")
        ax.set_ylabel("Probability")

    plt.tight_layout()
    plt.show()

    print("\nTop-K 采样分析:")
    print("  K=1: 等价于贪心采样 (Greedy)")
    print("  K=3: 只从前 3 个高概率 token 中采样")
    print("  K=10: 从前 10 个中采样，多样性增加")
    print("  优势: 避免采样到极低概率的'坏' token")


demo_top_k_sampling()

---

## 6. 完整生成 Demo

In [None]:
def demo_generation_with_strategies() -> None:
    """对比不同采样策略的生成结果。"""
    # 使用简单词表
    vocab = {i: c for i, c in enumerate("abcdefghijklmnopqrstuvwxyz ")}
    char_to_idx = {c: i for i, c in vocab.items()}

    config = GPTConfig(
        vocab_size=len(vocab),
        block_size=32,
        n_embd=64,
        n_head=2,
        n_layer=2,
    )

    # 随机初始化模型 (未训练)
    model = GPT(config).eval()

    prompt = "the "
    context = torch.tensor([[char_to_idx[c] for c in prompt]], dtype=torch.long)

    strategies = [
        ("Conservative (T=0.1, K=1)", {"temperature": 0.1, "top_k": 1}),
        ("Balanced (T=1.0, K=5)", {"temperature": 1.0, "top_k": 5}),
        ("Creative (T=2.0, K=10)", {"temperature": 2.0, "top_k": 10}),
    ]

    print("生成结果对比 (未训练模型，仅供格式演示):\n")
    for name, kwargs in strategies:
        generated = model.generate(context, max_new_tokens=20, **kwargs)
        text = "".join([vocab[i] for i in generated[0].tolist()])
        print(f"{name}:")
        print(f"  {text}")
        print()


demo_generation_with_strategies()

---

## 7. KV Cache: 高效推理 ⭐⭐

### 7.1 核心思想

**问题**: 自回归生成时，每生成一个 token 都要重新计算所有历史 token 的 K, V。

**解决方案**: 缓存历史 token 的 K, V，只计算新 token 的 K, V。

**复杂度优化**:
- 无缓存: 生成 $n$ 个 token 需要 $O(n^3)$ 计算
- 有缓存: 生成 $n$ 个 token 需要 $O(n^2)$ 计算

In [None]:
class KVCache:
    """KV Cache 用于高效自回归生成。

    核心思想:
        缓存历史 token 的 Key 和 Value，避免重复计算。

    复杂度:
        - 无缓存: O(n^3) 生成 n 个 token
        - 有缓存: O(n^2) 生成 n 个 token
    """

    def __init__(self, max_batch_size: int, max_seq_len: int, n_heads: int, head_dim: int) -> None:
        self.max_batch_size = max_batch_size
        self.max_seq_len = max_seq_len
        self.n_heads = n_heads
        self.head_dim = head_dim

        # 预分配缓存空间
        self.k_cache = torch.zeros(max_batch_size, n_heads, max_seq_len, head_dim)
        self.v_cache = torch.zeros(max_batch_size, n_heads, max_seq_len, head_dim)
        self.seq_len = 0

    def update(self, k: Tensor, v: Tensor) -> Tuple[Tensor, Tensor]:
        """更新缓存并返回完整的 K, V。

        Args:
            k: 新的 Key (batch, n_heads, new_len, head_dim)
            v: 新的 Value (batch, n_heads, new_len, head_dim)

        Returns:
            完整的 K, V (包含历史和新的)
        """
        batch_size, n_heads, new_len, head_dim = k.shape

        # 更新缓存
        self.k_cache[:batch_size, :, self.seq_len : self.seq_len + new_len, :] = k
        self.v_cache[:batch_size, :, self.seq_len : self.seq_len + new_len, :] = v
        self.seq_len += new_len

        # 返回完整的 K, V
        return (
            self.k_cache[:batch_size, :, : self.seq_len, :],
            self.v_cache[:batch_size, :, : self.seq_len, :],
        )

    def reset(self) -> None:
        """重置缓存。"""
        self.seq_len = 0


class CausalSelfAttentionWithKVCache(nn.Module):
    """支持 KV Cache 的因果自注意力。"""

    def __init__(self, config: GPTConfig) -> None:
        super().__init__()
        self.n_head = config.n_head
        self.head_size = config.n_embd // config.n_head
        self.scale = 1.0 / math.sqrt(self.head_size)

        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        self.dropout = nn.Dropout(config.dropout)

    def forward(
        self,
        x: Tensor,
        kv_cache: Optional[KVCache] = None,
        start_pos: int = 0,
    ) -> Tensor:
        B, T, C = x.shape

        # QKV 投影
        qkv = self.c_attn(x)
        q, k, v = qkv.split(C, dim=2)
        q = q.view(B, T, self.n_head, self.head_size).transpose(1, 2)
        k = k.view(B, T, self.n_head, self.head_size).transpose(1, 2)
        v = v.view(B, T, self.n_head, self.head_size).transpose(1, 2)

        # 使用 KV Cache
        if kv_cache is not None:
            k, v = kv_cache.update(k, v)

        # 注意力计算
        scores = torch.einsum("b h i d, b h j d -> b h i j", q, k) * self.scale

        # 因果掩码 (只对新 token 应用)
        seq_len = k.shape[2]
        mask = torch.triu(torch.ones(T, seq_len, device=x.device), diagonal=seq_len - T + 1)
        scores = scores.masked_fill(mask.bool().unsqueeze(0).unsqueeze(0), -1e9)

        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)

        out = torch.einsum("b h i j, b h j d -> b h i d", attn, v)
        out = out.transpose(1, 2).contiguous().view(B, T, C)

        return self.c_proj(out)


# 测试 KV Cache
def test_kv_cache() -> None:
    cache = KVCache(max_batch_size=2, max_seq_len=64, n_heads=4, head_dim=32)

    # 模拟逐 token 生成
    for i in range(5):
        k_new = torch.randn(2, 4, 1, 32)  # 每次一个新 token
        v_new = torch.randn(2, 4, 1, 32)
        k_full, v_full = cache.update(k_new, v_new)
        print(f"Step {i+1}: K shape = {k_full.shape}, cached_len = {cache.seq_len}")

    print("[PASS] KV Cache 测试通过")


test_kv_cache()

---

## 8. Nucleus (Top-P) Sampling ⭐

### 8.1 核心思想

**Top-K 的问题**: 固定 K 值，无法适应不同的概率分布。

**Nucleus Sampling**: 动态选择累积概率达到 $p$ 的最小 token 集合。

$$\text{Nucleus}(p) = \min\{V' \subseteq V : \sum_{v \in V'} P(v) \geq p\}$$

In [None]:
def top_p_sampling(logits: Tensor, top_p: float = 0.9, temperature: float = 1.0) -> Tensor:
    """Nucleus (Top-P) Sampling。

    核心思想:
        动态选择累积概率达到 p 的最小 token 集合，
        比 Top-K 更灵活，能适应不同的概率分布。

    Args:
        logits: 模型输出 (batch, vocab_size)
        top_p: 累积概率阈值 (0.9 表示保留 90% 概率质量)
        temperature: 温度参数

    Returns:
        采样的 token 索引 (batch, 1)
    """
    # 温度缩放
    logits = logits / temperature

    # 排序
    sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
    sorted_probs = F.softmax(sorted_logits, dim=-1)

    # 计算累积概率
    cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

    # 找到累积概率超过 top_p 的位置
    sorted_indices_to_remove = cumulative_probs > top_p
    # 保留第一个超过阈值的 token
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
    sorted_indices_to_remove[..., 0] = False

    # 将被移除的 token 的 logits 设为 -inf
    sorted_logits[sorted_indices_to_remove] = -float("Inf")

    # 恢复原始顺序
    logits = torch.gather(sorted_logits, -1, sorted_indices.argsort(-1))

    # 采样
    probs = F.softmax(logits, dim=-1)
    return torch.multinomial(probs, num_samples=1)


# 可视化 Top-P vs Top-K
def compare_sampling_methods() -> None:
    """对比 Top-K 和 Top-P 采样。"""
    torch.manual_seed(42)

    # 两种不同的分布
    # 分布1: 尖锐 (一个 token 主导)
    logits_sharp = torch.tensor([[5.0, 1.0, 0.5, 0.2, 0.1, -1.0, -2.0, -3.0, -4.0, -5.0]])
    # 分布2: 平坦 (多个 token 概率相近)
    logits_flat = torch.tensor([[1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]])

    fig, axes = plt.subplots(2, 3, figsize=(15, 8))

    for row, (logits, name) in enumerate([(logits_sharp, "尖锐分布"), (logits_flat, "平坦分布")]):
        probs = F.softmax(logits, dim=-1).squeeze()

        # 原始分布
        axes[row, 0].bar(range(10), probs.numpy(), color="steelblue")
        axes[row, 0].set_title(f"{name}\n原始分布")
        axes[row, 0].set_ylim(0, 1)

        # Top-K=3
        k = 3
        v, _ = torch.topk(logits, k)
        mask = logits < v[:, -1:]
        logits_topk = logits.clone()
        logits_topk[mask] = -float("Inf")
        probs_topk = F.softmax(logits_topk, dim=-1).squeeze()
        colors = ["orange" if p > 0 else "lightgray" for p in probs_topk]
        axes[row, 1].bar(range(10), probs_topk.numpy(), color=colors)
        axes[row, 1].set_title(f"Top-K=3\n保留 {(probs_topk > 0).sum().item()} 个 token")
        axes[row, 1].set_ylim(0, 1)

        # Top-P=0.9
        sorted_probs, sorted_idx = torch.sort(probs, descending=True)
        cumsum = torch.cumsum(sorted_probs, dim=0)
        n_keep = (cumsum < 0.9).sum().item() + 1
        probs_topp = probs.clone()
        threshold = sorted_probs[n_keep - 1]
        probs_topp[probs < threshold] = 0
        probs_topp = probs_topp / probs_topp.sum()
        colors = ["green" if p > 0 else "lightgray" for p in probs_topp]
        axes[row, 2].bar(range(10), probs_topp.numpy(), color=colors)
        axes[row, 2].set_title(f"Top-P=0.9\n保留 {(probs_topp > 0).sum().item()} 个 token")
        axes[row, 2].set_ylim(0, 1)

    plt.suptitle("Top-K vs Top-P 采样对比", fontsize=14)
    plt.tight_layout()
    plt.show()

    print("\n观察:")
    print("  - 尖锐分布: Top-K=3 和 Top-P=0.9 效果相似")
    print("  - 平坦分布: Top-P 自动保留更多 token，更灵活")


compare_sampling_methods()

---

## 9. 完整训练循环 (Shakespeare 数据集)

In [None]:
def train_mini_gpt(epochs: int = 3, batch_size: int = 32, block_size: int = 64) -> GPT:
    """在 Shakespeare 数据集上训练 Mini-GPT。"""
    # 下载数据
    import urllib.request

    url = (
        "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
    )
    try:
        with urllib.request.urlopen(url, timeout=10) as response:
            text = response.read().decode("utf-8")
    except Exception:
        # 使用模拟数据
        text = "To be or not to be, that is the question. " * 1000
        print("使用模拟数据进行演示")

    print(f"数据集大小: {len(text):,} 字符")

    # 构建词表
    chars = sorted(list(set(text)))
    vocab_size = len(chars)
    stoi = {ch: i for i, ch in enumerate(chars)}
    itos = {i: ch for i, ch in enumerate(chars)}
    encode = lambda s: [stoi[c] for c in s]
    decode = lambda l: "".join([itos[i] for i in l])

    print(f"词表大小: {vocab_size}")

    # 编码数据
    data = torch.tensor(encode(text), dtype=torch.long)
    n = int(0.9 * len(data))
    train_data = data[:n]
    val_data = data[n:]

    # 数据加载函数
    def get_batch(split: str) -> Tuple[Tensor, Tensor]:
        data_split = train_data if split == "train" else val_data
        ix = torch.randint(len(data_split) - block_size, (batch_size,))
        x = torch.stack([data_split[i : i + block_size] for i in ix])
        y = torch.stack([data_split[i + 1 : i + block_size + 1] for i in ix])
        return x.to(device), y.to(device)

    # 创建模型
    config = GPTConfig(
        vocab_size=vocab_size,
        block_size=block_size,
        n_embd=128,
        n_head=4,
        n_layer=4,
        dropout=0.1,
    )
    model = GPT(config).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

    print(f"模型参数量: {sum(p.numel() for p in model.parameters()):,}")

    # 训练循环
    steps_per_epoch = 100
    print("\n开始训练...")

    for epoch in range(epochs):
        model.train()
        total_loss = 0

        for step in range(steps_per_epoch):
            xb, yb = get_batch("train")
            _, loss = model(xb, yb)

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / steps_per_epoch

        # 验证
        model.eval()
        with torch.no_grad():
            xv, yv = get_batch("val")
            _, val_loss = model(xv, yv)

        print(f"Epoch {epoch+1}/{epochs} | Train Loss: {avg_loss:.4f} | Val Loss: {val_loss:.4f}")

    # 生成示例
    print("\n生成示例:")
    model.eval()
    context = torch.tensor([encode("To be or not")], dtype=torch.long, device=device)
    generated = model.generate(context, max_new_tokens=100, temperature=0.8, top_k=40)
    print(decode(generated[0].tolist()))

    return model


# 运行训练 (减少 epoch 用于演示)
trained_model = train_mini_gpt(epochs=2, batch_size=16, block_size=32)

---

## 10. 总结

| 特性 | 说明 |
|:-----|:-----|
| **因果掩码** | 防止自回归生成时的信息泄漏 |
| **权重初始化** | GPT-2 风格，残差层特殊处理 |
| **温度采样** | 控制生成保守度 (T<1 保守, T>1 创造) |
| **Top-K 采样** | 固定数量的高概率 token |
| **Top-P (Nucleus)** | 动态选择累积概率达到 p 的 token |
| **KV Cache** | 缓存历史 K, V，推理加速 $O(n^3) \to O(n^2)$ |
| **Pre-Norm** | 梯度流更稳定，GPT 标配 |

**进阶学习**: Speculative Decoding, Continuous Batching, PagedAttention