diff --git a/swift/llm/argument/rlhf_args.py b/swift/llm/argument/rlhf_args.py index 68dbd9f7d6..1c53e48ab8 100644 --- a/swift/llm/argument/rlhf_args.py +++ b/swift/llm/argument/rlhf_args.py @@ -43,6 +43,9 @@ class RLHFArguments(TrainArguments): desirable_weight: float = 1.0 undesirable_weight: float = 1.0 + # Use last_round by default + loss_scale: str = 'last_round' + def __post_init__(self): self._init_simpo() self._set_default() diff --git a/swift/llm/train/rlhf.py b/swift/llm/train/rlhf.py index 906a3e1166..a11a32aed5 100644 --- a/swift/llm/train/rlhf.py +++ b/swift/llm/train/rlhf.py @@ -28,9 +28,10 @@ def _prepare_template(self) -> None: mode = 'kto' if args.rlhf_type == 'kto' else 'rlhf' self.template.set_mode(mode) - if args.rlhf_type != 'orpo' or args.model_meta.is_multimodal: + if args.rlhf_type == 'orpo' and not args.model_meta.is_multimodal: # Avoid padding labels during the model's forward pass in multimodal models. - self.template.loss_scale = 'last_round' + args.loss_scale = 'default' + self.template.loss_scale = args.loss_scale @classmethod def prepare_model(cls, args, model, *_args, **kwargs):