Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements/framework.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ tensorboard
tqdm
transformers>=4.33,<4.48
transformers_stream_generator
trl>=0.11.0
trl>=0.11,<0.12
10 changes: 9 additions & 1 deletion swift/trainers/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,7 +770,6 @@ def _save_checkpoint(self, model, trial, metrics=None):
def concatenated_forward(
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:

model_kwargs = batch.copy()
labels = model_kwargs.pop('labels', None)
if self.is_encoder_decoder:
Expand Down Expand Up @@ -808,6 +807,15 @@ def get_batch_logps(self, logits: torch.FloatTensor, labels: torch.LongTensor, *
labels = labels.clone() # fix trl bug
return super().get_batch_logps(logits, labels, *args, **kwargs)

def compute_loss(self, model, inputs, return_outputs=None, num_items_in_batch=None):
res = super().compute_loss(model, inputs, return_outputs=return_outputs)
# compat transformers>=4.46.*
if num_items_in_batch is not None:
loss = res[0] if return_outputs else res
loss /= self.args.gradient_accumulation_steps
return (loss, res[1:]) if return_outputs else loss
return res


# monkey patching
trainer.DEFAULT_PROGRESS_CALLBACK = ProgressCallbackNew
Expand Down
1 change: 1 addition & 0 deletions swift/trainers/rlhf_trainer/cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from swift.trainers import PushToMsHubMixin, RLHFTrainerMixin, SwiftMixin

del HFCPOTrainer.__init__
del HFCPOTrainer.get_batch_samples


class CPOTrainer(RLHFTrainerMixin, PushToMsHubMixin, SwiftMixin, HFCPOTrainer):
Expand Down
1 change: 1 addition & 0 deletions swift/trainers/rlhf_trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from swift.trainers import PushToMsHubMixin, RLHFTrainerMixin, SwiftMixin

del HFDPOTrainer.__init__
del HFDPOTrainer.get_batch_samples


class DPOTrainer(RLHFTrainerMixin, PushToMsHubMixin, SwiftMixin, HFDPOTrainer):
Expand Down
1 change: 1 addition & 0 deletions swift/trainers/rlhf_trainer/kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
logger = get_logger()

del HFKTOTrainer.__init__
del HFKTOTrainer.get_batch_samples


def _add_kl_dataset(dataset: LLMDataset, total_batch_size: int, seed: Optional[int] = None) -> None:
Expand Down
1 change: 1 addition & 0 deletions swift/trainers/rlhf_trainer/orpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from swift.trainers import PushToMsHubMixin, RLHFTrainerMixin, SwiftMixin

del HFORPOTrainer.__init__
del HFORPOTrainer.get_batch_samples


class ORPOTrainer(RLHFTrainerMixin, PushToMsHubMixin, SwiftMixin, HFORPOTrainer):
Expand Down
14 changes: 8 additions & 6 deletions swift/trainers/rlhf_trainer/reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,11 @@ def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = Non
self.use_reward_data_collator = True # disable warning
super().__init__(model, *_args, **kwargs)

def compute_loss(
self,
model: Union[PreTrainedModel, nn.Module],
inputs: Dict[str, Union[torch.Tensor, Any]],
return_outputs=False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
def compute_loss(self,
model: Union[PreTrainedModel, nn.Module],
inputs: Dict[str, Union[torch.Tensor, Any]],
return_outputs=False,
num_items_in_batch=None) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
model_kwargs = inputs.copy()
labels = model_kwargs.pop('labels', None)
if self.is_encoder_decoder:
Expand All @@ -43,6 +42,9 @@ def compute_loss(
dim=-1, index=(rejected_masks.sum(dim=-1, keepdim=True) - 1)).squeeze()
loss = -torch.nn.functional.logsigmoid(chosen_scores.float() - rejected_scores.float()).mean().to(
self.args.device)
# compat transformers>=4.46.*
if num_items_in_batch is not None:
loss /= self.args.gradient_accumulation_steps
if return_outputs:
return loss, {
'rewards_chosen': chosen_rewards,
Expand Down
Loading