From dc699537cb47876d03356fc1601735e9bc41308d Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sun, 16 Nov 2025 13:17:21 +0800 Subject: [PATCH 1/2] support add_eos --- swift/llm/template/base.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index 676158ff83..367828876c 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -1099,6 +1099,7 @@ def _swift_encode(self, inputs: StdTemplateInputs): assert len(inputs.messages) > 0, f'inputs.messages: {inputs.messages}' n_round = len(inputs.messages) // 2 + add_eos = inputs.extra_kwargs.get('add_eos') for i, (query_message, response_message) in enumerate(zip(inputs.messages[::2], inputs.messages[1::2])): query_role, query = query_message['role'], query_message['content'] response_role, response = response_message['role'], response_message['content'] @@ -1139,7 +1140,10 @@ def _swift_encode(self, inputs: StdTemplateInputs): if isinstance(stop_word, str)) # self.is_training needed because we may want to continue generation from # the current response - if (self.is_training or self.task_type != 'causal_lm') and not sep_token and not endswith_stop_words: + if add_eos is None: + add_eos = (self.is_training + or self.task_type != 'causal_lm') and not sep_token and not endswith_stop_words + if add_eos: extra_context_list = template_meta.suffix extra_context_type = ContextType.SUFFIX elif template_meta.response_prefix: From 25157fb72f06b3d3323eb6ed975d2d14e6d3d562 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sun, 16 Nov 2025 13:54:02 +0800 Subject: [PATCH 2/2] update --- swift/llm/template/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index 367828876c..5993dfee0c 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -1099,7 +1099,6 @@ def _swift_encode(self, inputs: StdTemplateInputs): assert len(inputs.messages) > 0, f'inputs.messages: {inputs.messages}' n_round = len(inputs.messages) // 2 - add_eos = inputs.extra_kwargs.get('add_eos') for i, (query_message, response_message) in enumerate(zip(inputs.messages[::2], inputs.messages[1::2])): query_role, query = query_message['role'], query_message['content'] response_role, response = response_message['role'], response_message['content'] @@ -1140,6 +1139,7 @@ def _swift_encode(self, inputs: StdTemplateInputs): if isinstance(stop_word, str)) # self.is_training needed because we may want to continue generation from # the current response + add_eos = inputs.extra_kwargs.get('add_eos') if add_eos is None: add_eos = (self.is_training or self.task_type != 'causal_lm') and not sep_token and not endswith_stop_words