# DPO 训练示范教程

作者：林立

github： https://github.com/efsotr

运行过程展示： https://api.wandb.ai/links/nlp-amct/uifw66p5

本教程使用的主要工具包括 Visual Studio Code、Anaconda，以及一块 NVIDIA A40 显卡。

若在硬件或软件环境上存在限制，可通过AI、搜索引擎或向他人请教等方式灵活解决相关问题。

本教程使用的模型为 `Qwen/Qwen2.5-0.5B-Instruct`，数据集为 `PKU-Alignment/Align-Anything` 中的 `text-to-text` 子集。

训练预计耗时： 17 分 48 秒

## 配置环境

使用 anaconda 新建环境 `conda create -n dpo python=3.10 -c conda-forge`

jupyter notebook 选择新建的 `dpo`。

然后要么运行下面的安装命令块，要么按照块中顺序来安装所需的库。

In [None]:
%conda install -c nvidia/label/cuda-12.4 cuda # 若已经有 cuda >=12.4 可以注释掉
%pip install torch==2.6.0
%pip install transformers==4.51.3
%pip install deepspeed==0.16.5
%pip install accelerate==1.6.0
%pip install datasets==2.21.0
%pip install wandb
%pip install flash-attn==2.7.4.post1 --no-build-isolation # 加速用，不必要，需要 Compute Capability >= 8.0  

登录 wandb 账号

In [None]:
!wandb login Your-wandb-key # Your-wandb-key 得替换为单独的 key

## 执行完整的流程

先重启 jupyter notebook，然后再下面的第一个代码单元格上点击“执行单元格及以下”，即可执行完整的 DPO 训练流程。

### 设置环境变量

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # 指定卡号
os.environ["WANDB_PROJECT"] = "text2text_dpo"
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" # 访问 huggingface 镜像站，下载模型和数据集
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["LOCAL_RANK"] = "0"
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "29500"  # 你可以自定义一个不被占用的端口

### 设置参数

#### 设置 deepspeed 参数

In [4]:
import json
json.dump({
    "bf16": {
        "enabled": "auto"
    },
    "zero_optimization": {
        "stage": 0
    },
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "gradient_clipping": "auto",
    "gradient_accumulation_steps": "auto",
}, open("ds_config_zero0.json", "w"), indent=2)

#### 设置 Trainer 参数

In [None]:
from transformers import TrainingArguments
training_args = TrainingArguments(
    output_dir="text_to_text_dpo",
    bf16=True,
    do_train=True,
    deepspeed="ds_config_zero0.json",
    per_device_train_batch_size=8,
    gradient_accumulation_steps=4,
    learning_rate=1e-6,
    adam_beta1=0.9,
    adam_beta2=0.95,
    optim="adamw_torch_fused",
    lr_scheduler_type="cosine",
    weight_decay=0,
    warmup_ratio=0.03,
    num_train_epochs=1,
    remove_unused_columns=False,
    log_level="info",
    logging_strategy="steps",
    logging_steps=1,
    save_strategy="epoch",
    save_only_model=False,
    run_name="text_to_text_dpo",
    report_to="wandb",
    seed=42,
)
training_args.__post_init__()

[2025-05-25 11:04:01,021] [INFO] [comm.py:658:init_distributed] cdb=None
[2025-05-25 11:04:01,022] [INFO] [comm.py:689:init_distributed] Initializing TorchBackend in DeepSpeed with backend nccl


#### 指定重要参数

In [None]:
from transformers import AutoTokenizer, PreTrainedTokenizer
model_path = "Qwen/Qwen2.5-0.5B-Instruct" # 已经下载好的模型的路径 或者 huggingface 上对应的模型 id，现在是模型 id
max_model_len = 2048
dpo_beta = 0.1 # DPO 重要参数
tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(model_path)

### 数据准备阶段

In [7]:
from datasets import load_dataset
train_dataset = load_dataset('PKU-Alignment/Align-Anything',name='text-to-text',split="train")

In [8]:
def get_input_and_label(conversation):
    input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")[0]
    prompt_len = len(tokenizer.apply_chat_template(conversation[:-1], add_generation_prompt=True))
    labels = input_ids.clone()
    labels[:prompt_len] = -100
    return input_ids, labels

def process(raw_sample):
    better_id = int(raw_sample['overall_response'])
    better_response = raw_sample[f'response_{better_id}']
    worse_response = raw_sample[f'response_{3 - better_id}']
    prompt = raw_sample['question']

    better_conversation = [
        {'role': 'user', 'content': prompt},
        {'role': 'assistant', 'content': better_response},
    ]

    worse_conversation = [
        {'role': 'user', 'content': prompt},
        {'role': 'assistant', 'content': worse_response},
    ]

    chosen_input, chosen_labels = get_input_and_label(better_conversation)
    rejected_input, rejected_labels = get_input_and_label(worse_conversation)
    
    return {
        "chosen_input": chosen_input,
        "chosen_labels": chosen_labels,
        "rejected_input": rejected_input, 
        "rejected_labels": rejected_labels,
    }

processed_dataset = train_dataset.map(process, num_proc=8, remove_columns=train_dataset.column_names)
processed_dataset = processed_dataset.filter(lambda ex: ex["chosen_input"] != ex["rejected_input"] and 
                                             max(len(ex["chosen_input"]), len(ex["rejected_input"])) <= max_model_len, num_proc=8)

In [9]:
processed_dataset

Dataset({
    features: ['chosen_input', 'chosen_labels', 'rejected_input', 'rejected_labels'],
    num_rows: 30408
})

In [None]:
import torch
import torch.nn.functional as F
def get_key(exs, key):
    return [ex[key] for ex in exs]

def data_collator(exs):
    inputs = get_key(exs, "chosen_input") + get_key(exs, "rejected_input")
    labels = get_key(exs, "chosen_labels") + get_key(exs, "rejected_labels")
    pad = tokenizer.pad_token_id
    max_len = max(map(len, inputs))
    input_ids = torch.tensor([ex + [pad] * (max_len - len(ex)) for ex in inputs])
    labels = torch.tensor([ex + [-100] * (max_len - len(ex)) for ex in labels])
    return {
        "input_ids": input_ids,
        "labels": labels,
        "attention_mask": input_ids.ne(pad),
    }

### 自定义 DPO 的 Trainer

In [None]:
from transformers import Trainer

@torch.compile(dynamic=True)
def getlogp(logits, labels):
    assert logits.shape[:-1] == labels.shape
    logsumexp = torch.logsumexp(logits, dim=-1, keepdim=True)
    logp = torch.gather(logits, dim=-1, index=labels.unsqueeze(-1))
    return (logp - logsumexp).squeeze(-1)

def packed_getlogp(logits, labels):
    assert logits.shape[:-1] == labels.shape
    logp = torch.zeros_like(labels, dtype=logits.dtype)
    shift_labels = torch.nn.functional.pad(labels, pad=(0, 1), value=-100)[:, 1:]
    mask = shift_labels != -100
    logp[mask] = getlogp(logits[mask], shift_labels[mask])
    return logp

class CustomTrainer(Trainer):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.model_accepts_loss_kwargs = False
        self.logs = {}

    def get_batch_samples(self, *args, **kwargs):
        self.logs = {
            "chosen_reward": 0.0,
            "rejected_reward": 0.0,
            "chosen_logp": 0.0,
            "rejected_logp": 0.0,
            "accuracy": 0.0,
        }
        return super().get_batch_samples(*args, **kwargs)

    def log(self, logs: dict[str, float], start_time = None) -> None:
        if self.logs != {}:
            for name in ["chosen_reward", "rejected_reward", "chosen_logp", "rejected_logp", "accuracy"]:
                self.logs[name] /= self.args.gradient_accumulation_steps
            self.logs["margin_reward"] = self.logs["chosen_reward"] - self.logs["rejected_reward"]
            logs.update(self.logs)
            self.logs = {}

        return super().log(logs, start_time)

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        labels = inputs.pop("labels")
        with torch.no_grad():
            ref_logp = packed_getlogp(self.ref_model(**inputs).logits, labels).sum(dim=1)
        outputs = model(**inputs)
        logp = packed_getlogp(outputs.logits, labels).sum(dim=1)
        chosen, rejected = (logp - ref_logp).mul(dpo_beta).chunk(2)
        loss = - F.logsigmoid(chosen - rejected).mean()
        chosen_logp, rejected_logp = logp.detach().chunk(2)
        self.logs["chosen_reward"] += chosen.detach().mean().item()
        self.logs["rejected_reward"] += rejected.detach().mean().item()
        self.logs["chosen_logp"] += chosen_logp.detach().mean().item()
        self.logs["rejected_logp"] += rejected_logp.detach().mean().item()
        self.logs["accuracy"] += (chosen.detach() > rejected.detach()).float().mean().item()
        return loss

### 显存和运行开销优化：利用 `torch.compile` 和 `flash attention 2`

In [None]:
# 去掉 pad token 优化
def wrap(fn):
    
    def wrapped_fn(input_ids, attention_mask, position_ids=None, **kwargs):
        assert attention_mask is not None
        assert not kwargs.get("output_hidden_states", False)
        
        attention_mask = attention_mask.bool()

        # get shape
        B, L = input_ids.shape
        input_ids = input_ids[attention_mask][None] # (1, T)

        # get position ids
        if position_ids is None:
            position_ids = torch.arange(0, L, device=attention_mask.device)[None].repeat(B, 1)
        position_ids = position_ids[attention_mask][None]

        indices = torch.nonzero(attention_mask.view(-1), as_tuple=True)[0]

        # get flash attention 2 kwargs
        seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
        max_seqlen_in_batch = seqlens_in_batch.max().item()
        cu_seqlens = torch.nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0), value=0)
        attention_mask = None
        kwargs["cu_seq_lens_q"] = cu_seqlens
        kwargs["cu_seq_lens_k"] = cu_seqlens
        kwargs["max_length_q"] = max_seqlen_in_batch
        kwargs["max_length_k"] = max_seqlen_in_batch

        output = fn(input_ids, attention_mask, position_ids, **kwargs) # original forward

        # restore original shape 
        last_hidden_state = output.last_hidden_state # (1, T, H)
        output.last_hidden_state = torch.zeros((B * L, last_hidden_state.size(-1)), 
                                               dtype=last_hidden_state.dtype,
                                               device=last_hidden_state.device)
        output.last_hidden_state[indices] = last_hidden_state[0] # (T, H)
        output.last_hidden_state = output.last_hidden_state.view(B, L, -1)

        return output

    return wrapped_fn

# 启用 torch.compile 的显存优化
def enable_nonsac(enable_remove_add=True):
    import torch._functorch.config
    import torch._functorch.partitioners as partitioners

    # reduce activation memory footprint 
    torch._functorch.config.activation_memory_budget = 0.99

    if enable_remove_add:
        # Remove 'add' from the default operation list to eliminate redundant re-computation 
        # due to wrong partitioning of nodes
        def remove_add(fn):
            def wrapped_fn():
                optypes = fn()
                optypes.recomputable_ops.remove(torch.ops.aten.add)
                return optypes
            return wrapped_fn

        partitioners.get_default_op_list = remove_add(partitioners.get_default_op_list)

enable_nonsac()

# 为模型启用 torch.compile 优化
def patch(model):
    model.model.forward = wrap(model.model.forward)
    for layer in model.model.layers:
        layer.forward = torch.compile(layer.forward, dynamic=True)

### 加载模型

In [13]:
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
    use_cache=False,
)
ref_model = AutoModelForCausalLM.from_pretrained(
    model_path,
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
    use_cache=False,
).cuda()
patch(model)
patch(ref_model)

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.


### 开始训练

In [14]:
trainer = CustomTrainer(
    model=model,
    processing_class=tokenizer,
    args=training_args,
    data_collator=data_collator,
    train_dataset=processed_dataset,
)
trainer.ref_model = ref_model
trainer.train()
trainer.save_state()

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Using auto half precision backend


[2025-05-25 11:04:15,890] [INFO] [logging.py:107:log_dist] [Rank 0] DeepSpeed info: version=0.16.5, git-hash=unknown, git-branch=unknown
[2025-05-25 11:04:15,892] [INFO] [config.py:734:__init__] Config mesh_device None world_size = 1
[2025-05-25 11:04:16,611] [INFO] [logging.py:107:log_dist] [Rank 0] DeepSpeed Flops Profiler Enabled: False
[2025-05-25 11:04:16,615] [INFO] [logging.py:107:log_dist] [Rank 0] Using client Optimizer as basic optimizer
[2025-05-25 11:04:16,616] [INFO] [logging.py:107:log_dist] [Rank 0] Removing param_group that has no 'params' in the basic Optimizer
[2025-05-25 11:04:16,631] [INFO] [logging.py:107:log_dist] [Rank 0] DeepSpeed Basic Optimizer = AdamW
[2025-05-25 11:04:16,632] [INFO] [logging.py:107:log_dist] [Rank 0] Creating BF16 optimizer
[2025-05-25 11:04:16,948] [INFO] [utils.py:781:see_memory_usage] begin bf16_optimizer
[2025-05-25 11:04:16,950] [INFO] [utils.py:782:see_memory_usage] MA 1.86 GB         Max_MA 1.86 GB         CA 1.94 GB         Max_CA 2 

***** Running training *****
  Num examples = 30,408
  Num Epochs = 1
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 32
  Gradient Accumulation steps = 4
  Total optimization steps = 950
  Number of trainable parameters = 494,032,768
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33m205368424[0m ([33mnlp-amct[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss
1,0.7119
2,0.7041
3,0.7002
4,0.6914
5,0.6904
6,0.7119
7,0.6719
8,0.6992
9,0.7266
10,0.6982


Saving model checkpoint to text_to_text_dpo/checkpoint-950
Configuration saved in text_to_text_dpo/checkpoint-950/config.json
Configuration saved in text_to_text_dpo/checkpoint-950/generation_config.json
Model weights saved in text_to_text_dpo/checkpoint-950/model.safetensors
tokenizer config file saved in text_to_text_dpo/checkpoint-950/tokenizer_config.json
Special tokens file saved in text_to_text_dpo/checkpoint-950/special_tokens_map.json


[2025-05-25 11:22:02,242] [INFO] [logging.py:107:log_dist] [Rank 0] [Torch] Checkpoint global_step950 is about to be saved!
[2025-05-25 11:22:02,252] [INFO] [logging.py:107:log_dist] [Rank 0] Saving model checkpoint: text_to_text_dpo/checkpoint-950/global_step950/mp_rank_00_model_states.pt
[2025-05-25 11:22:02,254] [INFO] [torch_checkpoint_engine.py:21:save] [Torch] Saving text_to_text_dpo/checkpoint-950/global_step950/mp_rank_00_model_states.pt...
[2025-05-25 11:22:06,332] [INFO] [torch_checkpoint_engine.py:23:save] [Torch] Saved text_to_text_dpo/checkpoint-950/global_step950/mp_rank_00_model_states.pt.
[2025-05-25 11:22:06,342] [INFO] [torch_checkpoint_engine.py:21:save] [Torch] Saving text_to_text_dpo/checkpoint-950/global_step950/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt...
[2025-05-25 11:22:25,826] [INFO] [torch_checkpoint_engine.py:23:save] [Torch] Saved text_to_text_dpo/checkpoint-950/global_step950/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt.
[2025-05-25 11:22:25,843] [



Training completed. Do not forget to share your model on huggingface.co/models =)


