[bugfix] fix jinja-backend train↔infer parity for Qwen3.6#9277
[bugfix] fix jinja-backend train↔infer parity for Qwen3.6#9277ArvinZhuang wants to merge 1 commit into
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements a mechanism to split ChatML-rendered text into alternating chunks, allowing for selective supervision of assistant turns when using the Jinja backend during training. A regression test for Qwen3.6 is included to ensure consistency between training and inference. Reviewers suggested setting load_model=False in the test to prevent OOM errors and adjusting the answer_len logic for encoder-decoder models to maintain correct data partitioning.
| rendered text, so the labels for system / user / tool tokens were not | ||
| masked — silently training on non-assistant tokens. | ||
| """ | ||
| engine = TransformersEngine('Qwen/Qwen3.6-35B-A3B') |
There was a problem hiding this comment.
The model Qwen/Qwen3.6-35B-A3B is very large. Loading it during a template regression test is unnecessary and can lead to OOM errors or significantly slow down the CI environment. Since this test only requires the tokenizer and template logic, it is recommended to set load_model=False.
| engine = TransformersEngine('Qwen/Qwen3.6-35B-A3B') | |
| engine = TransformersEngine('Qwen/Qwen3.6-35B-A3B', load_model=False) |
| if assistant_contexts is not None: | ||
| return assistant_contexts, [1. if i % 2 == 1 else 0. | ||
| for i in range(len(assistant_contexts))], answer_len |
There was a problem hiding this comment.
For encoder-decoder models, the answer_len is used to slice the res_context_list into prompt and answer parts (see lines 1450-1453). Previously, _jinja_encode returned a single context with answer_len=1, which meant the entire rendered text was treated as the answer (target). With the new splitting logic, len(assistant_contexts) is greater than 1, but answer_len remains 1. This will cause the prompt/answer split to be incorrect for encoder-decoder models that use ChatML templates (the prompt will contain all but the last chunk). To maintain the previous behavior for encoder-decoder models, answer_len should be adjusted to include all chunks in the decoder part.
| if assistant_contexts is not None: | |
| return assistant_contexts, [1. if i % 2 == 1 else 0. | |
| for i in range(len(assistant_contexts))], answer_len | |
| if assistant_contexts is not None: | |
| if self.is_encoder_decoder: | |
| answer_len = len(assistant_contexts) | |
| return assistant_contexts, [1. if i % 2 == 1 else 0. | |
| for i in range(len(assistant_contexts))], answer_len |
|
Please do not modify the base Template, as this will affect all models. |
…lscope#9276) Closes modelscope#9276. ## Problem `Qwen/Qwen3.6-*-A3B` is bound to `TemplateType.qwen3_5` (no separate qwen3_6 template registration), and Qwen3.6's bundled `chat_template.jinja` differs from Qwen3.5's in two places: - assistant-turn `<think>` reasoning preservation gating (Qwen3.6 honors a `preserve_thinking` kwarg) - tool_call argument value rendering Both also share `|trim` filters on user/assistant content that the `agent_template=qwen3_5` swift backend's prompt format does NOT apply. For agent-shaped data where the user prompt ends in `\n` (common for prompts read from a file), the swift backend emits an extra `\n` between the user content and `<|im_end|>` — so train↔inference input_ids drift by ~1 token per user turn. The cleanest path is to make `template_backend='jinja'` actually usable for SFT — currently `_jinja_encode` returns `loss_scale=[1.]` wholesale, which means under jinja the trainer supervises *every* token (system / user / tool / role markers). That's why most users default to the swift backend even though they want the HF chat_template render. ## Fix (scoped to Qwen3_5Template — does NOT modify base Template) Override `_jinja_encode` only in `Qwen3_5Template` (which Qwen3.5 + Qwen3.6 all use today). When training, split the rendered text along `<|im_start|>role\n` markers into alternating `[non_assist, assist, ..., trailing_non_assist]` chunks. The trainer's existing `_encode_context_list` then assigns `loss_scale=0` to non- assistant chunks and `loss_scale=1` to assistant chunks. Recombination is byte-exact, so train and inference render identically. Inference behavior is unchanged (gated on `self.is_training`). The base `Template._jinja_encode` is untouched, so all other models are unaffected. ## Test Adds `test_qwen3_6_jinja_train_infer_parity` in `tests/test_align/test_template/test_agent.py`. Uses agent-shape data: multi-turn `<think>` + tool_call + tool, user content ending in `\n`. Loads with `load_model=False` (only tokenizer + template needed). Asserts: 1. train input_ids byte-equal to inference apply_chat_template tokens 2. labels mask non-assistant tokens (n_supervised > 0, n_masked > 0) 3. user-prompt content does not leak into supervised labels Without the fix, Qwen3.6-35B-A3B fails (1) on agent data with `\n`-ended user content, and fails (2) entirely (every token supervised under template_backend=jinja).
8d20e31 to
3c1c41a
Compare
|
@Jintao-Huang Thanks for the quick review — you're right, the base
Also addressed the gemini-code-assist suggestions:
Diff stat: |
|
Thank you for your PR, the issue has been fixed. |
Closes #9276.
Problem
Qwen/Qwen3.6-*-A3Bis bound toTemplateType.qwen3_5(no separateqwen3_6template registration inswift/model/models/qwen.py:1158-1161), and Qwen3.6's bundledchat_template.jinjadiffers from Qwen3.5's in two places (assistant<think>preservation gating + tool_call args rendering). Both also apply|trimfilters on message content that theagent_template=qwen3_5swift backend'sprompt=['<|im_start|>user\n{{QUERY}}<|im_end|>\n']format string does NOT apply.For agent data where the user prompt ends in
\n(common when prompts are read from a file), the swift backend emits an extra\nbetween the user content and<|im_end|>— so train↔inferenceinput_idsdrift by ~1 token per user turn.The byte-equality test added in PR #8161 (
test_qwen3_5) only covers Qwen3.5-35B-A3B with thefunction-calling-chatmldataset (whose user content has no trailing whitespace), so it does not catch any of the above on Qwen3.6.The cleanest path is to make
template_backend='jinja'actually usable for SFT. Currently_jinja_encodereturnsloss_scale=[1.]wholesale (base.py:1078), so the trainer supervises every token rendered byapply_chat_template— system, user, tool, role markers, everything. That's why users default to the swift backend even though they want the HF chat_template render.Fix
In
_jinja_encode(training path), split the rendered text along<|im_start|>role\nmarkers into alternating[non_assist_0, assist_0, non_assist_1, assist_1, ..., trailing_non_assist]chunks. The existing_encode_context_listthen assignsloss_scale=0to non-assistant chunks andloss_scale=1to assistant chunks. Recombination is byte-exact, so train and inference still render identically.This works for any ChatML-format chat_template (Qwen, ChatGLM, etc.). For non-ChatML templates the helper returns
Noneand the caller falls back to the legacy wholesaleloss_scale=[1.]path (no behavior change for those models).After this PR,
--template_backend jinjabecomes the recommended setting for Qwen3.6 SFT, fully eliminating the swift-vs-jinja byte drift documented in #9276.Test
Adds
test_qwen3_6_jinja_train_infer_parityintests/test_align/test_template/test_agent.pyusing agent-shape data: multi-turn<think>reasoning + tool_call + tool response, user content ending in\n(which exercises Qwen3.6 chat_template's|trimfilter).Asserts:
input_idsbyte-equal to inferenceapply_chat_template(...)tokensn_supervised > 0,n_masked > 0)Without the fix, Qwen3.6-35B-A3B fails (1) on agent data with
\n-ended user content, and fails (2) entirely (every token supervised undertemplate_backend=jinja).Backward compatibility
self.is_training).'<|im_start|>' not in text→ returnNone→ wholesale[1.]path).answer_lenis still emitted correctly).