Skip to content

Fix some MoE routers#43445

Merged
vasqu merged 8 commits intomainfrom
fix-broken-eager-experts
Jan 27, 2026
Merged

Fix some MoE routers#43445
vasqu merged 8 commits intomainfrom
fix-broken-eager-experts

Conversation

@IlyasMoutawwakil
Copy link
Member

@IlyasMoutawwakil IlyasMoutawwakil commented Jan 23, 2026

What does this PR do?

Same as #43288
Also fixes phimoe and its integration tests

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

)
return selected_experts, routing_weights.to(hidden_states.dtype)

return selected_experts, routing_weights.to(hidden_states.dtype)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's interesting that this dead code is not caught by styling

router_logits,
jitter_eps=self.router_jitter_noise,
training=self.training,
router_logits, jitter_eps=self.router_jitter_noise, training=self.training, top_k=self.top_k
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

top_k was not passed and defaulted to 2

@IlyasMoutawwakil
Copy link
Member Author

run-slow: hunyuan_v1_moe, phimoe

@github-actions
Copy link
Contributor

This comment contains run-slow, running the specified jobs:

models: ["models/hunyuan_v1_moe", "models/phimoe"]
quantizations: []

@github-actions
Copy link
Contributor

CI Results

Workflow Run ⚙️

✅ No failing test specific to this PR 🎉 !

@IlyasMoutawwakil
Copy link
Member Author

both models integration tests pass locally now

Comment on lines +554 to +558
# Phimoe uses nn.LayerNorm
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps, elementwise_affine=True)
self.post_attention_layernorm = nn.LayerNorm(
config.hidden_size, eps=config.rms_norm_eps, elementwise_affine=True
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Phimoe uses nn.LayerNorm with bias

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for checking another set of models, I guess we have them all fixed with this?

The phimoe one is an interesting case 😅

expert_ids = top_k_index.reshape(-1)
token_idx = torch.arange(num_tokens, device=device).unsqueeze(1).expand(-1, num_top_k).reshape(-1)

# Resolve routing weights per selected sample, allowing top_k_weights to be either:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Glad to remove this 🙏

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah much cleaner !

Comment on lines +336 to +340
# Phimoe uses nn.LayerNorm
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps, elementwise_affine=True)
self.post_attention_layernorm = nn.LayerNorm(
config.hidden_size, eps=config.rms_norm_eps, elementwise_affine=True
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh wow, that's an insane find - this model must have been broken for a long time

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeh ppl must've continued to use the remote code one

Comment on lines 90 to -92
def test_model_generation(self):
# we will compele this when model file change over
# pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a PR here #43411 which fixes some wrong RoPE init, will this still work with that fix?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i guess, here I only removed a comment, I might have misunderstood but the initialization added in the PR is already in the init of the class, why is it necessary in init_weights as well ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apparently, it doesn't matter anymore what you init in __init__ and _init_weights will overwrite in any case - meaning that if there is custom logic in init, it will not be applied. I want to refactor this so that it is no longer the case or rather that we depend on another init function for rope that allows users to do whatever they want to as init

Comment on lines 147 to 148
[-3.4844, -2.4688, -1.1719, 0.5703, -0.4902, -0.0942, 0.7773, -0.2539, 0.3223, -1.0234],
[-0.9805, 0.0811, -0.5273, 2.3438, 0.6914, 3.0781, 0.3164, 0.2197, 0.5312, -2.1094],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess that's the gpu diff between t4 and a10

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yes, rerunning with eager experts impl to see if grouped is also contributing to the diff.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, grouped and eager are equivalent in terms of logits here (on A100). should I revert this change ? or maybe use the expectation class.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should check against A10 runners - let's revert for now and check with run-slow first. If it needs a change, I can update it

@IlyasMoutawwakil
Copy link
Member Author

run-slow: hunyuan_v1_moe, phimoe

@github-actions
Copy link
Contributor

This comment contains run-slow, running the specified jobs:

models: ["models/hunyuan_v1_moe", "models/phimoe"]
quantizations: []

@github-actions
Copy link
Contributor

CI Results

Workflow Run ⚙️

✅ No failing test specific to this PR 🎉 !

@IlyasMoutawwakil
Copy link
Member Author

the ci is struggling with model loading again 😭, even though device_map="auto"

@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: hunyuan_v1_moe, phimoe

sample_weights = sample_weights.reshape(-1, 1) # (S, 1)

# Reshape for easier indexing
# S is the number of selected tokens-experts pairs (S = num_tokens * num_top_k)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clarification taken from #43439

@vasqu vasqu merged commit 0dbb56e into main Jan 27, 2026
21 of 26 checks passed
@vasqu vasqu deleted the fix-broken-eager-experts branch January 27, 2026 13:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants