# LLaMA 架构详解

**SOTA 教育标准实现** | 现代大语言模型核心技术

---

## 学习目标

1. 理解RMSNorm相比LayerNorm的优势
2. 掌握旋转位置编码(RoPE)的数学原理
3. 学会SwiGLU激活函数的实现
4. 理解分组查询注意力(GQA)的内存优化

## 目录

1. [环境配置](#1-环境配置)
2. [RMSNorm](#2-rmsnorm)
3. [旋转位置编码RoPE](#3-旋转位置编码rope)
4. [SwiGLU激活](#4-swiglu激活)
5. [分组查询注意力GQA](#5-分组查询注意力gqa)
6. [验证测试](#6-验证测试)

---

## 1. 环境配置

In [None]:
# 导入必要的库
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

# 设置随机种子
torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'使用设备: {device}')

---

## 2. RMSNorm

### 2.1 数学原理

RMSNorm相比LayerNorm移除了均值中心化，计算更高效：

$$\text{RMSNorm}(x) = \frac{x}{\sqrt{\frac{1}{d}\sum_{i=1}^d x_i^2 + \epsilon}} \cdot \gamma$$

### 2.2 优势

| 特性 | LayerNorm | RMSNorm |
|:-----|:----------|:--------|
| 均值计算 | 需要 | 不需要 |
| 计算量 | 较高 | 较低 |
| 效果 | 基准 | 相当 |

In [None]:
class RMSNorm(nn.Module):
    """Root Mean Square Layer Normalization。
    
    相比LayerNorm移除了均值中心化，计算更高效。
    """
    
    def __init__(self, dim: int, eps: float = 1e-6) -> None:
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # 计算RMS
        rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
        return x / rms * self.weight

# 验证RMSNorm
x = torch.randn(2, 10, 768)
rms = RMSNorm(768)
out = rms(x)
print(f'输入形状: {x.shape}')
print(f'输出形状: {out.shape}')
print(f'✓ RMSNorm测试通过')

---

## 3. 旋转位置编码(RoPE)

### 3.1 核心思想

RoPE将位置信息编码为旋转矩阵，使注意力分数自然包含相对位置信息：

$$f(x, m) = x \cdot \cos(m\theta) + \text{rotate}(x) \cdot \sin(m\theta)$$

### 3.2 优势

- **相对位置**: 注意力分数只依赖相对位置
- **外推能力**: 可处理比训练更长的序列
- **高效**: 无需额外参数

In [None]:
class RotaryEmbedding(nn.Module):
    """旋转位置编码(RoPE)。"""
    
    def __init__(self, dim: int, max_seq_len: int = 4096, theta: float = 10000.0) -> None:
        super().__init__()
        # 计算频率
        inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)
        
        # 预计算cos和sin
        t = torch.arange(max_seq_len)
        freqs = torch.outer(t, inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer('cos', emb.cos())
        self.register_buffer('sin', emb.sin())
    
    def forward(self, seq_len: int):
        return self.cos[:seq_len], self.sin[:seq_len]


def rotate_half(x: torch.Tensor) -> torch.Tensor:
    """将张量的后半部分旋转到前半部分。"""
    x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin):
    """应用旋转位置编码到Q和K。"""
    cos = cos.unsqueeze(0).unsqueeze(0)
    sin = sin.unsqueeze(0).unsqueeze(0)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

# 验证RoPE
rope = RotaryEmbedding(64)
cos, sin = rope(10)
print(f'cos形状: {cos.shape}, sin形状: {sin.shape}')
print(f'✓ RoPE测试通过')

---

## 4. SwiGLU激活

### 4.1 数学公式

$$\text{SwiGLU}(x) = \text{Swish}(xW_1) \otimes xW_3$$
$$\text{Swish}(x) = x \cdot \sigma(x)$$

### 4.2 优势

- 门控机制提升表达能力
- 比GELU/ReLU效果更好

In [None]:
class SwiGLU(nn.Module):
    """SwiGLU激活函数。
    
    SwiGLU(x) = Swish(x*W1) * (x*W3)
    """
    
    def __init__(self, d_model: int, d_ff: int) -> None:
        super().__init__()
        self.w1 = nn.Linear(d_model, d_ff, bias=False)
        self.w2 = nn.Linear(d_ff, d_model, bias=False)
        self.w3 = nn.Linear(d_model, d_ff, bias=False)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

# 验证SwiGLU
swiglu = SwiGLU(768, 2048)
x = torch.randn(2, 10, 768)
out = swiglu(x)
print(f'输入形状: {x.shape}')
print(f'输出形状: {out.shape}')
print(f'SwiGLU参数量: {sum(p.numel() for p in swiglu.parameters()):,}')
print(f'✓ SwiGLU测试通过')

---

## 5. 分组查询注意力(GQA)

### 5.1 核心思想

GQA让多个Q头共享同一组KV头，减少KV Cache内存：

| 类型 | KV头数 | 内存 |
|:-----|:-------|:-----|
| MHA | = Q头数 | 高 |
| MQA | 1 | 低 |
| GQA | 介于两者 | 中 |

In [None]:
class GroupedQueryAttention(nn.Module):
    """分组查询注意力(GQA)。
    
    多个Q头共享同一组KV头，减少KV Cache内存。
    """
    
    def __init__(self, d_model: int, n_heads: int, n_kv_heads: int) -> None:
        super().__init__()
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.n_rep = n_heads // n_kv_heads  # KV重复次数
        self.head_dim = d_model // n_heads
        
        self.wq = nn.Linear(d_model, n_heads * self.head_dim, bias=False)
        self.wk = nn.Linear(d_model, n_kv_heads * self.head_dim, bias=False)
        self.wv = nn.Linear(d_model, n_kv_heads * self.head_dim, bias=False)
        self.wo = nn.Linear(n_heads * self.head_dim, d_model, bias=False)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, _ = x.size()
        
        # 计算Q, K, V
        q = self.wq(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        k = self.wk(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
        v = self.wv(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
        
        # 重复KV头以匹配Q头数
        k = k.repeat_interleave(self.n_rep, dim=1)
        v = v.repeat_interleave(self.n_rep, dim=1)
        
        # 注意力计算
        scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        attn = F.softmax(scores, dim=-1)
        out = (attn @ v).transpose(1, 2).contiguous().view(B, T, -1)
        
        return self.wo(out)

# 对比MHA和GQA参数量
d_model, n_heads = 4096, 32
mha = GroupedQueryAttention(d_model, n_heads, n_heads)  # MHA
gqa = GroupedQueryAttention(d_model, n_heads, 8)        # GQA

print(f'MHA参数量: {sum(p.numel() for p in mha.parameters()):,}')
print(f'GQA参数量: {sum(p.numel() for p in gqa.parameters()):,}')
print(f'参数减少: {(1 - sum(p.numel() for p in gqa.parameters()) / sum(p.numel() for p in mha.parameters())) * 100:.1f}%')

---

## 6. 验证测试

In [None]:
def test_gqa() -> None:
    """验证GQA输出形状正确性。"""
    gqa = GroupedQueryAttention(768, 12, 4)
    x = torch.randn(2, 10, 768)
    out = gqa(x)
    
    assert out.shape == x.shape, f'输出形状错误: {out.shape}'
    print(f'✓ GQA: {x.shape} -> {out.shape}')
    print(f'✓ test_gqa 通过')

test_gqa()

---

## 总结

| 技术 | 作用 |
|:-----|:-----|
| **RMSNorm** | 更高效的归一化 |
| **RoPE** | 更好的位置编码，支持外推 |
| **SwiGLU** | 更强的FFN表达能力 |
| **GQA** | 减少KV Cache，加速推理 |