From 28c0ae9ddedf3f3b9e822cc054c673f6e3775a61 Mon Sep 17 00:00:00 2001 From: "huangjintao.hjt" Date: Tue, 5 Nov 2024 18:02:58 +0800 Subject: [PATCH 1/2] fix trl compat --- requirements/framework.txt | 2 +- swift/trainers/mixin.py | 8 +++++++- swift/trainers/rlhf_trainer/cpo_trainer.py | 1 + swift/trainers/rlhf_trainer/dpo_trainer.py | 1 + swift/trainers/rlhf_trainer/kto_trainer.py | 1 + swift/trainers/rlhf_trainer/orpo_trainer.py | 1 + 6 files changed, 12 insertions(+), 2 deletions(-) diff --git a/requirements/framework.txt b/requirements/framework.txt index c03df1a427..d0d5c6aa33 100644 --- a/requirements/framework.txt +++ b/requirements/framework.txt @@ -20,4 +20,4 @@ tensorboard tqdm transformers>=4.33,<4.48 transformers_stream_generator -trl>=0.11.0 +trl>=0.11,<0.12 diff --git a/swift/trainers/mixin.py b/swift/trainers/mixin.py index 6edc2cc24b..714f620459 100644 --- a/swift/trainers/mixin.py +++ b/swift/trainers/mixin.py @@ -770,7 +770,6 @@ def _save_checkpoint(self, model, trial, metrics=None): def concatenated_forward( self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]] ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: - model_kwargs = batch.copy() labels = model_kwargs.pop('labels', None) if self.is_encoder_decoder: @@ -808,6 +807,13 @@ def get_batch_logps(self, logits: torch.FloatTensor, labels: torch.LongTensor, * labels = labels.clone() # fix trl bug return super().get_batch_logps(logits, labels, *args, **kwargs) + def compute_loss(self, model, inputs, return_outputs=None, num_items_in_batch=None): + res = super().compute_loss(model, inputs, return_outputs=return_outputs) + # compat transformers>=4.46.* + if num_items_in_batch is not None: + res /= self.args.gradient_accumulation_steps + return res + # monkey patching trainer.DEFAULT_PROGRESS_CALLBACK = ProgressCallbackNew diff --git a/swift/trainers/rlhf_trainer/cpo_trainer.py b/swift/trainers/rlhf_trainer/cpo_trainer.py index 3826e7b105..a760cc040c 100644 --- a/swift/trainers/rlhf_trainer/cpo_trainer.py +++ b/swift/trainers/rlhf_trainer/cpo_trainer.py @@ -9,6 +9,7 @@ from swift.trainers import PushToMsHubMixin, RLHFTrainerMixin, SwiftMixin del HFCPOTrainer.__init__ +del HFCPOTrainer.get_batch_samples class CPOTrainer(RLHFTrainerMixin, PushToMsHubMixin, SwiftMixin, HFCPOTrainer): diff --git a/swift/trainers/rlhf_trainer/dpo_trainer.py b/swift/trainers/rlhf_trainer/dpo_trainer.py index 8f26047d73..e6e853a70f 100644 --- a/swift/trainers/rlhf_trainer/dpo_trainer.py +++ b/swift/trainers/rlhf_trainer/dpo_trainer.py @@ -9,6 +9,7 @@ from swift.trainers import PushToMsHubMixin, RLHFTrainerMixin, SwiftMixin del HFDPOTrainer.__init__ +del HFDPOTrainer.get_batch_samples class DPOTrainer(RLHFTrainerMixin, PushToMsHubMixin, SwiftMixin, HFDPOTrainer): diff --git a/swift/trainers/rlhf_trainer/kto_trainer.py b/swift/trainers/rlhf_trainer/kto_trainer.py index 1e1a94c37d..49ab4bfee7 100644 --- a/swift/trainers/rlhf_trainer/kto_trainer.py +++ b/swift/trainers/rlhf_trainer/kto_trainer.py @@ -15,6 +15,7 @@ logger = get_logger() del HFKTOTrainer.__init__ +del HFKTOTrainer.get_batch_samples def _add_kl_dataset(dataset: LLMDataset, total_batch_size: int, seed: Optional[int] = None) -> None: diff --git a/swift/trainers/rlhf_trainer/orpo_trainer.py b/swift/trainers/rlhf_trainer/orpo_trainer.py index aa7808105f..15651d163b 100644 --- a/swift/trainers/rlhf_trainer/orpo_trainer.py +++ b/swift/trainers/rlhf_trainer/orpo_trainer.py @@ -9,6 +9,7 @@ from swift.trainers import PushToMsHubMixin, RLHFTrainerMixin, SwiftMixin del HFORPOTrainer.__init__ +del HFORPOTrainer.get_batch_samples class ORPOTrainer(RLHFTrainerMixin, PushToMsHubMixin, SwiftMixin, HFORPOTrainer): From 4a3bc2abef57d2d2a770afe01c23aeb8769b5a6c Mon Sep 17 00:00:00 2001 From: "huangjintao.hjt" Date: Tue, 5 Nov 2024 19:50:13 +0800 Subject: [PATCH 2/2] fix RM --- swift/trainers/mixin.py | 4 +++- swift/trainers/rlhf_trainer/reward_trainer.py | 14 ++++++++------ 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/swift/trainers/mixin.py b/swift/trainers/mixin.py index 714f620459..f3f8a8b067 100644 --- a/swift/trainers/mixin.py +++ b/swift/trainers/mixin.py @@ -811,7 +811,9 @@ def compute_loss(self, model, inputs, return_outputs=None, num_items_in_batch=No res = super().compute_loss(model, inputs, return_outputs=return_outputs) # compat transformers>=4.46.* if num_items_in_batch is not None: - res /= self.args.gradient_accumulation_steps + loss = res[0] if return_outputs else res + loss /= self.args.gradient_accumulation_steps + return (loss, res[1:]) if return_outputs else loss return res diff --git a/swift/trainers/rlhf_trainer/reward_trainer.py b/swift/trainers/rlhf_trainer/reward_trainer.py index 84d3f40e3e..c77e21d71a 100644 --- a/swift/trainers/rlhf_trainer/reward_trainer.py +++ b/swift/trainers/rlhf_trainer/reward_trainer.py @@ -24,12 +24,11 @@ def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = Non self.use_reward_data_collator = True # disable warning super().__init__(model, *_args, **kwargs) - def compute_loss( - self, - model: Union[PreTrainedModel, nn.Module], - inputs: Dict[str, Union[torch.Tensor, Any]], - return_outputs=False, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: + def compute_loss(self, + model: Union[PreTrainedModel, nn.Module], + inputs: Dict[str, Union[torch.Tensor, Any]], + return_outputs=False, + num_items_in_batch=None) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: model_kwargs = inputs.copy() labels = model_kwargs.pop('labels', None) if self.is_encoder_decoder: @@ -43,6 +42,9 @@ def compute_loss( dim=-1, index=(rejected_masks.sum(dim=-1, keepdim=True) - 1)).squeeze() loss = -torch.nn.functional.logsigmoid(chosen_scores.float() - rejected_scores.float()).mean().to( self.args.device) + # compat transformers>=4.46.* + if num_items_in_batch is not None: + loss /= self.args.gradient_accumulation_steps if return_outputs: return loss, { 'rewards_chosen': chosen_rewards,