Add Multi-Token Prediction (MTP) support for Qwen3.5#45638
Add Multi-Token Prediction (MTP) support for Qwen3.5#45638curnane-lab wants to merge 8 commits intohuggingface:mainfrom
Conversation
Add MTP architecture and loss computation for Qwen3.5 models, enabling multi-token prediction during training for improved efficiency. Changes: - Add Qwen3_5MTPLayer and Qwen3_5MTP module classes - Add shared _compute_qwen35_mtp_loss() helper function - Add MTP support to Qwen3_5ForCausalLM (text-only model) - Add MTP support to Qwen3_5ForConditionalGeneration (VL model) - Add mtp_num_hidden_layers and mtp_loss_weight config fields - Remove mtp from _keys_to_ignore_on_load_unexpected in CausalLM - Regenerate modeling_qwen3_5.py and configuration_qwen3_5.py
- Add mtp_num_hidden_layers and mtp_loss_weight docstrings to Qwen3_5TextConfig - Add mtp_num_hidden_layers and mtp_loss_weight docstrings to Qwen3_5Config - Fix ruff formatting in modular_qwen3_5.py - Regenerate modeling_qwen3_5.py and configuration_qwen3_5.py
Follow the MoE auxiliary loss pattern (e.g., Mixtral's output_router_logits) to add proper runtime control and output visibility for MTP loss: - Add output_mtp_loss config field (default: False) to Qwen3_5TextConfig and Qwen3_5Config, controlling whether MTP loss is computed - Add output_mtp_loss forward parameter to both CausalLM and VL models, overriding the config default at runtime - Qwen3_5ForCausalLM now returns MoeCausalLMOutputWithPast with aux_loss=mtp_loss (consistent with MoE models like Mixtral) - Qwen3_5ForConditionalGeneration now returns Qwen3_5VLCausalLMOutputWithPast with mtp_loss field (extends Qwen3VLCausalLMOutputWithPast) - MTP loss is only computed when output_mtp_loss=True, avoiding unnecessary computation when MTP is not needed - Regenerate modeling_qwen3_5.py and configuration_qwen3_5.py
Using MoeCausalLMOutputWithPast caused test_model_outputs_equivalence to fail because it has extra fields (aux_loss, router_logits) that change the tuple length when return_dict=False. Replace with custom output types that only add the mtp_loss field: - Qwen3_5CausalLMOutputWithPast extends CausalLMOutputWithPast - Qwen3_5VLCausalLMOutputWithPast extends Qwen3VLCausalLMOutputWithPast Both add only mtp_loss: torch.FloatTensor | None = None, keeping tuple ordering consistent with the parent class for backward compatibility. Regenerate modeling_qwen3_5.py and configuration_qwen3_5.py
…n, and documentation - Fix Qwen3_5CausalLMOutputWithPast to inherit from ModelOutput (not CausalLMOutputWithPast) and add @DataClass decorator so mtp_loss field is properly recognized - Add @DataClass decorator to Qwen3_5VLCausalLMOutputWithPast for proper field handling - Fix qwen3_5_moe modular conversion: align ModelOutput import with generated code - Add documentation entries for Qwen3_5CausalLMOutputWithPast, Qwen3_5VLCausalLMOutputWithPast, and Qwen3_5MTP in qwen3_5.md - Remove Qwen3_5MTPLayer from __all__ (internal implementation detail, not public API) - Regenerate modeling_qwen3_5.py and modeling_qwen3_5_moe.py
… import from qwen3_vl - Change Qwen3_5VLCausalLMOutputWithPast to inherit from ModelOutput directly instead of Qwen3VLCausalLMOutputWithPast, avoiding cross-model import - Define all fields explicitly (loss, mtp_loss, logits, past_key_values, hidden_states, attentions, rope_deltas) with proper docstring - Remove import of Qwen3VLCausalLMOutputWithPast from qwen3_vl.modular_qwen3_vl - Update both modular_qwen3_5.py and modeling_qwen3_5.py
- Add safety checks: MTP loss only computed when both labels and input_ids are not None, preventing crashes when using inputs_embeds or inference - Fix docstring: aux_loss -> mtp_loss in output_mtp_loss parameter doc - Add Qwen3_5MTPLayer to _no_split_modules for proper device splitting - Restore mtp pattern in _keys_to_ignore_on_load_unexpected for loading checkpoints without MTP weights - Use outputs.last_hidden_state instead of outputs[0] in VL model for consistency with CausalLM model
|
[For maintainers] Suggested jobs to run (before merge) run-slow: qwen3_5, qwen3_5_moe |
Note on relationship with PR #45618 (MTPCandidateGenerator)I noticed that PR #45618 by @ArthurZucker introduces a This PR (#45638) focuses on the training side of MTP:
PR #45618 focuses on the inference side of MTP:
These two approaches are complementary rather than conflicting — one addresses training-time MTP loss, the other addresses inference-time MTP speculative decoding. That said, I appreciate the architectural direction of PR #45618 in decoupling MTP from individual model files. Once #45618 is merged, I'm happy to adapt this PR to align with the I'd welcome any feedback from the maintainers on the preferred long-term architecture for MTP support in transformers. 🙏 |
Thanks for the quick response, @ArthurZucker — completely agree with the principle of keeping inference-specific and training-specific code separated. Shared MTP module location. If MTP layers (Qwen3_5MTPLayer, Qwen3_5MTP) live in a dedicated folder (e.g. transformers/mtp/ or under transformers/generation/mtp/), both this PR and #45618 could import from the same place. Training-side code would consume them to compute auxiliary loss; inference-side code would consume them through MTPCandidateGenerator. Happy to refactor in this direction if that matches what you have in mind. Could you let me know which layout you'd prefer once you've had a chance to think it over? I'd rather align with your preferred architecture upfront than refactor twice. 🙏 |
Add Multi-Token Prediction (MTP) support for Qwen3.5
This PR adds Multi-Token Prediction (MTP) architecture and loss computation for Qwen3.5 models, enabling multi-token prediction during training for improved efficiency.
Changes
New classes:
Qwen3_5MTPLayer: Single MTP transformer layer with attention and MLPQwen3_5MTP: Top-level MTP module with FC fusion, layers, and normNew shared helper:
_compute_qwen35_mtp_loss(): Shared MTP loss computation function used by both CausalLM and VL models, eliminating code duplicationModified models:
Qwen3_5ForCausalLM: Added MTP initialization and loss computation in forward passQwen3_5ForConditionalGeneration: Added MTP initialization and loss computation in forward passConfiguration:
mtp_num_hidden_layers(default: 0) andmtp_loss_weight(default: 0.0) to bothQwen3_5TextConfigandQwen3_5Configmtpfrom_keys_to_ignore_on_load_unexpectedinQwen3_5ForCausalLMso MTP weights are properly loaded from checkpointsDesign decisions
Shared loss function: The
_compute_qwen35_mtp_loss()helper eliminates code duplication between the text-only and VL models. Both models delegate to this shared function with their respectiveembed_tokensandrotary_embreferences.MTP loss stays in model files: Following the pattern of other auxiliary losses in transformers (e.g., MoE router losses), MTP loss is computed within the model's forward pass rather than in a separate trainer class.
Backward compatible: With
mtp_num_hidden_layers=0(default), MTP is disabled and the models behave identically to before.Checkpoint alignment: The MTP module structure aligns with the Qwen3.5 checkpoint format:
mtp.pre_fc_norm_hidden.*mtp.pre_fc_norm_embedding.*mtp.fc.*mtp.layers.N.*mtp.norm.*Testing
Tested with Qwen3.5-MTP model checkpoints to verify weight loading and loss computation.