# Transformer Encoder 深度实现

**SOTA 教育标准** | 包含 Pre-LN/Post-LN 切换、GELU 激活、形状追踪

---

## 1. 理论基础

### 1.1 Pre-LN vs Post-LN 架构对比

**Post-LN (原始 Transformer, BERT-style)**:
$$x = \text{LayerNorm}(x + \text{Sublayer}(x))$$

**Pre-LN (GPT-style, 更稳定)**:
$$x = x + \text{Sublayer}(\text{LayerNorm}(x))$$

### 1.2 梯度流分析 ⭐

**Post-LN 的问题**:
- 梯度需要穿过 LayerNorm 才能到达残差路径
- 深层网络中梯度可能不稳定
- 需要 Learning Rate Warmup 来稳定训练

**Pre-LN 的优势**:
- 残差路径上梯度直接流动，不经过 LayerNorm
- 梯度流更稳定，可以训练更深的网络
- 通常不需要 Warmup

**直觉**: Pre-LN 让残差连接成为"梯度高速公路"，信息可以无阻碍地反向传播。

### 1.3 GELU 激活函数

**公式**:
$$\text{GELU}(x) = x \cdot \Phi(x) = x \cdot \frac{1}{2}\left[1 + \text{erf}\left(\frac{x}{\sqrt{2}}\right)\right]$$

**近似公式** (更快):
$$\text{GELU}(x) \approx 0.5x\left(1 + \tanh\left[\sqrt{\frac{2}{\pi}}(x + 0.044715x^3)\right]\right)$$

**相比 ReLU 的优势**:
1. **平滑性**: GELU 处处可微，ReLU 在 0 点不可微
2. **概率解释**: GELU 可解释为"以概率 $\Phi(x)$ 保留输入"
3. **负值处理**: GELU 允许小的负值通过，ReLU 完全截断
4. **实践效果**: GPT、BERT 等模型均采用 GELU

---

## 2. 代码实现

In [None]:
from __future__ import annotations

import math
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

In [None]:
@dataclass
class EncoderConfig:
    """Encoder 配置类，避免魔术数字。"""

    vocab_size: int = 10000
    d_model: int = 512
    n_layers: int = 6
    n_heads: int = 8
    d_ff: int = 2048
    max_len: int = 5000
    dropout: float = 0.1
    pad_idx: int = 0
    norm_first: bool = True  # True=Pre-LN (GPT), False=Post-LN (BERT)
    activation: str = "gelu"  # "gelu" or "relu"
    mask_value: float = -1e9

In [None]:
class GELU(nn.Module):
    """Gaussian Error Linear Unit 激活函数。

    核心思想: 以输入值的累积分布函数为概率，随机"门控"输入。

    数学原理:
        GELU(x) = x * Phi(x)
        其中 Phi(x) 是标准正态分布的 CDF。

    相比 ReLU:
        - 平滑可微，梯度更稳定
        - 允许小负值通过，信息保留更完整
        - GPT/BERT 等 SOTA 模型的标准选择
    """

    def __init__(self, approximate: bool = True) -> None:
        super().__init__()
        self.approximate = approximate

    def forward(self, x: Tensor) -> Tensor:
        if self.approximate:
            # 快速近似: 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
            return F.gelu(x, approximate="tanh")
        else:
            # 精确计算: x * Phi(x)
            return F.gelu(x, approximate="none")

In [None]:
class MultiHeadAttention(nn.Module):
    """多头注意力机制。"""

    def __init__(self, config: EncoderConfig) -> None:
        super().__init__()
        assert config.d_model % config.n_heads == 0

        self.d_model = config.d_model
        self.n_heads = config.n_heads
        self.d_k = config.d_model // config.n_heads
        self.scale = 1.0 / math.sqrt(self.d_k)
        self.mask_value = config.mask_value

        self.W_q = nn.Linear(config.d_model, config.d_model)
        self.W_k = nn.Linear(config.d_model, config.d_model)
        self.W_v = nn.Linear(config.d_model, config.d_model)
        self.W_o = nn.Linear(config.d_model, config.d_model)
        self.dropout = nn.Dropout(config.dropout)

    def forward(
        self,
        query: Tensor,
        key: Tensor,
        value: Tensor,
        mask: Optional[Tensor] = None,
    ) -> Tuple[Tensor, Tensor]:
        batch_size, seq_len, _ = query.shape

        # 线性投影并分头: (batch, seq, d_model) -> (batch, n_heads, seq, d_k)
        q = self.W_q(query).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        k = self.W_k(key).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        v = self.W_v(value).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)

        # 注意力分数: einsum 'b h i d, b h j d -> b h i j'
        scores = torch.einsum("b h i d, b h j d -> b h i j", q, k) * self.scale

        if mask is not None:
            scores = scores.masked_fill(mask.bool(), self.mask_value)

        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # 加权求和: einsum 'b h i j, b h j d -> b h i d'
        context = torch.einsum("b h i j, b h j d -> b h i d", attn_weights, v)
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)

        output = self.W_o(context)
        return output, attn_weights

In [None]:
class PositionwiseFeedForward(nn.Module):
    """位置前馈网络，支持 GELU/ReLU 切换。"""

    def __init__(self, config: EncoderConfig) -> None:
        super().__init__()
        self.linear1 = nn.Linear(config.d_model, config.d_ff)
        self.linear2 = nn.Linear(config.d_ff, config.d_model)
        self.dropout = nn.Dropout(config.dropout)

        # 激活函数选择
        if config.activation == "gelu":
            self.activation = GELU(approximate=True)
        else:
            self.activation = nn.ReLU()

    def forward(self, x: Tensor) -> Tensor:
        # FFN(x) = W2 * activation(W1 * x + b1) + b2
        return self.linear2(self.dropout(self.activation(self.linear1(x))))

In [None]:
class PositionalEncoding(nn.Module):
    """正弦位置编码。"""

    def __init__(self, config: EncoderConfig) -> None:
        super().__init__()
        self.dropout = nn.Dropout(config.dropout)

        pe = torch.zeros(config.max_len, config.d_model)
        position = torch.arange(0, config.max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, config.d_model, 2).float() * (-math.log(10000.0) / config.d_model)
        )

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)

        self.register_buffer("pe", pe)

    def forward(self, x: Tensor) -> Tensor:
        x = x + self.pe[:, : x.size(1), :]
        return self.dropout(x)

In [None]:
class EncoderLayer(nn.Module):
    """Transformer Encoder 单层，支持 Pre-LN/Post-LN 切换。

    核心思想:
        Pre-LN: 先归一化再计算，残差路径畅通无阻
        Post-LN: 先计算再归一化，需要 warmup 稳定训练

    Args:
        config: EncoderConfig，包含 norm_first 开关
    """

    def __init__(self, config: EncoderConfig) -> None:
        super().__init__()
        self.norm_first = config.norm_first

        self.self_attn = MultiHeadAttention(config)
        self.ffn = PositionwiseFeedForward(config)

        self.norm1 = nn.LayerNorm(config.d_model)
        self.norm2 = nn.LayerNorm(config.d_model)

        self.dropout1 = nn.Dropout(config.dropout)
        self.dropout2 = nn.Dropout(config.dropout)

    def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
        if self.norm_first:
            # Pre-LN (GPT-style): x = x + Sublayer(LayerNorm(x))
            attn_out, _ = self.self_attn(self.norm1(x), self.norm1(x), self.norm1(x), mask)
            x = x + self.dropout1(attn_out)
            x = x + self.dropout2(self.ffn(self.norm2(x)))
        else:
            # Post-LN (BERT-style): x = LayerNorm(x + Sublayer(x))
            attn_out, _ = self.self_attn(x, x, x, mask)
            x = self.norm1(x + self.dropout1(attn_out))
            x = self.norm2(x + self.dropout2(self.ffn(x)))

        return x

In [None]:
class Encoder(nn.Module):
    """完整 Transformer Encoder。"""

    def __init__(self, config: EncoderConfig) -> None:
        super().__init__()
        self.config = config
        self.scale = math.sqrt(config.d_model)

        self.embedding = nn.Embedding(config.vocab_size, config.d_model, padding_idx=config.pad_idx)
        self.pos_encoding = PositionalEncoding(config)
        self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.n_layers)])

        # Pre-LN 需要最后一层 LayerNorm
        self.final_norm = nn.LayerNorm(config.d_model) if config.norm_first else nn.Identity()

    def create_padding_mask(self, src: Tensor) -> Tensor:
        return (src == self.config.pad_idx).unsqueeze(1).unsqueeze(2)

    def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor:
        if mask is None:
            mask = self.create_padding_mask(src)

        x = self.embedding(src) * self.scale
        x = self.pos_encoding(x)

        for layer in self.layers:
            x = layer(x, mask)

        return self.final_norm(x)

---

## 3. 形状追踪器 (TraceableModule)

In [None]:
class ShapeTracer:
    """张量形状追踪器，用于调试和理解数据流。

    使用 PyTorch Hook 机制，在前向传播时记录每层的输入输出形状。
    """

    def __init__(self) -> None:
        self.traces: List[Dict] = []
        self.hooks: List = []

    def _hook_fn(self, name: str) -> Callable:
        def hook(module: nn.Module, input: Tuple, output) -> None:
            input_shape = input[0].shape if isinstance(input, tuple) and len(input) > 0 else "N/A"
            output_shape = output.shape if isinstance(output, Tensor) else output[0].shape
            self.traces.append(
                {
                    "layer": name,
                    "type": module.__class__.__name__,
                    "input": str(input_shape),
                    "output": str(output_shape),
                }
            )

        return hook

    def attach(
        self,
        model: nn.Module,
        target_types: Tuple = (MultiHeadAttention, PositionwiseFeedForward, nn.LayerNorm),
    ) -> None:
        """附加 Hook 到指定类型的模块。"""
        for name, module in model.named_modules():
            if isinstance(module, target_types):
                hook = module.register_forward_hook(self._hook_fn(name))
                self.hooks.append(hook)

    def detach(self) -> None:
        """移除所有 Hook。"""
        for hook in self.hooks:
            hook.remove()
        self.hooks.clear()

    def clear(self) -> None:
        """清空追踪记录。"""
        self.traces.clear()

    def print_traces(self) -> None:
        """打印形状追踪结果。"""
        print("\n" + "=" * 80)
        print("张量形状追踪 (Shape Trace)")
        print("=" * 80)
        print(f"{'Layer':<40} {'Type':<25} {'Input':<20} {'Output':<20}")
        print("-" * 80)
        for t in self.traces:
            print(f"{t['layer']:<40} {t['type']:<25} {t['input']:<20} {t['output']:<20}")
        print("=" * 80)

---

## 4. 测试与验证

In [None]:
def test_forward_pass() -> None:
    """验证前向传播形状正确性。"""
    config = EncoderConfig(d_model=256, n_layers=2, n_heads=4, d_ff=512)
    encoder = Encoder(config)

    batch_size, seq_len = 2, 10
    src = torch.randint(1, config.vocab_size, (batch_size, seq_len))

    output = encoder(src)

    assert output.shape == (batch_size, seq_len, config.d_model)
    print("[PASS] test_forward_pass")


test_forward_pass()

In [None]:
def test_pre_ln_vs_post_ln() -> None:
    """对比 Pre-LN 和 Post-LN 的输出差异。"""
    torch.manual_seed(42)

    config_pre = EncoderConfig(d_model=128, n_layers=2, n_heads=4, norm_first=True)
    config_post = EncoderConfig(d_model=128, n_layers=2, n_heads=4, norm_first=False)

    encoder_pre = Encoder(config_pre)
    encoder_post = Encoder(config_post)

    src = torch.randint(1, 1000, (1, 5))

    out_pre = encoder_pre(src)
    out_post = encoder_post(src)

    print(f"Pre-LN 输出范数: {out_pre.norm():.4f}")
    print(f"Post-LN 输出范数: {out_post.norm():.4f}")
    print("[PASS] test_pre_ln_vs_post_ln")


test_pre_ln_vs_post_ln()

In [None]:
def demo_shape_tracing() -> None:
    """演示形状追踪功能。"""
    config = EncoderConfig(d_model=256, n_layers=2, n_heads=4, d_ff=512)
    encoder = Encoder(config)

    tracer = ShapeTracer()
    tracer.attach(encoder)

    src = torch.randint(1, 1000, (2, 8))
    _ = encoder(src)

    tracer.print_traces()
    tracer.detach()


demo_shape_tracing()

---

## 5. GELU vs ReLU 可视化

In [None]:
import matplotlib.pyplot as plt
import numpy as np

x = torch.linspace(-4, 4, 200)

gelu_out = F.gelu(x)
relu_out = F.relu(x)

plt.figure(figsize=(10, 5))
plt.plot(x.numpy(), gelu_out.numpy(), label="GELU", linewidth=2)
plt.plot(x.numpy(), relu_out.numpy(), label="ReLU", linewidth=2, linestyle="--")
plt.axhline(y=0, color="gray", linestyle="-", alpha=0.3)
plt.axvline(x=0, color="gray", linestyle="-", alpha=0.3)
plt.xlabel("Input x")
plt.ylabel("Output")
plt.title("GELU vs ReLU 激活函数对比")
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print("观察: GELU 在负值区域平滑过渡，ReLU 硬截断为 0")

---

## 6. RoPE: 旋转位置编码 ⭐⭐

### 6.1 核心思想

**问题**: 正弦位置编码是加性的，无法直接编码相对位置。

**RoPE 解决方案**: 通过旋转矩阵将位置信息编码到注意力计算中。

$$f_q(x_m, m) = (W_q x_m) e^{im\theta}$$
$$f_k(x_n, n) = (W_k x_n) e^{in\theta}$$

**关键性质**: $\langle f_q(x_m, m), f_k(x_n, n) \rangle$ 只依赖于相对位置 $m - n$。

In [None]:
class RotaryPositionalEmbedding(nn.Module):
    """旋转位置编码 (RoPE) - LLaMA/GPT-NeoX 标准。

    核心思想:
        将位置信息编码为旋转角度，使得注意力分数自然包含相对位置信息。

    数学原理:
        对于位置 m 的向量 x，应用旋转:
        RoPE(x, m) = x * cos(m*theta) + rotate_half(x) * sin(m*theta)

    优势:
        1. 相对位置编码，外推性更好
        2. 无需额外参数
        3. 与注意力计算自然融合
    """

    def __init__(self, dim: int, max_seq_len: int = 4096, base: float = 10000.0) -> None:
        super().__init__()
        self.dim = dim
        self.max_seq_len = max_seq_len
        self.base = base

        # 预计算频率: theta_i = base^(-2i/d)
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)

        # 预计算 cos/sin 缓存
        self._build_cache(max_seq_len)

    def _build_cache(self, seq_len: int) -> None:
        """预计算 cos/sin 值。"""
        t = torch.arange(seq_len, device=self.inv_freq.device)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)  # (seq_len, dim/2)
        emb = torch.cat([freqs, freqs], dim=-1)  # (seq_len, dim)
        self.register_buffer("cos_cached", emb.cos())
        self.register_buffer("sin_cached", emb.sin())

    def _rotate_half(self, x: Tensor) -> Tensor:
        """将向量的前半部分和后半部分交换并取负。"""
        x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
        return torch.cat([-x2, x1], dim=-1)

    def forward(self, q: Tensor, k: Tensor, seq_len: int) -> Tuple[Tensor, Tensor]:
        """应用 RoPE 到 Q 和 K。

        Args:
            q: Query (batch, n_heads, seq_len, head_dim)
            k: Key (batch, n_heads, seq_len, head_dim)
            seq_len: 序列长度

        Returns:
            q_rotated, k_rotated: 应用 RoPE 后的 Q, K
        """
        cos = self.cos_cached[:seq_len].unsqueeze(0).unsqueeze(0)  # (1, 1, seq, dim)
        sin = self.sin_cached[:seq_len].unsqueeze(0).unsqueeze(0)

        # RoPE: x * cos + rotate_half(x) * sin
        q_rotated = q * cos + self._rotate_half(q) * sin
        k_rotated = k * cos + self._rotate_half(k) * sin

        return q_rotated, k_rotated


# 测试 RoPE
def test_rope() -> None:
    rope = RotaryPositionalEmbedding(dim=64, max_seq_len=512)
    q = torch.randn(2, 8, 32, 64)  # (batch, heads, seq, dim)
    k = torch.randn(2, 8, 32, 64)

    q_rot, k_rot = rope(q, k, seq_len=32)
    print(f"Q 形状: {q.shape} -> {q_rot.shape}")
    print(f"K 形状: {k.shape} -> {k_rot.shape}")
    print("[PASS] RoPE 测试通过")


test_rope()

---

## 7. ALiBi: 线性偏置注意力 ⭐⭐

### 7.1 核心思想

**ALiBi (Attention with Linear Biases)**: 直接在注意力分数上添加线性位置偏置。

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}} + m \cdot [-(i-j)]\right)V$$

其中 $m$ 是每个头的斜率，$i, j$ 是位置索引。

**优势**:
1. **零额外参数**: 只需预计算偏置矩阵
2. **外推性极强**: 训练 1024 长度，可推理 8192+
3. **实现简单**: 只需在注意力分数上加偏置

In [None]:
class ALiBi(nn.Module):
    """ALiBi: Attention with Linear Biases。

    核心思想:
        在注意力分数上添加线性位置偏置，无需学习位置嵌入。

    数学原理:
        bias[i,j] = -m * |i - j|
        其中 m 是每个头的斜率，按几何级数递减。

    优势:
        1. 零额外参数
        2. 外推性极强 (训练 1K，推理 8K+)
        3. 实现简单
    """

    def __init__(self, n_heads: int, max_seq_len: int = 4096) -> None:
        super().__init__()
        self.n_heads = n_heads

        # 计算每个头的斜率: 2^(-8/n), 2^(-16/n), ...
        slopes = self._get_slopes(n_heads)
        self.register_buffer("slopes", slopes)

        # 预计算偏置矩阵
        bias = self._build_alibi_bias(max_seq_len, slopes)
        self.register_buffer("bias", bias)

    def _get_slopes(self, n_heads: int) -> Tensor:
        """计算 ALiBi 斜率 (几何级数)。"""

        def get_slopes_power_of_2(n: int) -> list:
            start = 2 ** (-(2 ** -(math.log2(n) - 3)))
            ratio = start
            return [start * (ratio**i) for i in range(n)]

        if math.log2(n_heads).is_integer():
            return torch.tensor(get_slopes_power_of_2(n_heads))
        else:
            closest_power_of_2 = 2 ** math.floor(math.log2(n_heads))
            slopes = get_slopes_power_of_2(closest_power_of_2)
            extra_slopes = get_slopes_power_of_2(2 * closest_power_of_2)[0::2][
                : n_heads - closest_power_of_2
            ]
            return torch.tensor(slopes + extra_slopes)

    def _build_alibi_bias(self, seq_len: int, slopes: Tensor) -> Tensor:
        """构建 ALiBi 偏置矩阵。"""
        # 相对位置矩阵: [0, -1, -2, ...], [1, 0, -1, ...], ...
        positions = torch.arange(seq_len)
        relative_positions = positions.unsqueeze(0) - positions.unsqueeze(1)  # (seq, seq)

        # 偏置 = -slope * |relative_position|
        bias = -slopes.unsqueeze(1).unsqueeze(1) * relative_positions.abs().unsqueeze(0)
        return bias  # (n_heads, seq, seq)

    def forward(self, attn_scores: Tensor) -> Tensor:
        """将 ALiBi 偏置添加到注意力分数。

        Args:
            attn_scores: (batch, n_heads, seq_q, seq_k)

        Returns:
            带偏置的注意力分数
        """
        seq_len = attn_scores.shape[-1]
        return attn_scores + self.bias[:, :seq_len, :seq_len].unsqueeze(0)


# 测试 ALiBi
def test_alibi() -> None:
    alibi = ALiBi(n_heads=8, max_seq_len=512)
    scores = torch.randn(2, 8, 32, 32)  # (batch, heads, seq, seq)

    scores_with_bias = alibi(scores)
    print(f"注意力分数形状: {scores.shape}")
    print(f"ALiBi 偏置形状: {alibi.bias.shape}")
    print(f"斜率: {alibi.slopes[:4].tolist()}")
    print("[PASS] ALiBi 测试通过")


test_alibi()

---

## 8. 梯度检查点 (Gradient Checkpointing) ⭐

### 8.1 核心思想

**问题**: 深层 Transformer 需要存储所有中间激活值用于反向传播，内存消耗巨大。

**解决方案**: 只存储部分激活值，反向传播时重新计算其他激活值。

**权衡**:
- **内存**: 减少 $O(n)$ 到 $O(\sqrt{n})$
- **计算**: 增加约 33% 的计算量

In [None]:
from torch.utils.checkpoint import checkpoint


class CheckpointedEncoderLayer(nn.Module):
    """支持梯度检查点的 Encoder Layer。

    核心思想:
        使用 torch.utils.checkpoint 在反向传播时重新计算激活值，
        以内存换计算，适合训练超深网络。

    内存节省:
        - 标准: O(n_layers) 激活值存储
        - 检查点: O(sqrt(n_layers)) 激活值存储
    """

    def __init__(self, config: EncoderConfig, use_checkpoint: bool = True) -> None:
        super().__init__()
        self.use_checkpoint = use_checkpoint
        self.norm_first = config.norm_first

        self.self_attn = MultiHeadAttention(config)
        self.ffn = PositionwiseFeedForward(config)
        self.norm1 = nn.LayerNorm(config.d_model)
        self.norm2 = nn.LayerNorm(config.d_model)
        self.dropout1 = nn.Dropout(config.dropout)
        self.dropout2 = nn.Dropout(config.dropout)

    def _forward_impl(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
        """实际的前向传播逻辑。"""
        if self.norm_first:
            attn_out, _ = self.self_attn(self.norm1(x), self.norm1(x), self.norm1(x), mask)
            x = x + self.dropout1(attn_out)
            x = x + self.dropout2(self.ffn(self.norm2(x)))
        else:
            attn_out, _ = self.self_attn(x, x, x, mask)
            x = self.norm1(x + self.dropout1(attn_out))
            x = self.norm2(x + self.dropout2(self.ffn(x)))
        return x

    def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
        if self.use_checkpoint and self.training:
            # 使用梯度检查点，反向传播时重新计算
            return checkpoint(self._forward_impl, x, mask, use_reentrant=False)
        else:
            return self._forward_impl(x, mask)


# 测试梯度检查点
def test_gradient_checkpointing() -> None:
    config = EncoderConfig(d_model=256, n_heads=4)

    layer_standard = EncoderLayer(config)
    layer_checkpoint = CheckpointedEncoderLayer(config, use_checkpoint=True)

    x = torch.randn(2, 10, 256, requires_grad=True)

    # 标准层
    layer_standard.train()
    out_std = layer_standard(x.clone())

    # 检查点层
    layer_checkpoint.train()
    out_ckpt = layer_checkpoint(x.clone())

    print(f"标准层输出形状: {out_std.shape}")
    print(f"检查点层输出形状: {out_ckpt.shape}")
    print("[PASS] 梯度检查点测试通过")


test_gradient_checkpointing()

---

## 9. 位置编码方法对比

| 方法 | 类型 | 外推性 | 参数量 | 适用场景 |
|:-----|:-----|:-------|:-------|:---------|
| **正弦编码** | 绝对 | 差 | 0 | BERT, 原始 Transformer |
| **可学习编码** | 绝对 | 差 | O(L*d) | GPT-2 |
| **RoPE** | 相对 | 好 | 0 | LLaMA, GPT-NeoX |
| **ALiBi** | 相对 | 极好 | 0 | BLOOM, MPT |

In [None]:
def visualize_alibi_bias(n_heads: int = 4, seq_len: int = 16) -> None:
    """可视化 ALiBi 偏置矩阵。"""
    alibi = ALiBi(n_heads=n_heads, max_seq_len=seq_len)

    fig, axes = plt.subplots(1, n_heads, figsize=(4 * n_heads, 4))
    if n_heads == 1:
        axes = [axes]

    for i, ax in enumerate(axes):
        bias = alibi.bias[i, :seq_len, :seq_len].numpy()
        im = ax.imshow(bias, cmap="RdBu_r", aspect="auto")
        ax.set_title(f"Head {i+1}\nslope={alibi.slopes[i]:.4f}")
        ax.set_xlabel("Key Position")
        ax.set_ylabel("Query Position")
        plt.colorbar(im, ax=ax)

    plt.suptitle("ALiBi Bias Matrices (越远越负)", fontsize=14)
    plt.tight_layout()
    plt.show()


visualize_alibi_bias()

---

## 10. 总结

| 特性 | Pre-LN (GPT) | Post-LN (BERT) |
|:-----|:-------------|:---------------|
| 归一化位置 | 子层之前 | 子层之后 |
| 梯度稳定性 | 更稳定 | 需要 Warmup |
| 深层训练 | 更容易 | 较困难 |
| 最终 LayerNorm | 需要 | 不需要 |

**高级技术**:
- **GELU**: 平滑可微、概率解释、SOTA 标配
- **RoPE**: 旋转位置编码，相对位置，外推性好
- **ALiBi**: 线性偏置，零参数，外推性极强
- **梯度检查点**: 内存换计算，训练超深网络