Adding support for fused moe kernels in sampler.#1385
Merged
copybara-service[bot] merged 1 commit intomainfrom Apr 22, 2026
Merged
Adding support for fused moe kernels in sampler.#1385copybara-service[bot] merged 1 commit intomainfrom
copybara-service[bot] merged 1 commit intomainfrom
Conversation
4629f76 to
d4b75a1
Compare
d4b75a1 to
3578a75
Compare
3578a75 to
e036a98
Compare
d44997f to
817eece
Compare
817eece to
957a534
Compare
957a534 to
432334f
Compare
wang2yn84
reviewed
Apr 22, 2026
| tensor_parallel_size=rollout_config.tensor_parallel_size, | ||
| data_parallel_size=rollout_config.data_parallel_size, | ||
| expert_parallel_size=rollout_config.expert_parallel_size, | ||
| rollout_chunk_size=rollout_config.rollout_vllm_reshard_chunk_size, |
Collaborator
There was a problem hiding this comment.
rollout_chunk_size seems doesn't need need any special handling, just plumbing it through, shall we pass it via rollout_vllm_kwargs?
Collaborator
Author
There was a problem hiding this comment.
you are right, we only need to plumb it through - however it is not a vLLM engine argument so it shouldn't be passed to the LLM() constructor. If we can remove it before passing rollout_vllm_kwargs to the constructor that would work, otherwise adding a new argument for it may be cleaner. LMK your thoughts :)
Collaborator
There was a problem hiding this comment.
My bad, I was looking at the wrong place. Plumbing it through is the right way to go.
deleting dst buffers during reshard.
432334f to
fa8a5a6
Compare
4 tasks
niting
added a commit
to niting/maxtext
that referenced
this pull request
Apr 23, 2026
This is required to support fused moe in MaxText. See Tunix PR: google/tunix#1385.
niting
added a commit
to niting/maxtext
that referenced
this pull request
Apr 23, 2026
This is required to support fused moe in MaxText. See Tunix PR: google/tunix#1385.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR introduces full support for the Fused MoE Kernel integrated to MaxText in AI-Hypercomputer/maxtext#3627.
More specifically, this PR introduces the ability to fuse MoE kernel weights in the MaxText model during weight sync to match the required shapes for the
tpu-inferenceFused MoE kernel.Additionally, this PR adds some optimizations for resharding, which are helpful when with large models. First, this PR removes the call to
jax.clear_caches()invllm_sampler.py. This removes the need to clear Jax compilation caches, which speeds up both training and rollout steps. The downside of this, is that there is more memory fragmentation present in sampler TPUs after clearing the KV cache. To get around fragmentation OOMs, we introduce chunked resharding to reduce the peak HBM consumed during reshard operations for large models.Checklist