Skip to content

Enable wgrad-delay for gpt-oss experts via GroupLinearFunc#29

Merged
haok1402 merged 1 commit intomlc-ai:mainfrom
MasterJH5574:04-24-gpt-oss-grouped-mm
Apr 24, 2026
Merged

Enable wgrad-delay for gpt-oss experts via GroupLinearFunc#29
haok1402 merged 1 commit intomlc-ai:mainfrom
MasterJH5574:04-24-gpt-oss-grouped-mm

Conversation

@MasterJH5574
Copy link
Copy Markdown
Member

Route the two GPT-OSS expert GEMMs (gate_up_proj, down_proj) through GroupLinearFunc.apply instead of F.grouped_mm, so the wgrad-delay path added in #28 applies to gpt-oss too. The per-expert bias add stays on the caller side (bias[group_ids]) rather than being folded into the autograd Function: index_add_ for bgrad is cheap compared to the grouped_mm wgrad, so deferring it buys little but complicates both the Function signature and the gpt-oss call sites.

Add test_gpt_oss_experts_weight_grad_store_matches_direct as an integration check: it runs GptOssExperts with WeightGradStore on vs off and asserts that forward output and input/weight/bias gradients all match (weights tightly through deterministic grouped_mm, biases with ~5% bf16 slack since CUDA index_add_ is non-deterministic). It also verifies that gate_up_proj/down_proj grads are deferred before flush/pop while the bias grads remain eager, documenting the split.

Drive-by: fix the stale assertion in test_scatter_for_grouped_gemm that was left over from #16. The scatter output has been rounded up to _GEMM_ALLOC_ALIGNMENT (for CUDA allocator locality) since that commit and the tail rows are zeroed by the kernel, but the test still required out.shape[0] == offs[-1]. Replace it with the actual contract: shape is at least offs[-1], aligned to
_GEMM_ALLOC_ALIGNMENT, with the over-allocated tail all-zero.

Route the two GPT-OSS expert GEMMs (gate_up_proj, down_proj) through
GroupLinearFunc.apply instead of F.grouped_mm, so the wgrad-delay path
added in mlc-ai#28 applies to gpt-oss too. The per-expert bias add stays on
the caller side (bias[group_ids]) rather than being folded into the
autograd Function: index_add_ for bgrad is cheap compared to the
grouped_mm wgrad, so deferring it buys little but complicates both the
Function signature and the gpt-oss call sites.

Add test_gpt_oss_experts_weight_grad_store_matches_direct as an
integration check: it runs GptOssExperts with WeightGradStore on vs
off and asserts that forward output and input/weight/bias gradients
all match (weights tightly through deterministic grouped_mm, biases
with ~5% bf16 slack since CUDA index_add_ is non-deterministic). It
also verifies that gate_up_proj/down_proj grads are deferred before
flush/pop while the bias grads remain eager, documenting the split.

Drive-by: fix the stale assertion in test_scatter_for_grouped_gemm
that was left over from mlc-ai#16. The scatter output has been rounded up
to _GEMM_ALLOC_ALIGNMENT (for CUDA allocator locality) since that
commit and the tail rows are zeroed by the kernel, but the test
still required out.shape[0] == offs[-1]. Replace it with the actual
contract: shape is at least offs[-1], aligned to
_GEMM_ALLOC_ALIGNMENT, with the over-allocated tail all-zero.
Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request transitions the GptOssExperts model to use GroupLinearFunc.apply instead of direct F.grouped_mm calls, facilitating deferred weight gradient storage. Additionally, it updates the scatter_for_grouped_gemm test to support memory alignment and introduces an end-to-end test to ensure gradient correctness with WeightGradStore. I have no feedback to provide.

@haok1402 haok1402 merged commit 7e952c4 into mlc-ai:main Apr 24, 2026
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants