-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Narrow prefix-preserving check to the actual requirement #5458
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4b3aa51
0894910
730070b
4622d77
8a00354
cd8cbfc
8b35321
160d6a0
8dd341b
7e3ddd7
71bf73f
103d3c9
fffcb67
87131e5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -630,6 +630,10 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: | |
| """ | ||
| Check whether the chat template preserves prefixes when applied. | ||
|
|
||
| A prefix-preserving chat template renders earlier messages identically regardless of what messages follow. This | ||
| property is required by `_get_tool_suffix_ids`, which extracts tool response formatting tokens by comparing | ||
| tokenizations with and without tool messages appended. | ||
|
|
||
| Args: | ||
| tokenizer (`PreTrainedTokenizer`): | ||
| Tokenizer instance to check. | ||
|
|
@@ -638,24 +642,22 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: | |
| `bool`: | ||
| `True` if the chat template preserves prefixes, `False` otherwise. | ||
| """ | ||
| # Use the same dummy messages as _get_tool_suffix_ids to test the exact property it relies on. | ||
| dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": {}}}] | ||
| messages1 = [ | ||
| {"role": "user", "content": "What color is the sky?"}, | ||
| {"role": "user", "content": "dummy"}, | ||
| {"role": "assistant", "content": "", "tool_calls": dummy_tool_calls}, | ||
| ] | ||
| messages2 = [ | ||
| {"role": "user", "content": "What color is the sky?"}, | ||
| {"role": "assistant", "content": "It is blue."}, | ||
| ] | ||
| messages3 = [ | ||
| {"role": "user", "content": "What color is the sky?"}, | ||
| {"role": "assistant", "content": "It is blue."}, | ||
| {"role": "user", "content": "And at night?"}, | ||
| {"role": "user", "content": "dummy"}, | ||
| {"role": "assistant", "content": "", "tool_calls": dummy_tool_calls}, | ||
| {"role": "tool", "name": "dummy", "content": "dummy"}, | ||
| ] | ||
|
|
||
| text1 = tokenizer.apply_chat_template(messages1, tokenize=False, add_generation_prompt=True) | ||
| text2 = tokenizer.apply_chat_template(messages2, tokenize=False) | ||
| text3 = tokenizer.apply_chat_template(messages3, tokenize=False) | ||
| text1 = tokenizer.apply_chat_template(messages1, tokenize=False) | ||
| text2 = tokenizer.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True) | ||
|
|
||
| return text2.startswith(text1) and text3.startswith(text2) | ||
| return text2.startswith(text1) | ||
|
|
||
|
|
||
| # Modifications: | ||
|
|
@@ -749,33 +751,12 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: | |
| {%- endif %}""" | ||
|
|
||
|
|
||
| # Modifications: | ||
| # - {%- if '</think>' in content %} | ||
| # + {%- if '<think>' in content and '</think>' in content %} | ||
| # Always check for both tags to avoid edge cases where the model generates only one tag, which would otherwise be parsed incorrectly | ||
| # - {{- '<|im_start|>' + message.role + '\n' + content }} | ||
| # + {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content + '\n</think>\n\n' + content }} | ||
| # Always include thinking block during training. It's important to have a prefix-preserving template. | ||
| def _patch_qwen3_5_training_template(template: str) -> str: | ||
| return template.replace( | ||
| "{%- if '</think>' in content %}", | ||
| "{%- if '<think>' in content and '</think>' in content %}", | ||
| ).replace( | ||
| "{{- '<|im_start|>' + message.role + '\\n' + content }}", | ||
| "{{- '<|im_start|>' + message.role + '\\n<think>\\n' + reasoning_content + '\\n</think>\\n\\n' + content }}", | ||
| ) | ||
|
|
||
|
|
||
| qwen3_5_training_chat_template_2b_and_below = _patch_qwen3_5_training_template(qwen3_5_chat_template_2b_and_below) | ||
| qwen3_5_training_chat_template_4b_and_above = _patch_qwen3_5_training_template(qwen3_5_chat_template_4b_and_above) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You mention Qwen3.5 doesn't need patching anymore. Does this depend on the transformers version?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no, it's really about the chat template: from transformers import AutoTokenizer, AutoProcessor
dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": {}}}]
messages1 = [
{"role": "user", "content": "dummy"},
{"role": "assistant", "content": "", "tool_calls": dummy_tool_calls},
]
messages2 = messages1 + [
{"role": "tool", "name": "dummy", "content": "dummy"},
]
model_id = "Qwen/Qwen3-0.6B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
text1 = tokenizer.apply_chat_template(messages1, tokenize=False)
text2 = tokenizer.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True)
print(f"\n{'='*60}")
print(f"{model_id}")
print(f"Prefix-preserving: {text2.startswith(text1)}")
print(repr(text1))
print(repr(text2))
model_id = "Qwen/Qwen3.5-0.8B"
tokenizer = AutoProcessor.from_pretrained(model_id)
text1 = tokenizer.apply_chat_template(messages1, tokenize=False)
text2 = tokenizer.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True)
print(f"\n{'='*60}")
print(f"{model_id}")
print(f"Prefix-preserving: {text2.startswith(text1)}")
print(repr(text1))
print(repr(text2)) |
||
|
|
||
|
|
||
| def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None: | ||
| r""" | ||
| Get a prefix-preserving chat template for training, if needed. | ||
|
|
||
| If the tokenizer's template isn't prefix-preserving, returns a training-compatible template (currently Qwen3 and | ||
| Qwen3.5 supported). Otherwise, returns `None`. | ||
| If the tokenizer's template isn't prefix-preserving, returns a training-compatible template (currently Qwen3 | ||
| supported). Otherwise, returns `None`. | ||
|
|
||
| Args: | ||
| tokenizer (`PreTrainedTokenizer`): | ||
|
|
@@ -793,27 +774,31 @@ def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None: | |
|
|
||
| >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B") | ||
| >>> messages1 = [ | ||
| ... {"role": "user", "content": "What color is the sky?"}, | ||
| ... {"role": "assistant", "content": "It is blue."}, | ||
| ... {"role": "user", "content": "What is 2 * 3?"}, | ||
| ... { | ||
| ... "role": "assistant", | ||
| ... "content": "", | ||
| ... "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}], | ||
| ... }, | ||
| ... ] | ||
| >>> messages2 = [ | ||
| ... {"role": "user", "content": "What color is the sky?"}, | ||
| ... {"role": "assistant", "content": "It is blue."}, | ||
| ... {"role": "user", "content": "And at night?"}, | ||
| >>> messages2 = messages1 + [ | ||
| ... {"role": "tool", "name": "multiply", "content": "6"}, | ||
| ... ] | ||
| >>> tokenizer.apply_chat_template(messages1, tokenize=False) | ||
| '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\nIt is blue.<|im_end|>\n' | ||
| '<|im_start|>user\nWhat is 2 * 3?<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n<tool_call>\n{"name": "multiply", "arguments": {"a": 2, "b": 3}}\n</tool_call><|im_end|>\n' | ||
|
|
||
| >>> tokenizer.apply_chat_template(messages2, tokenize=False) | ||
| '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\nIt is blue.<|im_end|>\n<|im_start|>user\nAnd at night?<|im_end|>\n' | ||
| >>> tokenizer.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True) | ||
| '<|im_start|>user\nWhat is 2 * 3?<|im_end|>\n<|im_start|>assistant\n<tool_call>\n{"name": "multiply", "arguments": {"a": 2, "b": 3}}\n</tool_call><|im_end|>\n<|im_start|>user\n<tool_response>\n6\n</tool_response><|im_end|>\n<|im_start|>assistant\n' | ||
|
|
||
| >>> # ^ think tags missing | ||
| >>> # ^ think tags missing | ||
| >>> chat_template = get_training_chat_template(tokenizer) | ||
| >>> tokenizer.apply_chat_template(messages1, tokenize=False, chat_template=chat_template) | ||
| '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\nIt is blue.<|im_end|>\n' | ||
| '<|im_start|>user\nWhat is 2 * 3?<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n<tool_call>\n{"name": "multiply", "arguments": {"a": 2, "b": 3}}\n</tool_call><|im_end|>\n' | ||
|
|
||
| >>> tokenizer.apply_chat_template(messages2, tokenize=False, chat_template=chat_template) | ||
| '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\nIt is blue.<|im_end|>\n<|im_start|>user\nAnd at night?<|im_end|>\n' | ||
| >>> tokenizer.apply_chat_template( | ||
| ... messages2, tokenize=False, add_generation_prompt=True, chat_template=chat_template | ||
| ... ) | ||
| '<|im_start|>user\nWhat is 2 * 3?<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n<tool_call>\n{"name": "multiply", "arguments": {"a": 2, "b": 3}}\n</tool_call><|im_end|>\n<|im_start|>user\n<tool_response>\n6\n</tool_response><|im_end|>\n<|im_start|>assistant\n' | ||
| ``` | ||
| """ | ||
| # First check if patching is needed | ||
|
|
@@ -822,10 +807,6 @@ def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None: | |
|
|
||
| if tokenizer.chat_template == qwen3_chat_template: | ||
| return qwen3_training_chat_template | ||
| if tokenizer.chat_template == qwen3_5_chat_template_2b_and_below: | ||
| return qwen3_5_training_chat_template_2b_and_below | ||
| if tokenizer.chat_template == qwen3_5_chat_template_4b_and_above: | ||
| return qwen3_5_training_chat_template_4b_and_above | ||
| else: | ||
| raise ValueError( | ||
| "The tokenizer's chat template is not prefix-preserving and patching is not supported for this template. " | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -555,31 +555,41 @@ async def _generate_one( | |
|
|
||
| def _get_tool_suffix_ids(self, tool_messages: list[dict[str, Any]]) -> list[int]: | ||
| """Get token IDs for tool result formatting by using a minimal dummy conversation.""" | ||
| dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": {}}}] | ||
| dummy_messages = [ | ||
| {"role": "user", "content": ""}, | ||
| {"role": "assistant", "content": ""}, | ||
| {"role": "user", "content": "dummy"}, | ||
| { | ||
| "role": "assistant", | ||
| # "content" is required here because VLM processors crash on tokenize=True without it | ||
| # (KeyError in processing_utils.py). See huggingface/transformers#45290. | ||
| "content": "", | ||
| "tool_calls": dummy_tool_calls, | ||
| }, | ||
| ] | ||
| prefix_ids = self.tokenizer.apply_chat_template( | ||
| dummy_messages, | ||
| return_dict=False, | ||
| tools=self.tools or None, # `or None`: Llama bug: it renders tool boilerplate for tools=[] | ||
| add_generation_prompt=False, | ||
| tokenize=True, | ||
| chat_template=self.chat_template, | ||
| return_dict=False, | ||
| **self.chat_template_kwargs, | ||
| ) | ||
| full_ids = self.tokenizer.apply_chat_template( | ||
| dummy_messages + tool_messages, | ||
| return_dict=False, | ||
| chat_template=self.chat_template, | ||
| add_generation_prompt=True, | ||
| tools=self.tools or None, # `or None`: Llama bug: it renders tool boilerplate for tools=[] | ||
| tokenize=True, | ||
| chat_template=self.chat_template, | ||
| return_dict=False, | ||
| **self.chat_template_kwargs, | ||
| ) | ||
|
Comment on lines
556
to
584
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just to align GRPO and async GRPO |
||
|
|
||
| # Some chat templates (notably Qwen3/Qwen3.5) render "...<|im_end|>\n" after an assistant/tool block. | ||
| # When we compute `suffix_ids` by slicing `full_ids`, we must align the slicing boundary to | ||
| # EOS (not EOS + newline). | ||
| last_eos_idx = max(i for i, tok_id in enumerate(prefix_ids) if tok_id == self.tokenizer.eos_token_id) | ||
| prefix_ids = prefix_ids[: last_eos_idx + 1] | ||
| # EOS (not EOS + newline). Templates that don't use EOS as end-of-turn (e.g. Gemma uses | ||
| # <turn|>) skip this trimming. | ||
| eos_positions = [i for i, tok_id in enumerate(prefix_ids) if tok_id == self.tokenizer.eos_token_id] | ||
| if eos_positions: | ||
| prefix_ids = prefix_ids[: eos_positions[-1] + 1] | ||
|
Comment on lines
-580
to
+592
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| if full_ids[: len(prefix_ids)] != prefix_ids: | ||
| raise ValueError("Unexpected tokenization: the EOS-trimmed prefix IDs are not a prefix of the full IDs.") | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.