Skip to content

Cast deferred wgrads to reduce_dtype before FSDP reduce-scatter#32

Merged
haok1402 merged 1 commit intomlc-ai:mainfrom
MasterJH5574:04-25-gpt-oss-wgrad
Apr 25, 2026
Merged

Cast deferred wgrads to reduce_dtype before FSDP reduce-scatter#32
haok1402 merged 1 commit intomlc-ai:mainfrom
MasterJH5574:04-25-gpt-oss-wgrad

Conversation

@MasterJH5574
Copy link
Copy Markdown
Member

GroupLinearFunc defers the weight gradient via WeightGradStore. The deferred wgrad runs in _weight_chunk after the per-step FSDP backward callback that would have cast unsharded_param.grad from param_dtype (bf16) to reduce_dtype (fp32), so the trailing weight write stays bf16. With FSDP MixedPrecisionPolicy(reduce_dtype=fp32) and a param group that mixes weights with bias-style params (gpt-oss gate_up_proj_bias / down_proj_bias), foreach_reduce then asserts on "reduce-scatter expects uniform gradient dtype but got {bf16, fp32}". deepseek/qwen3 only have GroupLinear weights in their expert FSDP groups, so all grads were uniform bf16 and the assertion was silent.

Fix in DualPipeV.run_post_backward: before invoking each FSDPParamGroup's post_backward, iterate fsdp_params and run accumulate_unsharded_grad_if_needed followed by to_accumulated_grad_if_needed. Order matters: to_accumulated replaces unsharded_accumulated_grad rather than adding, so calling it alone clobbers prior chunks' fp32 contributions; accumulate first folds the trailing bf16 write into the existing fp32 accumulator. Mathematically identical to FSDP's chunk_cat-time cast for the deepseek/qwen3 case (bf16 -> fp32 is exact).

Drive-by fix in test_fsdp.py: skip the cosine-diff check on params whose reference gradient max is below 1e-5. Reference accumulates grads in bf16 while DualPipeV accumulates in fp32, so cosine similarity on tiny gradients (e.g. gpt-oss router.bias, ~1e-8) is dominated by bf16 quantization noise rather than signal. Pre-existing issue, surfaced once the dtype assertion was unblocked.

Validated: test_fsdp on deepseek-v2-lite (largest diff 0.0007), qwen3-30b-a3b (0.0004), gpt-oss-20b (0.0026), all under eps=0.01. End-to-end gpt-oss-20b smoke (3 steps, pp=2/ep=2): loss 12.80 -> 11.18, gradient norm sane.

GroupLinearFunc defers the weight gradient via WeightGradStore. The
deferred wgrad runs in _weight_chunk after the per-step FSDP backward
callback that would have cast unsharded_param.grad from param_dtype
(bf16) to reduce_dtype (fp32), so the trailing weight write stays
bf16. With FSDP MixedPrecisionPolicy(reduce_dtype=fp32) and a param
group that mixes weights with bias-style params (gpt-oss
gate_up_proj_bias / down_proj_bias), foreach_reduce then asserts on
"reduce-scatter expects uniform gradient dtype but got {bf16, fp32}".
deepseek/qwen3 only have GroupLinear weights in their expert FSDP
groups, so all grads were uniform bf16 and the assertion was silent.

Fix in DualPipeV.run_post_backward: before invoking each FSDPParamGroup's
post_backward, iterate fsdp_params and run accumulate_unsharded_grad_if_needed
followed by to_accumulated_grad_if_needed. Order matters: to_accumulated
replaces unsharded_accumulated_grad rather than adding, so calling it
alone clobbers prior chunks' fp32 contributions; accumulate first folds
the trailing bf16 write into the existing fp32 accumulator. Mathematically
identical to FSDP's chunk_cat-time cast for the deepseek/qwen3 case
(bf16 -> fp32 is exact).

Drive-by fix in test_fsdp.py: skip the cosine-diff check on params whose
reference gradient max is below 1e-5. Reference accumulates grads in
bf16 while DualPipeV accumulates in fp32, so cosine similarity on tiny
gradients (e.g. gpt-oss router.bias, ~1e-8) is dominated by bf16
quantization noise rather than signal. Pre-existing issue, surfaced
once the dtype assertion was unblocked.

Validated: test_fsdp on deepseek-v2-lite (largest diff 0.0007),
qwen3-30b-a3b (0.0004), gpt-oss-20b (0.0026), all under eps=0.01.
End-to-end gpt-oss-20b smoke (3 steps, pp=2/ep=2): loss 12.80 -> 11.18,
gradient norm sane.
Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements manual gradient accumulation and conversion for FSDP parameters in dualpipev.py to resolve dtype mismatches in mixed precision. It also updates the test suite to ignore gradient noise floor values. Feedback recommends adding attribute checks for internal PyTorch methods to improve version compatibility.

Comment on lines +574 to +577
for fsdp_param in state._fsdp_param_group.fsdp_params:
if hasattr(fsdp_param, "_unsharded_param"):
fsdp_param.accumulate_unsharded_grad_if_needed()
fsdp_param.to_accumulated_grad_if_needed()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The manual invocation of accumulate_unsharded_grad_if_needed and to_accumulated_grad_if_needed correctly addresses the dtype mismatch issue caused by deferred weight gradients in FSDP mixed precision. However, since these are internal methods of FSDPParam, it would be safer to verify their existence on the fsdp_param object before calling them, especially to maintain compatibility across different PyTorch versions where internal APIs might shift.

Suggested change
for fsdp_param in state._fsdp_param_group.fsdp_params:
if hasattr(fsdp_param, "_unsharded_param"):
fsdp_param.accumulate_unsharded_grad_if_needed()
fsdp_param.to_accumulated_grad_if_needed()
for fsdp_param in state._fsdp_param_group.fsdp_params:
if hasattr(fsdp_param, "accumulate_unsharded_grad_if_needed") and \
hasattr(fsdp_param, "to_accumulated_grad_if_needed"):
fsdp_param.accumulate_unsharded_grad_if_needed()
fsdp_param.to_accumulated_grad_if_needed()

@haok1402 haok1402 merged commit 078f90b into mlc-ai:main Apr 25, 2026
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants