Skip to content

Adding support for fused moe kernels in sampler.#1385

Merged
copybara-service[bot] merged 1 commit intomainfrom
nicogrande/support-fused-moe
Apr 22, 2026
Merged

Adding support for fused moe kernels in sampler.#1385
copybara-service[bot] merged 1 commit intomainfrom
nicogrande/support-fused-moe

Conversation

@NicoGrande
Copy link
Copy Markdown
Collaborator

@NicoGrande NicoGrande commented Apr 9, 2026

This PR introduces full support for the Fused MoE Kernel integrated to MaxText in AI-Hypercomputer/maxtext#3627.

More specifically, this PR introduces the ability to fuse MoE kernel weights in the MaxText model during weight sync to match the required shapes for the tpu-inference Fused MoE kernel.

Additionally, this PR adds some optimizations for resharding, which are helpful when with large models. First, this PR removes the call to jax.clear_caches() in vllm_sampler.py. This removes the need to clear Jax compilation caches, which speeds up both training and rollout steps. The downside of this, is that there is more memory fragmentation present in sampler TPUs after clearing the KV cache. To get around fragmentation OOMs, we introduce chunked resharding to reduce the peak HBM consumed during reshard operations for large models.

Checklist

  • I have added all the necessary unit tests for my change.
  • I have verified that my change does not break existing code and all unit tests pass.
  • I have added all appropriate doc-strings/documentation.
  • My PR is based on the latest changes of the main branch (if unsure, rebase the code).
  • I have signed the Contributor License Agreement.
  • I have followed Contribution Guidelines.

@NicoGrande NicoGrande force-pushed the nicogrande/support-fused-moe branch from 4629f76 to d4b75a1 Compare April 10, 2026 21:34
@NicoGrande NicoGrande force-pushed the nicogrande/support-fused-moe branch from d4b75a1 to 3578a75 Compare April 13, 2026 16:28
@NicoGrande NicoGrande force-pushed the nicogrande/support-fused-moe branch from 3578a75 to e036a98 Compare April 14, 2026 01:01
@NicoGrande NicoGrande force-pushed the nicogrande/support-fused-moe branch from d44997f to 817eece Compare April 21, 2026 23:05
@NicoGrande NicoGrande force-pushed the nicogrande/support-fused-moe branch from 817eece to 957a534 Compare April 21, 2026 23:11
@NicoGrande NicoGrande force-pushed the nicogrande/support-fused-moe branch from 957a534 to 432334f Compare April 22, 2026 01:07
@NicoGrande NicoGrande marked this pull request as ready for review April 22, 2026 15:21
Comment thread tunix/rl/rollout/vllm_rollout.py Outdated
tensor_parallel_size=rollout_config.tensor_parallel_size,
data_parallel_size=rollout_config.data_parallel_size,
expert_parallel_size=rollout_config.expert_parallel_size,
rollout_chunk_size=rollout_config.rollout_vllm_reshard_chunk_size,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rollout_chunk_size seems doesn't need need any special handling, just plumbing it through, shall we pass it via rollout_vllm_kwargs?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right, we only need to plumb it through - however it is not a vLLM engine argument so it shouldn't be passed to the LLM() constructor. If we can remove it before passing rollout_vllm_kwargs to the constructor that would work, otherwise adding a new argument for it may be cleaner. LMK your thoughts :)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad, I was looking at the wrong place. Plumbing it through is the right way to go.

deleting dst buffers during reshard.
@NicoGrande NicoGrande force-pushed the nicogrande/support-fused-moe branch from 432334f to fa8a5a6 Compare April 22, 2026 22:55
@copybara-service copybara-service Bot merged commit 9d25e26 into main Apr 22, 2026
9 checks passed
niting added a commit to niting/maxtext that referenced this pull request Apr 23, 2026
This is required to support fused moe in MaxText. See Tunix PR:
google/tunix#1385.
niting added a commit to niting/maxtext that referenced this pull request Apr 23, 2026
This is required to support fused moe in MaxText. See Tunix PR:
google/tunix#1385.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants