Skip to content

Reduce MoE activation memory in DualPipeV#16

Merged
haok1402 merged 1 commit intomlc-ai:mainfrom
MasterJH5574:04-12-memory-opt
Apr 13, 2026
Merged

Reduce MoE activation memory in DualPipeV#16
haok1402 merged 1 commit intomlc-ai:mainfrom
MasterJH5574:04-12-memory-opt

Conversation

@MasterJH5574
Copy link
Copy Markdown
Member

@MasterJH5574 MasterJH5574 commented Apr 13, 2026

  • Remove EpilogOuts/epilog_b (dead code; logits not needed after loss)
  • Remove Stage2/Stage4 args and outs (only ctx needed for a2a backward)
  • Add padded_index_gather: avoids saving input for backward + reduces CUDA allocator fragmentation via fixed-alignment output buffers
  • Free stage-boundary tensor storage early via untyped_storage().resize_(0) after async all-to-all consumers complete (sorted_tokens after Stage 2, gathered_tokens after Stage 3, stage3 moe_outs after Stage 4)
  • Add fwd_comm_deferred_free to ExecutionCtx for clean deferred-free in the overlapped path without changing stage function signatures
  • Add layer_partition() for memory-aware pipeline stage assignment
  • Add per-layer saved-tensor profiling with weight/activation distinction

Co-authored-by: Hao Kang 89672451+haok1402@users.noreply.github.com

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements significant memory optimizations and architectural enhancements for the DualPipeV pipeline, including a fused cross-entropy Triton kernel, a memory-efficient padded index gather operator, and a balanced layer partitioning algorithm. It also introduces manual storage management and deferred tensor freeing to reduce peak memory consumption. Review feedback highlights critical safety issues where early storage resizing or deferred freeing could cause crashes when expert parallelism is disabled, and suggests performance refinements for the Triton kernels.

Comment thread pithtrain/dualpipe/execution.py Outdated
Comment thread pithtrain/dualpipe/modeling.py Outdated
Comment thread pithtrain/dualpipe/modeling.py Outdated
Comment thread pithtrain/dualpipe/overlap.py Outdated
Comment thread pithtrain/dualpipe/overlap.py Outdated
Comment thread pithtrain/operators/cross_entropy.py
Comment thread pithtrain/operators/cross_entropy.py Outdated
@haok1402
Copy link
Copy Markdown
Collaborator

The cross entropy loss using 4k sequence looks correct after a couple steps from the released checkpoint.

2026-04-12 22:39:00 | INFO | step 00000011/00000015 | step-time 17.691 sec | cross-entropy-loss 2.5770 | learning-rate 1.000000e-06 | gradient-norm 696.0355 | tokens-per-second 237,090 | peak-gpu-memory 58.31 GB
2026-04-12 22:39:18 | INFO | step 00000012/00000015 | step-time 17.741 sec | cross-entropy-loss 2.5923 | learning-rate 1.000000e-06 | gradient-norm 1003.7461 | tokens-per-second 236,419 | peak-gpu-memory 58.37 GB
2026-04-12 22:39:36 | INFO | step 00000013/00000015 | step-time 18.392 sec | cross-entropy-loss 2.6027 | learning-rate 1.000000e-06 | gradient-norm 1149.4879 | tokens-per-second 228,045 | peak-gpu-memory 58.51 GB
2026-04-12 22:39:54 | INFO | step 00000014/00000015 | step-time 17.032 sec | cross-entropy-loss 2.5203 | learning-rate 1.000000e-06 | gradient-norm 713.4630 | tokens-per-second 246,259 | peak-gpu-memory 58.26 GB
2026-04-12 22:40:12 | INFO | step 00000015/00000015 | step-time 18.045 sec | cross-entropy-loss 2.5940 | learning-rate 1.000000e-06 | gradient-norm 937.0814 | tokens-per-second 232,441 | peak-gpu-memory 58.43 GB

Comment thread pithtrain/dualpipe/layer_partition.py
Comment thread pithtrain/dualpipe/execution.py
Comment thread pithtrain/operators/all_to_all.py
Comment thread tests/test_layer_partition.py
Comment thread pithtrain/models/deepseek_v2_lite.py
- Remove EpilogOuts/epilog_b (dead code; logits not needed after loss)
- Remove Stage2/Stage4 args and outs (only ctx needed for a2a backward)
- Add padded_index_gather: avoids saving input for backward + reduces
  CUDA allocator fragmentation via fixed-alignment output buffers
- Free stage-boundary tensor storage early via untyped_storage().resize_(0)
  after async all-to-all consumers complete (sorted_tokens after Stage 2,
  gathered_tokens after Stage 3, stage3 moe_outs after Stage 4)
- Add fwd_comm_deferred_free to ExecutionCtx for clean deferred-free in
  the overlapped path without changing stage function signatures
- Add layer_partition() for memory-aware pipeline stage assignment
- Add per-layer saved-tensor profiling with weight/activation distinction

Co-Authored-By: Hao Kang <89672451+haok1402@users.noreply.github.com>
@haok1402 haok1402 merged commit 331e748 into mlc-ai:main Apr 13, 2026
1 check passed
haok1402 pushed a commit that referenced this pull request Apr 24, 2026
Route the two GPT-OSS expert GEMMs (gate_up_proj, down_proj) through
GroupLinearFunc.apply instead of F.grouped_mm, so the wgrad-delay path
added in #28 applies to gpt-oss too. The per-expert bias add stays on
the caller side (bias[group_ids]) rather than being folded into the
autograd Function: index_add_ for bgrad is cheap compared to the
grouped_mm wgrad, so deferring it buys little but complicates both the
Function signature and the gpt-oss call sites.

Add test_gpt_oss_experts_weight_grad_store_matches_direct as an
integration check: it runs GptOssExperts with WeightGradStore on vs
off and asserts that forward output and input/weight/bias gradients
all match (weights tightly through deterministic grouped_mm, biases
with ~5% bf16 slack since CUDA index_add_ is non-deterministic). It
also verifies that gate_up_proj/down_proj grads are deferred before
flush/pop while the bias grads remain eager, documenting the split.

Drive-by: fix the stale assertion in test_scatter_for_grouped_gemm
that was left over from #16. The scatter output has been rounded up
to _GEMM_ALLOC_ALIGNMENT (for CUDA allocator locality) since that
commit and the tail rows are zeroed by the kernel, but the test
still required out.shape[0] == offs[-1]. Replace it with the actual
contract: shape is at least offs[-1], aligned to
_GEMM_ALLOC_ALIGNMENT, with the over-allocated tail all-zero.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants