diff --git a/swift/trainers/rlhf_trainer/ppo_trainer.py b/swift/trainers/rlhf_trainer/ppo_trainer.py index 20a6012240..2f3ceecb54 100644 --- a/swift/trainers/rlhf_trainer/ppo_trainer.py +++ b/swift/trainers/rlhf_trainer/ppo_trainer.py @@ -3,8 +3,6 @@ from contextlib import contextmanager from typing import Optional -import transformers -from packaging import version from torch.utils.data import DataLoader from transformers import PreTrainedModel, Trainer from trl import PPOTrainer as HFPPOTrainer @@ -38,8 +36,14 @@ def __init__(self, model: PreTrainedModel, ref_model: PreTrainedModel, *_args, * with self._patch_dataloader(kwargs['data_collator']): new_kwargs = { k: v - for k, v in kwargs.items() - if k in ['train_dataset', 'data_collator', 'reward_model', 'value_model', 'eval_dataset', 'callbacks'] + for k, v in kwargs.items() if k in [ + 'train_dataset', + 'data_collator', + 'reward_model', + 'value_model', + 'eval_dataset', + 'callbacks', + ] } parameters = inspect.signature(ppo_trainer_init).parameters if 'config' in parameters: @@ -63,7 +67,14 @@ def train(self, *args, **kwargs): def _save_checkpoint(self, *args, **kwargs): kwargs.pop('metrics', None) - return super()._save_checkpoint(*args, **kwargs) + + backup_model = self.model + try: + # Unwrap model if needed + self.model = self.accelerator.unwrap_model(self.model) + return super()._save_checkpoint(*args, **kwargs) + finally: + self.model = backup_model def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False): # https://github.com/huggingface/trl/issues/2122