Skip to content

Add Multi-Token Prediction (MTP) support for Qwen3.5#45638

Open
curnane-lab wants to merge 8 commits intohuggingface:mainfrom
curnane-lab:feature/qwen35-mtp-support
Open

Add Multi-Token Prediction (MTP) support for Qwen3.5#45638
curnane-lab wants to merge 8 commits intohuggingface:mainfrom
curnane-lab:feature/qwen35-mtp-support

Conversation

@curnane-lab
Copy link
Copy Markdown

Add Multi-Token Prediction (MTP) support for Qwen3.5

This PR adds Multi-Token Prediction (MTP) architecture and loss computation for Qwen3.5 models, enabling multi-token prediction during training for improved efficiency.

Changes

New classes:

  • Qwen3_5MTPLayer: Single MTP transformer layer with attention and MLP
  • Qwen3_5MTP: Top-level MTP module with FC fusion, layers, and norm

New shared helper:

  • _compute_qwen35_mtp_loss(): Shared MTP loss computation function used by both CausalLM and VL models, eliminating code duplication

Modified models:

  • Qwen3_5ForCausalLM: Added MTP initialization and loss computation in forward pass
  • Qwen3_5ForConditionalGeneration: Added MTP initialization and loss computation in forward pass

Configuration:

  • Added mtp_num_hidden_layers (default: 0) and mtp_loss_weight (default: 0.0) to both Qwen3_5TextConfig and Qwen3_5Config
  • Removed mtp from _keys_to_ignore_on_load_unexpected in Qwen3_5ForCausalLM so MTP weights are properly loaded from checkpoints

Design decisions

  1. Shared loss function: The _compute_qwen35_mtp_loss() helper eliminates code duplication between the text-only and VL models. Both models delegate to this shared function with their respective embed_tokens and rotary_emb references.

  2. MTP loss stays in model files: Following the pattern of other auxiliary losses in transformers (e.g., MoE router losses), MTP loss is computed within the model's forward pass rather than in a separate trainer class.

  3. Backward compatible: With mtp_num_hidden_layers=0 (default), MTP is disabled and the models behave identically to before.

  4. Checkpoint alignment: The MTP module structure aligns with the Qwen3.5 checkpoint format:

    • mtp.pre_fc_norm_hidden.*
    • mtp.pre_fc_norm_embedding.*
    • mtp.fc.*
    • mtp.layers.N.*
    • mtp.norm.*

Testing

Tested with Qwen3.5-MTP model checkpoints to verify weight loading and loss computation.

mingliangfu and others added 8 commits April 24, 2026 23:12
Add MTP architecture and loss computation for Qwen3.5 models, enabling
multi-token prediction during training for improved efficiency.

Changes:
- Add Qwen3_5MTPLayer and Qwen3_5MTP module classes
- Add shared _compute_qwen35_mtp_loss() helper function
- Add MTP support to Qwen3_5ForCausalLM (text-only model)
- Add MTP support to Qwen3_5ForConditionalGeneration (VL model)
- Add mtp_num_hidden_layers and mtp_loss_weight config fields
- Remove mtp from _keys_to_ignore_on_load_unexpected in CausalLM
- Regenerate modeling_qwen3_5.py and configuration_qwen3_5.py
- Add mtp_num_hidden_layers and mtp_loss_weight docstrings to Qwen3_5TextConfig
- Add mtp_num_hidden_layers and mtp_loss_weight docstrings to Qwen3_5Config
- Fix ruff formatting in modular_qwen3_5.py
- Regenerate modeling_qwen3_5.py and configuration_qwen3_5.py
Follow the MoE auxiliary loss pattern (e.g., Mixtral's output_router_logits)
to add proper runtime control and output visibility for MTP loss:

- Add output_mtp_loss config field (default: False) to Qwen3_5TextConfig
  and Qwen3_5Config, controlling whether MTP loss is computed
- Add output_mtp_loss forward parameter to both CausalLM and VL models,
  overriding the config default at runtime
- Qwen3_5ForCausalLM now returns MoeCausalLMOutputWithPast with
  aux_loss=mtp_loss (consistent with MoE models like Mixtral)
- Qwen3_5ForConditionalGeneration now returns Qwen3_5VLCausalLMOutputWithPast
  with mtp_loss field (extends Qwen3VLCausalLMOutputWithPast)
- MTP loss is only computed when output_mtp_loss=True, avoiding
  unnecessary computation when MTP is not needed
- Regenerate modeling_qwen3_5.py and configuration_qwen3_5.py
Using MoeCausalLMOutputWithPast caused test_model_outputs_equivalence
to fail because it has extra fields (aux_loss, router_logits) that
change the tuple length when return_dict=False.

Replace with custom output types that only add the mtp_loss field:
- Qwen3_5CausalLMOutputWithPast extends CausalLMOutputWithPast
- Qwen3_5VLCausalLMOutputWithPast extends Qwen3VLCausalLMOutputWithPast

Both add only mtp_loss: torch.FloatTensor | None = None, keeping
tuple ordering consistent with the parent class for backward
compatibility.

Regenerate modeling_qwen3_5.py and configuration_qwen3_5.py
…n, and documentation

- Fix Qwen3_5CausalLMOutputWithPast to inherit from ModelOutput (not CausalLMOutputWithPast)
  and add @DataClass decorator so mtp_loss field is properly recognized
- Add @DataClass decorator to Qwen3_5VLCausalLMOutputWithPast for proper field handling
- Fix qwen3_5_moe modular conversion: align ModelOutput import with generated code
- Add documentation entries for Qwen3_5CausalLMOutputWithPast, Qwen3_5VLCausalLMOutputWithPast,
  and Qwen3_5MTP in qwen3_5.md
- Remove Qwen3_5MTPLayer from __all__ (internal implementation detail, not public API)
- Regenerate modeling_qwen3_5.py and modeling_qwen3_5_moe.py
… import from qwen3_vl

- Change Qwen3_5VLCausalLMOutputWithPast to inherit from ModelOutput directly
  instead of Qwen3VLCausalLMOutputWithPast, avoiding cross-model import
- Define all fields explicitly (loss, mtp_loss, logits, past_key_values,
  hidden_states, attentions, rope_deltas) with proper docstring
- Remove import of Qwen3VLCausalLMOutputWithPast from qwen3_vl.modular_qwen3_vl
- Update both modular_qwen3_5.py and modeling_qwen3_5.py
- Add safety checks: MTP loss only computed when both labels and input_ids
  are not None, preventing crashes when using inputs_embeds or inference
- Fix docstring: aux_loss -> mtp_loss in output_mtp_loss parameter doc
- Add Qwen3_5MTPLayer to _no_split_modules for proper device splitting
- Restore mtp pattern in _keys_to_ignore_on_load_unexpected for loading
  checkpoints without MTP weights
- Use outputs.last_hidden_state instead of outputs[0] in VL model for
  consistency with CausalLM model
@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: qwen3_5, qwen3_5_moe

@curnane-lab
Copy link
Copy Markdown
Author

Note on relationship with PR #45618 (MTPCandidateGenerator)

I noticed that PR #45618 by @ArthurZucker introduces a MTPCandidateGenerator class that integrates MTP into the generate() pipeline for speculative decoding. I want to clarify the complementary relationship between these two PRs:

This PR (#45638) focuses on the training side of MTP:

  • Adds MTP architecture modules (Qwen3_5MTPLayer, Qwen3_5MTP) to the model definition
  • Computes MTP auxiliary loss during training (controlled by output_mtp_loss config/parameter)
  • Follows the MoE auxiliary loss pattern (e.g., Mixtral's output_router_logits) for runtime control
  • Enables training with MTP loss for improved training efficiency

PR #45618 focuses on the inference side of MTP:

  • Introduces MTPCandidateGenerator as a universal CandidateGenerator subclass
  • Enables MTP-based speculative decoding during generate() for inference acceleration
  • Decouples MTP layers from model files into a reusable generator component
  • Adds GenerationMode.MTP_DECODING to the generation pipeline

These two approaches are complementary rather than conflicting — one addresses training-time MTP loss, the other addresses inference-time MTP speculative decoding.

That said, I appreciate the architectural direction of PR #45618 in decoupling MTP from individual model files. Once #45618 is merged, I'm happy to adapt this PR to align with the MTPCandidateGenerator architecture if the maintainers prefer, while preserving the training-time MTP loss computation that this PR enables.

I'd welcome any feedback from the maintainers on the preferred long-term architecture for MTP support in transformers. 🙏

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For training I need to think a bit good sir! We don't want to "pollute" the code with inference specific stuff so we might have a new folder for this

@curnane-lab
Copy link
Copy Markdown
Author

For training I need to think a bit good sir! We don't want to "pollute" the code with inference specific stuff so we might have a new folder for this

Thanks for the quick response, @ArthurZucker — completely agree with the principle of keeping inference-specific and training-specific code separated.
A few thoughts to help shape the direction, whenever you've had time to think it through:

Shared MTP module location. If MTP layers (Qwen3_5MTPLayer, Qwen3_5MTP) live in a dedicated folder (e.g. transformers/mtp/ or under transformers/generation/mtp/), both this PR and #45618 could import from the same place. Training-side code would consume them to compute auxiliary loss; inference-side code would consume them through MTPCandidateGenerator. Happy to refactor in this direction if that matches what you have in mind.
Training entry point. The training-side surface is small — essentially an output_mtp_loss flag (mirroring output_router_logits in Mixtral) and the loss computation in the forward pass. I can keep this minimal and model-agnostic so it doesn't bleed into inference paths.
Order of merging. I'm fine waiting for #45618 to land first and then rebasing this PR on top of the agreed structure, if that's easier for review. Alternatively, if you'd prefer to land the shared MTP module location as a small prerequisite PR, I can prepare that too.

Could you let me know which layout you'd prefer once you've had a chance to think it over? I'd rather align with your preferred architecture upfront than refactor twice. 🙏

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants