From 4f9044296db8e8be95f9a98a5980fd60c7d91a11 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sat, 1 Mar 2025 23:28:31 +0800 Subject: [PATCH] fix generation_config --- swift/llm/train/sft.py | 1 + swift/trainers/mixin.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/swift/llm/train/sft.py b/swift/llm/train/sft.py index e7ae237227..a673c5e472 100644 --- a/swift/llm/train/sft.py +++ b/swift/llm/train/sft.py @@ -53,6 +53,7 @@ def _prepare_gradient_checkpointing(self): def _prepare_generation_config(self): args = self.args + self.model.origin_generation_config = self.model.generation_config self.model.generation_config = prepare_generation_config(self.model.generation_config, args.get_request_config(), self.tokenizer) logger.info(f'model.generation_config: {self.model.generation_config}') diff --git a/swift/trainers/mixin.py b/swift/trainers/mixin.py index c8221c0ec4..92a9e0045b 100644 --- a/swift/trainers/mixin.py +++ b/swift/trainers/mixin.py @@ -221,6 +221,8 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None): from swift.llm import save_checkpoint additional_saved_files = self.model_meta.additional_saved_files save_checkpoint(None, self.template.processor, output_dir, additional_saved_files=additional_saved_files) + if hasattr(self.model, 'origin_generation_config'): + self.model.origin_generation_config.save_pretrained(output_dir) def _fix_zero3_gather_all_parameters(self) -> None: if is_deepspeed_zero3_enabled() and not hasattr(self.deepspeed, '_zero3_consolidated_16bit_state_dict_origin'):