# DPO (Direct Preference Optimization)

**直接偏好优化** 使用偏好对 (chosen / rejected) 对齐模型与人类偏好，无需训练单独的 reward model。

- **Chosen**：人类偏好的回复（更好）
- **Rejected**：人类不偏好的回复（更差）
- **Reference model**：冻结的参考模型，用于计算 \(\log \pi_{\text{ref}}(y|x)\)
- **Policy model**：待训练模型，优化使其对 chosen 的归一化似然相对 rejected 更高
- **β**：温度参数，控制偏离参考模型的程度

In [None]:
import os
import sys
import json
import tempfile
from pathlib import Path

import torch
import torch.nn.functional as F

os.environ["TOKENIZERS_PARALLELISM"] = "false"

ROOT = Path.cwd()
if ROOT.name == "docs":
    ROOT = ROOT.parent
sys.path.insert(0, str(ROOT))

from transformers import AutoTokenizer
from data_loader.dpo_dataset import DPODataset

## 数据格式：`dpo.jsonl`

每行一个 JSON 对象，包含 `chosen` 与 `rejected`，均为对话消息列表（与 SFT chat 格式一致）：

```json
{
  "chosen": [
    {"role": "user", "content": "1+1等于几？"},
    {"role": "assistant", "content": "1+1等于2。"}
  ],
  "rejected": [
    {"role": "user", "content": "1+1等于几？"},
    {"role": "assistant", "content": "不知道。"}
  ]
}
```

- `DPODataset` 用 `apply_chat_template` 将 chosen/rejected 转为文本，再 tokenize、截断、padding。
- **Labels**：仅对 **assistant 回复部分** 计算 loss；prompt 与 padding 的 label 为 `-100`。
- **Mask**：`mask_chosen` / `mask_rejected` 标记哪些位置参与 DPO 损失（1=助理回复，0=其余）。

In [None]:
# 创建临时 dpo.jsonl 用于演示
demo_data = [
    {
        "chosen": [
            {"role": "user", "content": "1+1等于几？"},
            {"role": "assistant", "content": "1+1等于2。"}
        ],
        "rejected": [
            {"role": "user", "content": "1+1等于几？"},
            {"role": "assistant", "content": "不知道。"}
        ]
    },
    {
        "chosen": [
            {"role": "user", "content": "你好"},
            {"role": "assistant", "content": "你好！有什么可以帮你的？"}
        ],
        "rejected": [
            {"role": "user", "content": "你好"},
            {"role": "assistant", "content": "哦。"}
        ]
    }
]

tmpdir = Path(tempfile.gettempdir()) / "dpo_demo"
tmpdir.mkdir(exist_ok=True)
dpo_jsonl = tmpdir / "dpo_demo.jsonl"
with open(dpo_jsonl, "w", encoding="utf-8") as f:
    for item in demo_data:
        f.write(json.dumps(item, ensure_ascii=False) + "\n")

tokenizer_path = ROOT / "vermind_tokenizer"
tokenizer = AutoTokenizer.from_pretrained(str(tokenizer_path), trust_remote_code=True)
ds = DPODataset(str(dpo_jsonl), tokenizer, max_length=256)
print(f"DPODataset 长度: {len(ds)}")
sample = ds[0]
print(f"keys: {list(sample.keys())}")
print(f"x_chosen shape: {sample['x_chosen'].shape}, mask_chosen sum: {sample['mask_chosen'].sum().item():.0f}")

## DPO 损失

1. **logits → log probs（每个 token）**：  
   `log_probs[b,s] = log_softmax(logits[b,s,:])[labels[b,s]]`  
   即模型对「真实下一个 token」的 log 概率。

2. **序列级 log prob**：支持 **sum** 或 **mean**（`aggregate` 参数）。  
   - **sum**：`seq_log_prob = (log_probs * mask).sum(dim=1)`，不除长度。  
   - **mean**：`seq_log_prob = (log_probs * mask).sum(dim=1) / mask.sum(dim=1).clamp_min(1e-8)`，对 mask 位置求平均。

3. **DPO 损失**：  
   - 将 batch 前半视为 chosen、后半视为 rejected（训练时 `torch.cat([chosen, rejected], dim=0)`）。  
   - 定义 log-ratio：  
     - \(\pi\) log-ratio = `chosen_policy_log_prob - rejected_policy_log_prob`  
     - ref log-ratio = `chosen_ref_log_prob - rejected_ref_log_prob`  
   - `logits = π_log_ratio - ref_log_ratio`，  
     `loss = -log σ(β * logits)`，  
     即希望 policy 相对 ref 更偏好 chosen。

In [None]:
def logits_to_log_probs(logits, labels):
    """logits: (B, S, V), labels: (B, S) -> (B, S)"""
    log_probs = F.log_softmax(logits, dim=2)
    return torch.gather(log_probs, dim=2, index=labels.unsqueeze(2)).squeeze(-1)


def logits_to_log_probs(logits, labels):
    """logits: (B, S, V), labels: (B, S) -> (B, S)"""
    log_probs = F.log_softmax(logits, dim=2)
    return torch.gather(log_probs, dim=2, index=labels.unsqueeze(2)).squeeze(-1)


def dpo_loss(ref_log_probs, policy_log_probs, mask, beta=0.1, aggregate="sum"):
    """aggregate: sum 不除长度, mean 对 mask 位置求平均。返回 loss, chosen_rewards, rejected_rewards。"""
    policy_raw = (policy_log_probs * mask).sum(dim=1)
    ref_raw = (ref_log_probs * mask).sum(dim=1)
    if aggregate == "mean":
        seq_lengths = mask.sum(dim=1).clamp_min(1e-8)
        policy_sum = policy_raw / seq_lengths
        ref_sum = ref_raw / seq_lengths
    else:
        policy_sum, ref_sum = policy_raw, ref_raw
    n = ref_log_probs.shape[0] // 2
    policy_chosen, policy_rejected = policy_sum[:n], policy_sum[n:]
    ref_chosen, ref_rejected = ref_sum[:n], ref_sum[n:]
    pi_logratios = policy_chosen - policy_rejected
    ref_logratios = ref_chosen - ref_rejected
    logits = pi_logratios - ref_logratios
    losses = -F.logsigmoid(beta * logits)
    with torch.no_grad():
        chosen_rewards = (beta * (policy_chosen - ref_chosen)).detach()
        rejected_rewards = (beta * (policy_rejected - ref_rejected)).detach()
    return losses.mean(), chosen_rewards.mean(), rejected_rewards.mean()


# 合成数据演示
B, S, V = 4, 8, 100
logits = torch.randn(B, S, V)
labels = torch.randint(0, V, (B, S))
mask = torch.zeros(B, S)
mask[:, :4] = 1.0
lp = logits_to_log_probs(logits, labels)
ref_lp = lp + 0.1 * torch.randn_like(lp)
policy_lp = lp + 0.2 * torch.randn_like(lp)
ref_lp = torch.cat([ref_lp, ref_lp], dim=0)
policy_lp = torch.cat([policy_lp, policy_lp], dim=0)
mask = torch.cat([mask, mask], dim=0)
loss_s, cr_s, rr_s = dpo_loss(ref_lp, policy_lp, mask, beta=0.1, aggregate="sum")
loss_m, cr_m, rr_m = dpo_loss(ref_lp, policy_lp, mask, beta=0.1, aggregate="mean")
print("aggregate=sum:  loss:", loss_s.item(), "chosen_reward:", cr_s.item(), "rejected_reward:", rr_s.item())
print("aggregate=mean: loss:", loss_m.item(), "chosen_reward:", cr_m.item(), "rejected_reward:", rr_m.item())

## 训练流程概要（train/dpo.py）

- **Reference model**：加载 `--ref_weight`，eval 且冻结，只算 `ref_log_probs`。
- **Policy model**：由 `--from_weight` 或 resume 加载，正常训练。
- 每 step：`x = cat([x_chosen, x_rejected])`，同样 cat `y`、`mask`；用 ref 算 `ref_log_probs`，用 policy 算 `policy_log_probs`；`loss = dpo_loss(...) + aux_loss`，只对 policy 反向传播。

运行：`python train/dpo.py --data_path /path/to/dpo.jsonl --ref_weight /path/to/sft_checkpoint`  
可选 `--dpo_aggregate sum`（默认）或 `--dpo_aggregate mean`。

In [None]:
from torch.utils.data import DataLoader

loader = DataLoader(ds, batch_size=2, shuffle=False)
batch = next(iter(loader))
x_chosen = batch["x_chosen"]
x_rejected = batch["x_rejected"]
y_chosen, y_rejected = batch["y_chosen"], batch["y_rejected"]
mask_chosen, mask_rejected = batch["mask_chosen"], batch["mask_rejected"]

x = torch.cat([x_chosen, x_rejected], dim=0)
y = torch.cat([y_chosen, y_rejected], dim=0)
mask = torch.cat([mask_chosen, mask_rejected], dim=0)
print("Training batch: x", x.shape, "y", y.shape, "mask", mask.shape)
print("mask sum (tokens to score):", mask.sum().item())