# Transformer 架构详解

**SOTA 教育标准** | 包含 Self-Attention、Multi-Head、FFN、LayerNorm

---

## 1. 核心组件

| 组件 | 公式 | 功能 |
|:-----|:-----|:-----|
| **Self-Attn** | softmax(QK^T/√d_k)V | 捕捉依赖 |
| **FFN** | ReLU(xW_1)W_2 | 非线性 |

In [None]:
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import Optional, Tuple
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

---

## 2. 注意力配置

In [None]:
@dataclass
class AttentionConfig:
    """注意力配置。"""
    d_model: int = 512
    n_heads: int = 8
    dropout: float = 0.1

    @property
    def d_k(self) -> int:
        return self.d_model // self.n_heads


config = AttentionConfig()
print(f"d_model={config.d_model}, n_heads={config.n_heads}, d_k={config.d_k}")

---

## 3. 缩放点积注意力

In [None]:
class ScaledDotProductAttention(nn.Module):
    """缩放点积注意力。
    
    Attention(Q,K,V) = softmax(QK^T / sqrt(d_k)) V
    """

    def __init__(self, d_k: int, dropout: float = 0.1):
        super().__init__()
        self.d_k = d_k
        self.dropout = nn.Dropout(dropout)

    def forward(self, Q: Tensor, K: Tensor, V: Tensor, mask: Optional[Tensor] = None):
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        attn = self.dropout(F.softmax(scores, dim=-1))
        return torch.matmul(attn, V), attn


# 测试
attn = ScaledDotProductAttention(64)
Q = K = V = torch.randn(2, 8, 10, 64)
out, weights = attn(Q, K, V)
print(f"输出: {out.shape}, 权重: {weights.shape}")

---

## 4. 多头注意力

In [None]:
class MultiHeadAttention(nn.Module):
    """多头注意力。"""

    def __init__(self, config: AttentionConfig):
        super().__init__()
        self.n_heads = config.n_heads
        self.d_k = config.d_k
        
        self.W_q = nn.Linear(config.d_model, config.d_model, bias=False)
        self.W_k = nn.Linear(config.d_model, config.d_model, bias=False)
        self.W_v = nn.Linear(config.d_model, config.d_model, bias=False)
        self.W_o = nn.Linear(config.d_model, config.d_model, bias=False)
        self.attention = ScaledDotProductAttention(config.d_k, config.dropout)

    def forward(self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None):
        B = query.size(0)
        Q = self.W_q(query).view(B, -1, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_k(key).view(B, -1, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_v(value).view(B, -1, self.n_heads, self.d_k).transpose(1, 2)
        
        if mask is not None:
            mask = mask.unsqueeze(1)
        x, attn = self.attention(Q, K, V, mask)
        x = x.transpose(1, 2).contiguous().view(B, -1, self.n_heads * self.d_k)
        return self.W_o(x), attn


# 测试
mha = MultiHeadAttention(config)
x = torch.randn(2, 10, 512)
out, _ = mha(x, x, x)
print(f"MHA 输出: {out.shape}")

---

## 5. FFN 与 Encoder Layer

In [None]:
class PositionwiseFFN(nn.Module):
    """前馈网络: FFN(x) = ReLU(xW_1)W_2"""
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.w1 = nn.Linear(d_model, d_ff)
        self.w2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: Tensor) -> Tensor:
        return self.w2(self.dropout(F.relu(self.w1(x))))


class TransformerEncoderLayer(nn.Module):
    """Transformer 编码器层。"""
    def __init__(self, config: AttentionConfig, d_ff: int = 2048):
        super().__init__()
        self.attn = MultiHeadAttention(config)
        self.ffn = PositionwiseFFN(config.d_model, d_ff, config.dropout)
        self.norm1 = nn.LayerNorm(config.d_model)
        self.norm2 = nn.LayerNorm(config.d_model)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
        attn_out, _ = self.attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_out))
        x = self.norm2(x + self.dropout(self.ffn(x)))
        return x


# 测试
layer = TransformerEncoderLayer(config)
out = layer(x)
print(f"Encoder Layer 输出: {out.shape}")
print(f"参数量: {sum(p.numel() for p in layer.parameters()):,}")

---

## 6. 可视化

In [None]:
def visualize_attention():
    """可视化注意力模式。"""
    tokens = ["The", "cat", "sat", "on", "mat"]
    n = len(tokens)
    
    patterns = [
        ("Local", lambda i,j: 1 if abs(i-j)<=1 else 0.1),
        ("Global", lambda i,j: 1),
        ("Causal", lambda i,j: 1 if j<=i else 0),
        ("Sparse", lambda i,j: 1 if i==j or j==0 else 0.1),
    ]
    
    fig, axes = plt.subplots(1, 4, figsize=(14, 3))
    for ax, (name, fn) in zip(axes, patterns):
        attn = np.array([[fn(i,j) for j in range(n)] for i in range(n)])
        attn = attn / attn.sum(axis=1, keepdims=True)
        ax.imshow(attn, cmap='Blues')
        ax.set_xticks(range(n))
        ax.set_yticks(range(n))
        ax.set_xticklabels(tokens, rotation=45)
        ax.set_yticklabels(tokens)
        ax.set_title(name)
    plt.tight_layout()
    plt.show()


visualize_attention()

---

## 7. 总结

| 组件 | 参数占比 | 功能 |
|:-----|:--------:|:-----|
| **Self-Attn** | ~33% | 捕捉依赖 |
| **FFN** | ~67% | 非线性变换 |