In [1]:
# ===============================================================================
# DPO（Direct Preference Optimization）算法实现
# DPO通过人类偏好数据直接优化语言模型，使其生成更符合人类偏好的输出
# 这里面使用了一个偏好prefer以及两个reject的格式
# ===============================================================================
import torch
import torch.nn.functional as F
from transformers import LlamaForCausalLM, LlamaConfig
from copy import deepcopy

torch.manual_seed(0)


A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.0.2 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "/Library/Developer/CommandLineTools/Library/Frameworks/Python3.framework/Versions/3.9/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/Library/Developer/CommandLineTools/Library/Frameworks/Python3.framework/Versions/3.9/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/Users/jt/Library/Python/3.9/lib/python/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "/Users/jt/Libra

<torch._C.Generator at 0x10a814f30>

In [17]:
# 加载模型
# 创建简化版的Llama模型作为策略模型（将被优化的模型）
policy_model = LlamaForCausalLM(config=LlamaConfig(vocab_size=12, num_hidden_layers=1, hidden_size=32))
# 创建参考模型（通常是SFT模型，在训练过程中保持不变）
reference_model = deepcopy(policy_model)  # 深度复制确保两个模型初始参数完全相同
policy_model


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(10, 32)
    (layers): ModuleList(
      (0): LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=32, out_features=32, bias=False)
          (k_proj): Linear(in_features=32, out_features=32, bias=False)
          (v_proj): Linear(in_features=32, out_features=32, bias=False)
          (o_proj): Linear(in_features=32, out_features=32, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=32, out_features=11008, bias=False)
          (up_proj): Linear(in_features=32, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=32, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((32,), eps=1e-06)
        (post_attention_layernorm): LlamaRMSNorm((32,), eps=1e-06)
      )
    )
    (norm): LlamaRMSNorm((32,), eps=1e-06)
    (rotary_emb): LlamaRotaryEmbedding()
  )
  (lm_hea

In [9]:
# 超参数
beta = 0.1  # DPO的温度系数，控制策略模型与参考模型的偏离程度，值越小允许偏离越大

# 准备训练数据
# 在DPO中，我们需要提示(prompt)、优选回答(chosen/good)和拒绝回答(rejected/bad)
prompt_ids = [1, 2, 3, 4, 5, 6]  # 输入提示的token IDs
good_response_ids = [7, 8, 9, 2]  # 优质回答的token IDs
# 多个低质量回答的示例，每个都是token IDs的列表
bad_response_ids_list = [[1, 2, 0, 0], [4, 5, 6, 0]]

In [10]:
# 构建模型输入：将提示与回答拼接
# 创建包含多个序列的批次：[提示+优质回答, 提示+低质回答1, 提示+低质回答2, ...]
input_ids = torch.LongTensor(
    [prompt_ids + good_response_ids, *[prompt_ids + bad_response_ids for bad_response_ids in bad_response_ids_list]]
)
input_ids

tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 2],
        [1, 2, 3, 4, 5, 6, 1, 2, 0, 0],
        [1, 2, 3, 4, 5, 6, 4, 5, 6, 0]])

In [12]:
# 准备用于计算语言模型损失的标签
# 在语言模型训练中，标签是输入向右移动一位（预测下一个token）
# -100表示在计算损失时忽略该位置（这里忽略提示部分）
labels = torch.LongTensor(
    [
        [-100] * len(prompt_ids) + good_response_ids,
        *[[-100] * len(prompt_ids) + bad_response_ids for bad_response_ids in bad_response_ids_list]
    ]
) # 向右移动一位，因为我们预测的是下一个token
labels

tensor([[-100, -100, -100, -100, -100, -100,    7,    8,    9,    2],
        [-100, -100, -100, -100, -100, -100,    1,    2,    0,    0],
        [-100, -100, -100, -100, -100, -100,    4,    5,    6,    0]])

In [13]:
labels = labels[:, 1:]  
labels

tensor([[-100, -100, -100, -100, -100,    7,    8,    9,    2],
        [-100, -100, -100, -100, -100,    1,    2,    0,    0],
        [-100, -100, -100, -100, -100,    4,    5,    6,    0]])

In [15]:
# 创建掩码，用于标识哪些位置参与损失计算（即回答部分）
loss_mask = (labels != -100)
print(loss_mask.shape)
loss_mask

torch.Size([3, 9])


tensor([[False, False, False, False, False,  True,  True,  True,  True],
        [False, False, False, False, False,  True,  True,  True,  True],
        [False, False, False, False, False,  True,  True,  True,  True]])

In [16]:
# 将-100替换为0，因为在gather操作中-100是无效索引
labels[labels == -100] = 0
print(labels.shape)
labels

torch.Size([3, 9])


tensor([[0, 0, 0, 0, 0, 7, 8, 9, 2],
        [0, 0, 0, 0, 0, 1, 2, 0, 0],
        [0, 0, 0, 0, 0, 4, 5, 6, 0]])

In [20]:
output = policy_model(input_ids)
for key, value in output.items():  # 如果是ModelOutput对象，可以用output.__dict__.items()
    if hasattr(value, "shape"):
        print(f"{key}: {value.shape}")
    else:
        print(f"{key}: {type(value)}")

logits: torch.Size([3, 10, 10])
past_key_values: <class 'transformers.cache_utils.DynamicCache'>


In [21]:
# ===============================================================================
# 计算策略模型（policy model）的对数概率
# ===============================================================================
# 前向传播，获取每个token位置的预测logits
logits = policy_model(input_ids)["logits"][:, :-1, :]  # 去掉最后一个位置，与label对齐
print(logits.shape)
logits

torch.Size([3, 9, 10])


tensor([[[-0.0565, -0.0827, -0.1474,  0.0100, -0.0097, -0.1095,  0.0679,
           0.0035, -0.0819, -0.0010],
         [ 0.1082, -0.1835, -0.1182, -0.1179,  0.0160, -0.1471,  0.0077,
          -0.1568, -0.0925,  0.0218],
         [ 0.0287, -0.1253,  0.0064, -0.0516, -0.0354, -0.0353, -0.0429,
          -0.2031, -0.0248, -0.0239],
         [-0.0219,  0.1550,  0.0996,  0.1397, -0.1033,  0.0055, -0.1241,
           0.0612,  0.0729, -0.0935],
         [-0.1054,  0.0074, -0.1446, -0.1313, -0.0460,  0.0428,  0.0211,
           0.0946,  0.1309, -0.0070],
         [ 0.0928, -0.1962, -0.0238, -0.0858, -0.1913, -0.2821, -0.0004,
          -0.1976,  0.0196, -0.1160],
         [-0.1436,  0.0744, -0.0164,  0.0534,  0.0762, -0.0857, -0.2603,
          -0.0457,  0.1537,  0.0267],
         [-0.1182,  0.0489, -0.0797,  0.0039, -0.0279,  0.1024,  0.0110,
           0.0005, -0.0170, -0.0856],
         [-0.0509, -0.0215, -0.0647,  0.0435, -0.0992,  0.0381, -0.1139,
           0.0671, -0.0092, -0.1155]],


In [20]:
# 将logits转换为对数概率，并提取每个位置上正确token的对数概率
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
per_token_logps

tensor([[-6.7857, -6.6947, -6.7697, -6.5169, -6.6811, -7.0314, -6.7662, -6.5559,
         -6.6359],
        [-6.7857, -6.6947, -6.7697, -6.5169, -6.6811, -7.1295, -6.9317, -6.6988,
         -6.4697],
        [-6.7857, -6.6947, -6.7697, -6.5169, -6.6811, -6.7405, -6.8247, -7.2198,
         -6.8045]], grad_fn=<SqueezeBackward1>)

In [21]:
# 仅对回答部分（loss_mask=True的位置）求和，得到每个序列的总对数概率
all_logps = (per_token_logps * loss_mask).sum(-1)
all_logps

tensor([-60.4376, -60.6779, -61.0376], grad_fn=<SumBackward1>)

In [22]:
# 分离优质回答和低质量回答的对数概率
policy_good_logps, policy_bad_logps = all_logps[:1], all_logps[1:]

In [23]:
policy_good_logps

tensor([-60.4376], grad_fn=<SliceBackward0>)

In [24]:
policy_bad_logps

tensor([-60.6779, -61.0376], grad_fn=<SliceBackward0>)

In [22]:
# ===============================================================================
# 计算参考模型（reference model）的对数概率
# ===============================================================================
with torch.no_grad():  # 不计算梯度，因为参考模型不需要更新
    # 重复与策略模型相同的步骤
    logits = reference_model(input_ids)["logits"][:, :-1, :]
    print(logits.shape)
    print("logits:\n",logits)
    per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
    print(per_token_logps.shape)
    print("per_token_logps:\n",per_token_logps)
    all_logps = (per_token_logps * loss_mask).sum(-1)
    print(all_logps.shape)
    print("all_logps\n",all_logps)
    reference_good_logps, reference_bad_logps = all_logps[:1], all_logps[1:]
    print(reference_good_logps.shape)
    print("reference_good_logps:\n",reference_good_logps)
    print(reference_bad_logps.shape)
    print("reference_bad_logps\n",reference_bad_logps)

torch.Size([3, 9, 10])
logits:
 tensor([[[-0.0565, -0.0827, -0.1474,  0.0100, -0.0097, -0.1095,  0.0679,
           0.0035, -0.0819, -0.0010],
         [ 0.1082, -0.1835, -0.1182, -0.1179,  0.0160, -0.1471,  0.0077,
          -0.1568, -0.0925,  0.0218],
         [ 0.0287, -0.1253,  0.0064, -0.0516, -0.0354, -0.0353, -0.0429,
          -0.2031, -0.0248, -0.0239],
         [-0.0219,  0.1550,  0.0996,  0.1397, -0.1033,  0.0055, -0.1241,
           0.0612,  0.0729, -0.0935],
         [-0.1054,  0.0074, -0.1446, -0.1313, -0.0460,  0.0428,  0.0211,
           0.0946,  0.1309, -0.0070],
         [ 0.0928, -0.1962, -0.0238, -0.0858, -0.1913, -0.2821, -0.0004,
          -0.1976,  0.0196, -0.1160],
         [-0.1436,  0.0744, -0.0164,  0.0534,  0.0762, -0.0857, -0.2603,
          -0.0457,  0.1537,  0.0267],
         [-0.1182,  0.0489, -0.0797,  0.0039, -0.0279,  0.1024,  0.0110,
           0.0005, -0.0170, -0.0856],
         [-0.0509, -0.0215, -0.0647,  0.0435, -0.0992,  0.0381, -0.1139,
       

In [26]:
# ===============================================================================
# 计算DPO损失
# DPO的核心思想：增大策略模型对优质回答的概率，同时减小对低质量回答的概率
# ===============================================================================
# 计算DPO的logits：(策略模型相对于参考模型对好回答的提升) - (对坏回答的提升)
logits = (policy_good_logps - reference_good_logps) - (policy_bad_logps - reference_bad_logps)
# 应用logsigmoid函数并乘以beta控制优化强度，取负值（因为要最小化损失）
loss = -F.logsigmoid(beta * logits).mean()  # 对所有样本取平均

# 输出损失值
print(loss)

tensor(0.6931, grad_fn=<NegBackward0>)
