Skip to content

[Llama4] Adopt @use_experts_implementation standard interface#45989

Closed
Chao1Han wants to merge 1 commit into
huggingface:mainfrom
Chao1Han:llama4_use_experts_implementation
Closed

[Llama4] Adopt @use_experts_implementation standard interface#45989
Chao1Han wants to merge 1 commit into
huggingface:mainfrom
Chao1Han:llama4_use_experts_implementation

Conversation

@Chao1Han
Copy link
Copy Markdown

@Chao1Han Chao1Han commented May 15, 2026

Authored by Claude (Anthropic)
This is an alternative approach to #45976

Motivation

Llama4's Llama4TextExperts currently uses a non-standard single-argument forward interface where routing weights are pre-multiplied into the hidden states before being passed to experts:

# Current (non-standard)
routed_in = hidden_states.repeat(num_experts, 1) * router_scores
expert_output = self.experts(routed_in)  # single pre-weighted tensor

This differs from all other MoE models in transformers (Gemma4, Qwen3-MoE, DeepSeek, PhiMoE, etc.) which use the standard 3-argument interface via @use_experts_implementation:

# Standard
expert_output = self.experts(hidden_states, top_k_index, top_k_weights)

This inconsistency causes problems:

  1. EP/TP hooks in tensor_parallel.py (MoeTensorParallelExperts) assume the standard 3-arg interface — requiring special-case handling for Llama4
  2. No backend switching — Llama4 cannot benefit from batched_mm or grouped_mm optimized implementations
  3. Router output mismatch — The original router returns (router_scores, router_logits) (2 values) while RouterParallel._prepare_output_fn expects (router_logits, router_scores, router_indices) (3 values), breaking EP mode entirely

Changes

modeling_llama4.py

  1. Llama4TextExperts: Add @use_experts_implementation(is_transposed=True) decorator and change forward signature to (self, hidden_states, top_k_index, top_k_weights)

    • is_transposed=True because Llama4 stores weights as [E, H, 2*I] instead of standard [E, 2*I, H]
    • The eager implementation uses sort-by-expert + padded bmm for efficient batched computation
  2. Llama4Router: Change return from (router_scores, router_logits) to (router_logits, top_k_weights, top_k_index) — consistent with other routers (Gemma4, etc.) and compatible with RouterParallel EP hooks

  3. Llama4TextMoe.forward: Simplify to standard calling convention:

    router_logits, top_k_weights, top_k_index = self.router(hidden_states)
    routed_out = self.experts(hidden_states, top_k_index, top_k_weights)

Benefits

  • Consistent interface across all MoE models in transformers
  • Backend switching via experts_implementation="batched_mm" / "grouped_mm" works out of the box
  • EP/TP compatibility — standard hooks in tensor_parallel.py work without special cases
  • No changes needed to tensor_parallel.py — unlike Fix MoeTensorParalellExperts crash for Llama4 pre-weighted MoE experts (EP support) #45976 which requires modifying MoeTensorParallelExperts to handle the pre-weighted single-tensor convention

Comparison with #45976

#45976 (fix_llama4_ep) This PR
Approach Fix hooks to handle Llama4's non-standard interface Make Llama4 conform to standard interface
Changes to tensor_parallel.py Yes (special case for len(inputs) == 1) No
Changes to modeling_llama4.py Minimal (router return fix only) Refactor experts + router + MoE forward
experts_implementation support
Future maintainability Each non-standard model needs hook changes Standard interface, no special cases

Testing

Verified locally:

  • experts_implementation=None (eager bmm): ✅ correct output
  • experts_implementation="batched_mm": ✅ numerically matches eager (max diff < 1e-8)
  • Full model forward pass with from_config(): ✅ logits shape correct
  • _can_set_experts_implementation() heuristic: ✅ correctly detects decorator

@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: llama4

Refactor Llama4TextExperts to use the standard @use_experts_implementation
decorator with the 3-argument forward interface (hidden_states, top_k_index,
top_k_weights), aligning with other MoE models (Gemma4, Qwen3-MoE, etc).

Key changes:
- Add @use_experts_implementation(is_transposed=True) to Llama4TextExperts
- Change experts forward from single pre-weighted tensor to standard 3-arg
- Update Llama4Router to return (router_logits, top_k_weights, top_k_index)
- Simplify Llama4TextMoe.forward to use standard calling convention

Benefits:
- Automatically supports batched_mm, grouped_mm backends via config
- Compatible with existing EP/TP hooks in tensor_parallel.py
- Consistent interface across all MoE models in transformers

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@Chao1Han Chao1Han force-pushed the llama4_use_experts_implementation branch from 2e36317 to e5c8338 Compare May 15, 2026 09:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants