# 从0实现一个DPO

## 1.准备数据

DPO所需要的数据主要三个字段：
- instruction：指令问题
- chosen：选择的偏好回答
- rejected: 不好的回答

# 2、数据集处理

了解DPO训练流程的可以知道，一般的DPO实现是需要将prompt(即instruction)分别和chsoen、rejected拼接在一起的。

# LOSS 

DPO主要是两个模型，policy model(即我们主要要调优的模型) 和 reference model(用来约束的模型)

In [None]:
import torch.nn.functional as F
import torch.nn as nn
import torch

class DPOLoss(nn.Module):
    """
    DPO Loss
    """

    def __init__(self, beta: float=0.1) -> None:
        super().__init__()
        self.beta = beta

    def forward(
        self,
        policy_chosen_logps: torch.Tensor,
        policy_rejected_logps: torch.Tensor,
        reference_chosen_logps: torch.Tensor,
        reference_rejected_logps: torch.Tensor,
    ) :
        """
        policy_chosen_logps: 模型输出的对数概率。Shape: (batch_size,)
        policy_rejected_logps:   Shape: (batch_size,)
        reference_chosen_logps: Shape: (batch_size,)
        reference_rejected_logps: Shape: (batch_size,)
        
        """
        policy_logps = policy_chosen_logps - policy_rejected_logps
        reference_logps = reference_chosen_logps - reference_rejected_logps
        logits = policy_logps - reference_logps
        
        loss = -F.logsigmoid(self.beta * logits)
        
        # 下面两个用于追踪训练的进度
        chosen_rewards = (policy_chosen_logps - reference_chosen_logps).detach()
        rejected_rewards = (policy_rejected_logps - reference_rejected_logps).detach()
        
        # 对每个batch进行平均
        return loss.mean(), chosen_rewards.mean(), rejected_rewards.mean()

        

计算log probs ,也就是 $\pi_\theta (y_w \mid x)$,

In [None]:
def compute_logprobs(logits, labels, mask=None):
    """
    logits:  shape (batch_size, sequence_len, vocab_size)
    labels:  shape (batch_size, sequence_len)
    """
    
    # 需要先进行位移操作
    # 去掉标签的第一个
    labels = labels[:, 1:].clone()
    # 去掉模型输出的最后一个
    logits = logits[:,:-1,:]
    
    logps = F.log_softmax(logits, dim=-1)
    
    select_logprobs = torch.gather(
        input=logps,
        dim=1,
        index=labels.unsqueeze(1)
    ).squeeze(1)
    
    if mask is not None:
        mask = mask[:,1:].clone()
        # 进行掩码padding部分
        select_logprobs = select_logprobs * mask
        # 计算平均
        average_logprobs = select_logprobs.sum(-1) / mask.sum(-1)
        return average_logprobs
    else:
        return  select_logprobs.mean(-1)

clone 示例

In [13]:
mask = torch.tensor([1,2,3])
mask1 = mask
mask1 += 1
print(mask)

tensor([2, 3, 4])


tensor shape示例

In [1]:
import torch.nn.functional as F
import torch
logits = torch.tensor(
    [[2.0, 1.0, 0.1],
     [0.5, 2.5, 0.3]])  # Shape: (2, 3)
targets = torch.tensor([0, 2])  # Shape: (2,)
# print(targets.unsqueeze(-1).shape)

# Manual loss using torch.gather
log_softmax_logits = F.log_softmax(logits, dim=1)  # Shape: (2, 3)
# print(log_softmax_logits)
selected_log_probs = torch.gather(
    input=log_softmax_logits,
    dim=1,
    index=targets.unsqueeze(1), # Shape 2, 1
) # Shape: (2,)
print(selected_log_probs,selected_log_probs.shape)
print(selected_log_probs.squeeze(1),selected_log_probs.squeeze(1).shape)
manual_loss = -selected_log_probs.mean()  # Averaging over the batch


# PyTorch loss
cross_entropy_loss = F.cross_entropy(logits, targets)

print(manual_loss, cross_entropy_loss)

tensor([[-0.4170],
        [-2.4200]]) torch.Size([2, 1])
tensor([-0.4170, -2.4200]) torch.Size([2])
tensor(1.4185) tensor(1.4185)


进行batch的dpo loss计算

In [None]:
def compute_batch_loss(batch, policy_model, reference_model, beta):
    """Compute the DPO loss on an input batch"""
    policy_chosen_logps = compute_logprobs(
        logits=policy_model(batch["chosen"]),
        labels=batch["chosen"],
        mask=batch["chosen_mask"]
    )
    policy_rejected_logps = compute_logprobs(
        logits=policy_model(batch["rejected"]),
        labels=batch["rejected"],
        mask=batch["rejected_mask"]
    )