Cast deferred wgrads to reduce_dtype before FSDP reduce-scatter#32
Cast deferred wgrads to reduce_dtype before FSDP reduce-scatter#32haok1402 merged 1 commit intomlc-ai:mainfrom
Conversation
GroupLinearFunc defers the weight gradient via WeightGradStore. The
deferred wgrad runs in _weight_chunk after the per-step FSDP backward
callback that would have cast unsharded_param.grad from param_dtype
(bf16) to reduce_dtype (fp32), so the trailing weight write stays
bf16. With FSDP MixedPrecisionPolicy(reduce_dtype=fp32) and a param
group that mixes weights with bias-style params (gpt-oss
gate_up_proj_bias / down_proj_bias), foreach_reduce then asserts on
"reduce-scatter expects uniform gradient dtype but got {bf16, fp32}".
deepseek/qwen3 only have GroupLinear weights in their expert FSDP
groups, so all grads were uniform bf16 and the assertion was silent.
Fix in DualPipeV.run_post_backward: before invoking each FSDPParamGroup's
post_backward, iterate fsdp_params and run accumulate_unsharded_grad_if_needed
followed by to_accumulated_grad_if_needed. Order matters: to_accumulated
replaces unsharded_accumulated_grad rather than adding, so calling it
alone clobbers prior chunks' fp32 contributions; accumulate first folds
the trailing bf16 write into the existing fp32 accumulator. Mathematically
identical to FSDP's chunk_cat-time cast for the deepseek/qwen3 case
(bf16 -> fp32 is exact).
Drive-by fix in test_fsdp.py: skip the cosine-diff check on params whose
reference gradient max is below 1e-5. Reference accumulates grads in
bf16 while DualPipeV accumulates in fp32, so cosine similarity on tiny
gradients (e.g. gpt-oss router.bias, ~1e-8) is dominated by bf16
quantization noise rather than signal. Pre-existing issue, surfaced
once the dtype assertion was unblocked.
Validated: test_fsdp on deepseek-v2-lite (largest diff 0.0007),
qwen3-30b-a3b (0.0004), gpt-oss-20b (0.0026), all under eps=0.01.
End-to-end gpt-oss-20b smoke (3 steps, pp=2/ep=2): loss 12.80 -> 11.18,
gradient norm sane.
There was a problem hiding this comment.
Code Review
This pull request implements manual gradient accumulation and conversion for FSDP parameters in dualpipev.py to resolve dtype mismatches in mixed precision. It also updates the test suite to ignore gradient noise floor values. Feedback recommends adding attribute checks for internal PyTorch methods to improve version compatibility.
| for fsdp_param in state._fsdp_param_group.fsdp_params: | ||
| if hasattr(fsdp_param, "_unsharded_param"): | ||
| fsdp_param.accumulate_unsharded_grad_if_needed() | ||
| fsdp_param.to_accumulated_grad_if_needed() |
There was a problem hiding this comment.
The manual invocation of accumulate_unsharded_grad_if_needed and to_accumulated_grad_if_needed correctly addresses the dtype mismatch issue caused by deferred weight gradients in FSDP mixed precision. However, since these are internal methods of FSDPParam, it would be safer to verify their existence on the fsdp_param object before calling them, especially to maintain compatibility across different PyTorch versions where internal APIs might shift.
| for fsdp_param in state._fsdp_param_group.fsdp_params: | |
| if hasattr(fsdp_param, "_unsharded_param"): | |
| fsdp_param.accumulate_unsharded_grad_if_needed() | |
| fsdp_param.to_accumulated_grad_if_needed() | |
| for fsdp_param in state._fsdp_param_group.fsdp_params: | |
| if hasattr(fsdp_param, "accumulate_unsharded_grad_if_needed") and \ | |
| hasattr(fsdp_param, "to_accumulated_grad_if_needed"): | |
| fsdp_param.accumulate_unsharded_grad_if_needed() | |
| fsdp_param.to_accumulated_grad_if_needed() |
GroupLinearFunc defers the weight gradient via WeightGradStore. The deferred wgrad runs in _weight_chunk after the per-step FSDP backward callback that would have cast unsharded_param.grad from param_dtype (bf16) to reduce_dtype (fp32), so the trailing weight write stays bf16. With FSDP MixedPrecisionPolicy(reduce_dtype=fp32) and a param group that mixes weights with bias-style params (gpt-oss gate_up_proj_bias / down_proj_bias), foreach_reduce then asserts on "reduce-scatter expects uniform gradient dtype but got {bf16, fp32}". deepseek/qwen3 only have GroupLinear weights in their expert FSDP groups, so all grads were uniform bf16 and the assertion was silent.
Fix in DualPipeV.run_post_backward: before invoking each FSDPParamGroup's post_backward, iterate fsdp_params and run accumulate_unsharded_grad_if_needed followed by to_accumulated_grad_if_needed. Order matters: to_accumulated replaces unsharded_accumulated_grad rather than adding, so calling it alone clobbers prior chunks' fp32 contributions; accumulate first folds the trailing bf16 write into the existing fp32 accumulator. Mathematically identical to FSDP's chunk_cat-time cast for the deepseek/qwen3 case (bf16 -> fp32 is exact).
Drive-by fix in test_fsdp.py: skip the cosine-diff check on params whose reference gradient max is below 1e-5. Reference accumulates grads in bf16 while DualPipeV accumulates in fp32, so cosine similarity on tiny gradients (e.g. gpt-oss router.bias, ~1e-8) is dominated by bf16 quantization noise rather than signal. Pre-existing issue, surfaced once the dtype assertion was unblocked.
Validated: test_fsdp on deepseek-v2-lite (largest diff 0.0007), qwen3-30b-a3b (0.0004), gpt-oss-20b (0.0026), all under eps=0.01. End-to-end gpt-oss-20b smoke (3 steps, pp=2/ep=2): loss 12.80 -> 11.18, gradient norm sane.