# 1 PPO奖励模型训练

In [None]:

"""
现有微调框架例如LLaMa-Factory有PPO实现的功能，主要依赖于trl这个库，TRL（Transformer Reinforcement Learning）是由 Hugging Face 提供的一个开源库，
专门用于使用强化学习（Reinforcement Learning, RL）训练基于 Transformer 的语言模型。它是一个全面的工具集，支持多种强化学习方法，
包括监督微调（Supervised Fine-Tuning, SFT）、奖励建模（Reward Modeling, RM）、近端策略优化（Proximal Policy Optimization, PPO）和直接偏好优化（Direct Preference Optimization, DPO）
LLaMa-Factory train/ppo/trainer.py文件
"""

if is_peft_available():
    from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training


class RewardTrainer(Trainer):
    r"""
    建议使用“AutoModelForSequenceClassification”，即分类模型，作为奖励模型。
    奖励模型应该在成对的examples数据集上进行训练，其中每个examples都是一个两个sequence组成的元组。
    应该训练奖励模型来预测配对中的哪个例子与手头的任务更相关。

    训练数据需要满足特定格式，至少包含以下四个字段：
    - `input_ids_chosen`
    - `attention_mask_chosen`
    - `input_ids_rejected`
    - `attention_mask_rejected`
    此外，还可以提供一个可选的 margin 字段，用于调整损失函数
    """
    
    _tag_names = ["trl", "reward-trainer"]

    # 检查配置和参数，并初始化类
    def __init__(
        self,
        model: Optional[Union[PreTrainedModel, nn.Module]] = None,      # 要训练的模型，推荐使用 AutoModelForSequenceClassification
        args: Optional[RewardConfig] = None,                            # 训练配置，推荐使用 RewardConfig
        data_collator: Optional[DataCollator] = None,                   # 数据整理器，默认使用 RewardDataCollatorWithPadding
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
        tokenizer: Optional[PreTrainedTokenizerBase] = None,            # 分词器，
        model_init: Optional[Callable[[], PreTrainedModel]] = None,     
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,                 # 评估指标，默认为 compute_accuracy
        callbacks: Optional[List[TrainerCallback]] = None,              # 插件，在程序特定位置执行特定功能。回调函数是一种“你先把它交给别人，然后别人在需要的时候再调用它”的机制。
        optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
            None,
            None,
        ),
        preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
        max_length: Optional[int] = None,                               # 批中序列的最大长度。如果要使用默认数据整理器，则需要此参数。
        peft_config: Optional[Dict] = None,                             # 用于模型PEFT微调。
    ):
        # （1）检查训练配置并警告
        if type(args) == TrainingArguments:
            warnings.warn(
                "Using `transformers.TrainingArguments` for `args` is deprecated and will be removed in a future version. Please use `RewardConfig` instead.",
                FutureWarning,
            )
            if max_length is not None:
                warnings.warn(
                    "The `max_length` argument is deprecated and will be removed in a future version. Please use the `RewardConfig` to set `max_length` instead.",
                    FutureWarning,
                )
        else:
            if max_length is not None and args.max_length is not None:
                raise ValueError(
                    "You cannot specify both `max_length` and `args.max_length`. Please use the `RewardConfig` to set `max_length` once."
                )
            if max_length is not None and args.max_length is None:
                warnings.warn(
                    "The `max_length` argument is deprecated and will be removed in a future version. Please use the `RewardConfig` to set `max_length` instead.",
                    FutureWarning,
                )
        # （2）PEFT设置：如果提供了 peft_config，代码会检查是否安装了 PEFT 库，并将模型包装为 PEFT 模型。
        if not is_peft_available() and peft_config is not None:
            raise ValueError(
                "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
            )
        elif is_peft_available() and peft_config is not None:
            if not isinstance(model, PeftModel):
                if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False):
                    _supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(
                        inspect.signature(prepare_model_for_kbit_training).parameters
                    )

                    prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}

                    if not _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
                        warnings.warn(
                            "You passed `gradient_checkpointing_kwargs` in the trainer's kwargs, but your peft version does not support it. "
                            "please update to the latest version of peft to use `gradient_checkpointing_kwargs`."
                        )
                    elif _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
                        prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs

                    model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)

                model = get_peft_model(model, peft_config)

        # （3）如果评估指标没指定，则使用默认的
        if compute_metrics is None:
            compute_metrics = compute_accuracy

        # （4）如果未提供 data_collator，则会使用默认的 RewardDataCollatorWithPadding，并设置 max_length
        if data_collator is None:
            if tokenizer is None:
                raise ValueError(
                    "max_length or a tokenizer must be specified when using the default RewardDataCollatorWithPadding"
                )
            if type(args) == TrainingArguments:
                if max_length is None:
                    warnings.warn(
                        "When using RewardDataCollatorWithPadding, you should set `max_length` in RewardConfig."
                        " It will be set to `512` by default, but you should do it yourself in the future.",
                        UserWarning,
                    )
                    max_length = 512
            else:
                if max_length is None and args.max_length is None:
                    warnings.warn(
                        "When using RewardDataCollatorWithPadding, you should set `max_length` in RewardConfig."
                        " It will be set to `512` by default, but you should do it yourself in the future.",
                        UserWarning,
                    )
                    max_length = 512
                if max_length is None and args.max_length is not None:
                    max_length = args.max_length

            data_collator = RewardDataCollatorWithPadding(tokenizer, max_length=max_length)

            if args.remove_unused_columns:
                try:  # for bc before https://github.com/huggingface/transformers/pull/25435
                    args.remove_unused_columns = False
                except FrozenInstanceError:
                    args = replace(args, remove_unused_columns=False)
                # warn users
                warnings.warn(
                    "When using RewardDataCollatorWithPadding, you should set `remove_unused_columns=False` in your RewardConfig"
                    " we have set it for you, but you should do it yourself in the future.",
                    UserWarning,
                )

            self.use_reward_data_collator = True
        else:
            self.use_reward_data_collator = False
        
        super().__init__(
            model=model,
            args=args,
            data_collator=data_collator,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            tokenizer=tokenizer,
            model_init=model_init,
            compute_metrics=compute_metrics,
            callbacks=callbacks,
            optimizers=optimizers,
            preprocess_logits_for_metrics=preprocess_logits_for_metrics,
        )

        # Add tags for models that have been loaded with the correct transformers version
        if hasattr(self.model, "add_model_tags"):
            self.model.add_model_tags(self._tag_names)

    # 损失函数计算
    def compute_loss(
        self,
        model: Union[PreTrainedModel, nn.Module],
        inputs: Dict[str, Union[torch.Tensor, Any]],
        return_outputs=False,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
        # （1）使用模型分别计算 chosen 和 rejected 序列的奖励分数（logits）：一个n类别的得分向量
        rewards_chosen = model(
            input_ids=inputs["input_ids_chosen"],
            attention_mask=inputs["attention_mask_chosen"],
            return_dict=True,
        )["logits"]
        rewards_rejected = model(
            input_ids=inputs["input_ids_rejected"],
            attention_mask=inputs["attention_mask_rejected"],
            return_dict=True,
        )["logits"]

        # （2）计算损失：winner - loser - margin后整个向量输入进行sigmoid和log计算。如果提供了 margin，则在损失计算中加入该值
        if "margin" in inputs:
            loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs["margin"]).mean()
        else:
            loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean()

        # （3）返回损失值和两者的评分
        if return_outputs:
            return loss, {
                "rewards_chosen": rewards_chosen,
                "rewards_rejected": rewards_rejected,
            }
        return loss

    # 预测：在评估阶段计算模型输出
    def prediction_step(
        self,
        model: Union[PreTrainedModel, nn.Module],
        inputs: Dict[str, Union[torch.Tensor, Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[List[str]] = None,
    ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
        inputs = self._prepare_inputs(inputs)
        if ignore_keys is None:
            if hasattr(self.model, "config"):
                ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
            else:
                ignore_keys = []

        with torch.no_grad():
            loss, logits_dict = self.compute_loss(model, inputs, return_outputs=True)

        if prediction_loss_only:
            return (loss, None, None)

        loss = loss.detach()
        logits = tuple(v for k, v in logits_dict.items() if k not in ignore_keys)
        logits = nested_detach(logits)
        # Stack accepted against rejected, mean over logits
        # and softmax to get preferences between accepted and rejected to sum to 1
        logits = torch.stack(logits).mean(dim=2).softmax(dim=0).T

        labels = torch.zeros(logits.shape[0])
        labels = self._prepare_inputs(labels)

        return loss, logits, labels

    def evaluate(self, *args, **kwargs):
        num_print_samples = kwargs.pop("num_print_samples", 4)
        self.visualize_samples(num_print_samples)
        return super().evaluate(*args, **kwargs)

    def visualize_samples(self, num_print_samples: int):
        """
        可视化奖励模型的预测结果

        Args:
            num_print_samples (`int`, defaults to `4`):
                The number of samples to print. Set to `-1` to print all samples.
        """
        eval_dataloader = self.get_eval_dataloader()
        table = defaultdict(list)
        # 遍历评估数据集，解码 chosen 和 rejected 序列的文本。
        for _, inputs in enumerate(eval_dataloader):
            _, logits, _ = self.prediction_step(self.model, inputs, prediction_loss_only=False)
            chosen_text = self.tokenizer.batch_decode(inputs["input_ids_chosen"], skip_special_tokens=True)
            rejected_text = self.tokenizer.batch_decode(inputs["input_ids_rejected"], skip_special_tokens=True)
            table["chosen_text"].extend(gather_object(chosen_text))
            table["rejected_text"].extend(gather_object(rejected_text))
            table["logits"].extend(
                gather_object([[round(inner_item, 4) for inner_item in item] for item in logits.tolist()])
            )
            if num_print_samples >= 0 and len(table["chosen_text"]) >= num_print_samples:
                break
        df = pd.DataFrame(table)
        if self.accelerator.process_index == 0:
            print_rich_table(df[:num_print_samples])
            if "wandb" in self.args.report_to:
                import wandb

                if wandb.run is not None:
                    wandb.log({"completions": wandb.Table(dataframe=df)})

# 2 PPO stage

In [None]:
def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None:
    r"""实现PPO阶段的训练循环，如Huggingface训练器中的_inner_training_loop（）。"""
    if resume_from_checkpoint is not None:
        raise ValueError("`resume_from_checkpoint` will be supported in the future version.")

    # （1）在训练开始之前，代码首先计算了训练过程中的一些关键参数：总的训练批次大小，考虑了设备数量、梯度累积步数和PPO缓冲区大小。
    total_train_batch_size = (
        self.args.per_device_train_batch_size
        * self.args.gradient_accumulation_steps
        * self.finetuning_args.ppo_buffer_size
        * self.args.world_size
    )
    if self.args.max_steps > 0:
        num_examples = total_train_batch_size * self.args.max_steps     # 总的训练样本数量
        num_train_epochs = sys.maxsize                                  # 训练的总轮数。
        max_steps = self.args.max_steps                                 # 训练的最大步数，用于控制训练的总迭代次数。
        steps_in_epoch = self.args.max_steps                            # 每轮训练的步数。
    else:
        len_dataloader = len(self.dataloader)
        num_examples = len(self.dataset)
        num_train_epochs = self.args.num_train_epochs
        max_steps = math.ceil(num_train_epochs * len_dataloader)
        steps_in_epoch = len_dataloader

    self.state.max_steps = max_steps
    self.state.num_train_epochs = num_train_epochs
    self.state.is_local_process_zero = self.is_local_process_zero()
    self.state.is_world_process_zero = self.is_world_process_zero()

    # （2）日志和状态初始化：这些信息通过 logger.info_rank0 打印到日志中，方便调试和监控。
    logger.info_rank0("***** Running training *****")
    logger.info_rank0(f"  Num examples = {num_examples:,}")
    logger.info_rank0(f"  Num Epochs = {num_train_epochs:,}")
    logger.info_rank0(f"  Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}")
    logger.info_rank0(
        "  Total train batch size (w. parallel, buffer, distributed & accumulation) = {:,}".format(
            total_train_batch_size
        )
    )
    logger.info_rank0(f"  Gradient Accumulation steps = {self.args.gradient_accumulation_steps:,}")
    logger.info_rank0(f"  Num optimization epochs per batch = {self.finetuning_args.ppo_epochs:,}")
    logger.info_rank0(f"  Total training steps = {max_steps:,}")
    logger.info_rank0(f"  Number of trainable parameters = {count_parameters(self.model)[0]:,}")

    # （3）数据迭代器和状态初始化
    dataiter = iter(self.dataloader)    # 创建一个数据迭代器 dataiter，用于在训练过程中逐批次获取数据。
    loss_meter = AverageMeter()         # 用于跟踪训练过程中的损失和奖励。
    reward_meter = AverageMeter()       
    self.callback_handler.on_train_begin(self.args, self.state, self.control)       # 触发训练开始前的回调函数。
    
    # （4）训练循环的核心部分是一个 for 循环，迭代 max_steps 次
    for step in tqdm(range(max_steps), disable=not self.is_local_process_zero()):
        # 4.1 获取数据批次：从数据迭代器中获取一个批次的数据 batch。如果迭代器耗尽，则重新创建迭代器并继续获取数据。
        try:
            batch = next(dataiter)
        except StopIteration:
            dataiter = iter(self.dataloader)
            batch = next(dataiter)

        # 4.2 生成查询和响应：遍历batch中的mini batch
        # Get inputs
        self.model.eval()
        self.tokenizer.padding_side = "right"  # change padding side
        queries, responses, rewards = [], [], []
        for idx in range(0, self.config.batch_size, self.config.mini_batch_size):
            # 从当前批次中生成查询（queries）和响应（responses），并
            mini_batch_queries, mini_batch_responses = self.get_inputs(
                batch[idx : idx + self.config.mini_batch_size]
            )
            # 计算对应的奖励（rewards）。
            mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses)
            queries.extend(mini_batch_queries)
            responses.extend(mini_batch_responses)
            rewards.extend(mini_batch_rewards)

        # Run PPO step
        self.model.train()
        stats = self.step(queries, responses, rewards)
        self.tokenizer.padding_side = "left"  # restore padding side
        loss_meter.update(float(stats["ppo/loss/total"]), n=len(rewards))
        reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))

        if self.config.log_with is not None:
            try:
                batch["query"] = self.tokenizer.batch_decode(queries, skip_special_tokens=True)
                batch["response"] = self.tokenizer.batch_decode(responses, skip_special_tokens=True)
                self.log_stats(stats, batch, rewards)
            except Exception:
                logger.warning_rank0("Failed to save stats due to unknown errors.")

        self.state.global_step += 1
        self.callback_handler.on_step_end(self.args, self.state, self.control)

        if self.is_local_process_zero() and (step + 1) % self.args.logging_steps == 0:
            logs = dict(
                loss=round(loss_meter.avg, 4),
                reward=round(reward_meter.avg, 4),
                learning_rate=stats["ppo/learning_rate"],
                epoch=round(step / steps_in_epoch, 2),
            )
            tqdm.write(str(logs))
            logs["step"] = step
            self.state.log_history.append(logs)
            self.callback_handler.on_log(self.args, self.state, self.control, logs)
            loss_meter.reset()
            reward_meter.reset()

        if (step + 1) % self.args.save_steps == 0:  # save checkpoint
            self.save_model(
                os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
            )
            self.callback_handler.on_save(self.args, self.state, self.control)

        if self.control.should_epoch_stop or self.control.should_training_stop:
            break

    self.callback_handler.on_train_end(self.args, self.state, self.control)