Skip to content

fix: output router capture wrong router logits in qwen moe models#43542

Open
ITcarrot wants to merge 1 commit intohuggingface:mainfrom
ITcarrot:fix-qwen3-moe-capture-wrong-router-logit
Open

fix: output router capture wrong router logits in qwen moe models#43542
ITcarrot wants to merge 1 commit intohuggingface:mainfrom
ITcarrot:fix-qwen3-moe-capture-wrong-router-logit

Conversation

@ITcarrot
Copy link
Copy Markdown

What does this PR do?

This PR fixes a bug in the router implementation of several MoE models (Qwen Moe like models, Olmoe, FlexOlmo).

Previously, the raw router_logits were being overwritten by the result of the softmax operation:

router_logits = F.linear(hidden_states, self.weight)
router_logits = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1)

This behavior destroys the raw logits information, which is often required for:

  1. Casting returned router_top_value back to input type in router_top_value = router_top_value.to(router_logits.dtype)
  2. Calculating auxiliary losses (e.g., load balancing loss) which typically operate on raw logits.
  3. Returning correct router_logits in the model output if requested.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@Cyrilvallez @IlyasMoutawwakil

@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: flex_olmo, olmoe, qwen2_moe, qwen3_moe, qwen3_next, qwen3_omni_moe

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch yes, the logits should not go through softmax 😅

Could we have a small test added to qwen2 moe (since everyone else inherits from this) + are there any more such cases?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants