From 7131b16b1552f727696e601777601fa3444ae7fb Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Mon, 24 Feb 2025 16:36:06 +0800 Subject: [PATCH] fix --- swift/trainers/rlhf_trainer/grpo_trainer.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 7c65f5de9c..dab8332775 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -243,14 +243,13 @@ def __init__(self, # Buffer the batch to reuse generated outputs across multiple updates. For more details, see # `_get_train_sampler` and `_prepare_inputs`. self._buffered_inputs = [None] * args.gradient_accumulation_steps - - self.add_callback(GRPOCallback(self)) + if self.args.async_generate: + self.add_callback(GRPOCallback(self)) @property def infer_rank(self): rank, local_rank, world_size, local_world_size = get_dist_setting() assert local_world_size % self.args.num_infer_workers == 0 - assert local_world_size + self.args.num_infer_workers == get_device_count() step = local_world_size // self.args.num_infer_workers for _vllm_rank in range(self.args.num_infer_workers): _assigned = _vllm_rank * step @@ -263,7 +262,6 @@ def infer_rank(self): def local_infer_rank(self): rank, local_rank, world_size, local_world_size = get_dist_setting() assert local_world_size % self.args.num_infer_workers == 0 - assert local_world_size + self.args.num_infer_workers == get_device_count() step = local_world_size // self.args.num_infer_workers for _vllm_rank in range(self.args.num_infer_workers): _assigned = _vllm_rank * step @@ -502,6 +500,11 @@ def _generate_and_score_completions( mode = 'eval' if self.control.should_evaluate else 'train' completion_length = self.accelerator.gather_for_metrics(outputs['completion_mask'].sum(1)).float().mean().item() self._metrics[mode]['completion_length'].append(completion_length) + # clip ratio + response_clip_ratio = torch.gt( + self.accelerator.gather_for_metrics(outputs['completion_mask'].sum(1)), + self.args.max_completion_length).float().mean().item() + self._metrics[mode]['response_clip_ratio'].append(response_clip_ratio) reward_per_func = rewards_per_func.mean(0) for i, reward_func in enumerate(self.reward_funcs): if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models