# LoRA 微调实战

**SOTA 教育标准实现** | 参数高效微调技术

---

## 学习目标

1. 理解LoRA低秩分解的数学原理
2. 掌握LoRA层的实现方法
3. 学会将LoRA应用到预训练模型
4. 实现权重合并与保存

## 目录

1. [环境配置](#1-环境配置)
2. [LoRA原理](#2-lora原理)
3. [LoRA实现](#3-lora实现)
4. [应用到模型](#4-应用到模型)
5. [权重合并](#5-权重合并)
6. [验证测试](#6-验证测试)

---

## 1. 环境配置

In [None]:
# 导入必要的库
import math
import re
import torch
import torch.nn as nn
import torch.nn.functional as F

# 设置随机种子
torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'使用设备: {device}')

---

## 2. LoRA原理

### 2.1 核心思想

LoRA用低秩分解近似权重更新：

$$W' = W + \Delta W = W + BA$$

其中：
- $W \in \mathbb{R}^{d \times k}$: 原始权重（冻结）
- $B \in \mathbb{R}^{d \times r}$: 低秩矩阵
- $A \in \mathbb{R}^{r \times k}$: 低秩矩阵
- $r \ll \min(d, k)$: 秩

### 2.2 参数量对比

In [None]:
# 参数量对比分析
d, k = 4096, 4096
r_values = [4, 8, 16, 32, 64]

print(f'原始参数量: {d * k:,}')
print('\nLoRA参数量对比:')
print('-' * 40)
for r in r_values:
    lora_params = d * r + r * k
    ratio = lora_params / (d * k) * 100
    print(f'r={r:2d}: {lora_params:>8,} ({ratio:.2f}%)')

---

## 3. LoRA实现

In [None]:
class LoRALinear(nn.Module):
    """LoRA线性层。
    
    将原始线性层包装，添加低秩适配器。
    
    数学原理:
        h = Wx + BAx * (alpha/r)
    """
    
    def __init__(self, original_layer: nn.Linear, r: int = 8, 
                 alpha: int = 16, dropout: float = 0.0) -> None:
        super().__init__()
        self.original_layer = original_layer
        self.r = r
        self.alpha = alpha
        self.scaling = alpha / r
        
        in_features = original_layer.in_features
        out_features = original_layer.out_features
        
        # 冻结原始权重
        original_layer.weight.requires_grad = False
        if original_layer.bias is not None:
            original_layer.bias.requires_grad = False
        
        # LoRA矩阵
        self.lora_A = nn.Parameter(torch.zeros(r, in_features))
        self.lora_B = nn.Parameter(torch.zeros(out_features, r))
        self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
        
        # 初始化: A用kaiming，B用零
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B)
        
        self.merged = False
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.merged:
            return self.original_layer(x)
        # 原始输出 + LoRA增量
        result = self.original_layer(x)
        lora_out = self.dropout(x) @ self.lora_A.T @ self.lora_B.T
        return result + lora_out * self.scaling
    
    def merge(self) -> None:
        """将LoRA权重合并到原始权重。"""
        if not self.merged:
            self.original_layer.weight.data += (self.lora_B @ self.lora_A * self.scaling)
            self.merged = True
    
    def unmerge(self) -> None:
        """从原始权重中移除LoRA权重。"""
        if self.merged:
            self.original_layer.weight.data -= (self.lora_B @ self.lora_A * self.scaling)
            self.merged = False

In [None]:
# 验证LoRA层
linear = nn.Linear(768, 768)
lora_linear = LoRALinear(linear, r=8, alpha=16)

x = torch.randn(2, 10, 768)
out = lora_linear(x)

trainable = sum(p.numel() for p in lora_linear.parameters() if p.requires_grad)
total = sum(p.numel() for p in lora_linear.parameters())

print(f'输出形状: {out.shape}')
print(f'可训练参数: {trainable:,} / {total:,} ({100*trainable/total:.2f}%)')
print(f'✓ LoRA层测试通过')

---

## 4. 应用到模型

In [None]:
def apply_lora(model: nn.Module, target_modules: list, 
               r: int = 8, alpha: int = 16) -> nn.Module:
    """将LoRA应用到模型的指定模块。
    
    Args:
        model: 原始模型
        target_modules: 目标模块名称列表
        r: LoRA秩
        alpha: 缩放因子
    Returns:
        应用了LoRA的模型
    """
    modules_to_modify = {}
    
    # 查找目标模块
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            for target in target_modules:
                if re.search(target, name):
                    modules_to_modify[name] = module
                    break
    
    # 替换为LoRA层
    for name, module in modules_to_modify.items():
        parts = name.rsplit('.', 1)
        if len(parts) == 2:
            parent = model.get_submodule(parts[0])
            attr = parts[1]
        else:
            parent = model
            attr = name
        setattr(parent, attr, LoRALinear(module, r=r, alpha=alpha))
    
    # 冻结非LoRA参数
    for name, param in model.named_parameters():
        if 'lora_' not in name:
            param.requires_grad = False
    
    return model

In [None]:
# 示例：简单Transformer模型
class SimpleTransformer(nn.Module):
    def __init__(self, d_model: int = 768) -> None:
        super().__init__()
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.o_proj = nn.Linear(d_model, d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.GELU(),
            nn.Linear(d_model * 4, d_model)
        )

# 应用LoRA
model = SimpleTransformer()
model = apply_lora(model, ['q_proj', 'v_proj'], r=8, alpha=16)

trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f'应用LoRA后: {trainable:,} / {total:,} ({100*trainable/total:.2f}%)')

---

## 5. 权重合并

In [None]:
def merge_lora(model: nn.Module) -> None:
    """合并所有LoRA权重到原始权重。"""
    for module in model.modules():
        if isinstance(module, LoRALinear):
            module.merge()

def save_lora(model: nn.Module, path: str) -> None:
    """保存LoRA权重。"""
    lora_state = {}
    for name, param in model.named_parameters():
        if 'lora_' in name:
            lora_state[name] = param.data
    torch.save(lora_state, path)
    print(f'已保存 {len(lora_state)} 个LoRA参数')

# 保存LoRA权重
save_lora(model, '/tmp/lora_weights.pt')

# 合并权重
merge_lora(model)
print('✓ LoRA权重已合并')

---

## 6. 验证测试

In [None]:
def test_lora_equivalence() -> None:
    """验证合并前后输出一致性。"""
    linear = nn.Linear(64, 64)
    lora = LoRALinear(linear, r=4, alpha=8)
    
    x = torch.randn(1, 10, 64)
    
    # 合并前输出
    out_before = lora(x)
    
    # 合并后输出
    lora.merge()
    out_after = lora(x)
    
    diff = (out_before - out_after).abs().max().item()
    assert diff < 1e-5, f'合并前后输出差异过大: {diff}'
    print(f'合并前后最大差异: {diff:.2e}')
    print(f'✓ test_lora_equivalence 通过')

test_lora_equivalence()

---

## 总结

| 特性 | 说明 |
|:-----|:-----|
| **参数高效** | 只训练0.1-1%参数 |
| **无推理开销** | 权重可合并 |
| **易于切换** | 可加载不同适配器 |
| **效果接近全量微调** | 性能损失小 |