Skip to content

Fix load_balancing_loss_func incompatible with past_key_values#40908

Open
tkj666 wants to merge 1 commit intohuggingface:mainfrom
tkj666:fix-mixtral-past_key_values-and-output_router_logits-incompatible
Open

Fix load_balancing_loss_func incompatible with past_key_values#40908
tkj666 wants to merge 1 commit intohuggingface:mainfrom
tkj666:fix-mixtral-past_key_values-and-output_router_logits-incompatible

Conversation

@tkj666
Copy link
Copy Markdown

@tkj666 tkj666 commented Sep 16, 2025

What does this PR do?

Changes the way num_hidden_layers, batch_size and sequence_length are calculated, and slices attention_mask, so that the shapes of expert_attention_mask and expert_mask match, thus making it compatible with inference with past_key_values

Fixes #30731

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@ArthurZucker

@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: ernie4_5_moe, gpt_oss, minimax, mixtral, qwen3_moe, qwen3_next

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hey, sorry what's the motivation? Load balancing loss function is for training 😓

@tkj666
Copy link
Copy Markdown
Author

tkj666 commented Oct 1, 2025

hey, sorry what's the motivation? Load balancing loss function is for training 😓

Multiple models that support output_router_logits would call load_balancing_loss_func if output_router_logits == True without checking whether it's training or not, e.g. https://github.com/huggingface/transformers/blob/32567739740da86ddf96c60a23cf2d0494ce4145/src/transformers/models/qwen3_moe/modular_qwen3_moe.py#L255C2-L262C5 and https://github.com/tkj666/transformers/blob/f8bdbaf5e3b907258e2154ea16797c057090430c/src/transformers/models/mixtral/modular_mixtral.py#L430C9-L436C14

During inference, people may want to check the router logits, and currently due to the imcompatible shape of attention_mask and all_router_logits when past_key_values is enabled, it will raise an error. Since aux_loss does not depend on labels, it is also possible to check aux_loss in inference.

@ArthurZucker ArthurZucker added the Feature request Request for a new feature label Oct 16, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Feature request Request for a new feature

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Mixtral past_key_values and output_router_logits incompatible

2 participants