Fix BF16_Optimizer last-microbatch grad leak under ZeRO-1#7985
Merged
delock merged 3 commits intodeepspeedai:masterfrom Apr 29, 2026
Merged
Conversation
In `DeepSpeedEngine._backward_epilogue`, `allreduce_gradients()` ran before
`optimizer.backward_epilogue()`. For `BF16_Optimizer` (used when bf16 model +
grad_accum_dtype=fp32 + ZeRO-1) without `immediate_grad_update`, this means the
boundary microbatch's gradient is added to the rank-local fp32 accumulator
*after* the cross-rank allreduce, so it is silently skipped from the average.
Effect: each rank's fp32 buffer ends with
fp32_buffer_rank_i = avg_ranks(sum_{m=0..ga-2} grad_m) + local_grad_{ga-1}_rank_i
which biases the global gradient by `(world_size-1)/world_size * 1/ga_steps`.
Because the bias scales with per-microbatch grad weight, training trajectories
diverge depending on `per_device_train_batch_size` even with identical effective
batch size — the symptom is loss/grad-norm curves that drift apart between
otherwise-equivalent configs.
Fix: call `optimizer.backward_epilogue()` *before* `allreduce_gradients()` so
the boundary microbatch's grad is in the buffer being reduced. `exit_backward()`
remains after the allreduce; it only manages backward-hook state and has no
ordering dependency on the accumulator.
This is a no-op for `DeepSpeedZeroOptimizer_Stage1And2` and Stage3, whose
`backward_epilogue` does not mutate the reduction buffer (their grads are
either on `param.grad` already populated by autograd or accumulated via
inline backward hooks). It is also a no-op for `BF16_Optimizer` with
`immediate_grad_update=true` because the per-param hooks already fill the
fp32 buffer synchronously during backward.
Signed-off-by: Max Yu <18641481+maxyu1115@users.noreply.github.com>
Collaborator
|
Hi @tohtana, can you confirm whether this is an regression introduced from #7665 ? Thanks! |
tohtana
approved these changes
Apr 28, 2026
Collaborator
There was a problem hiding this comment.
Thank you for your fix and thorough investigation, @maxyu1115! This is very important. Let's merge it and release a new version soon.
Collaborator
|
Hi @maxyu1115, |
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Fix BF16_Optimizer last-microbatch grad leak under ZeRO-1
Summary
DeepSpeedEngine._backward_epiloguecallsallreduce_gradients()beforeoptimizer.backward_epilogue(). ForBF16_Optimizer(used whenbf16model +grad_accum_dtype: fp32+ ZeRO stage 1) withoutimmediate_grad_update, this means the boundary microbatch's gradient is added to the rank-local fp32 accumulator AFTER the cross-rank allreduce, so it is silently skipped from the average.The bias is
(world_size − 1) / world_size × 1 / gradient_accumulation_stepsof the per-step gradient. Because the bias scales with per-microbatch grad weight, training trajectories visibly diverge depending onper_device_train_batch_sizeeven with identical effective batch size — the symptom users see is loss / grad-norm curves drifting apart between otherwise-equivalent configs.The bug is reproducible in DeepSpeed 0.18.6 through current
master(0.18.10 at time of writing).Fix
Swap the order so
optimizer.backward_epilogue()runs beforeallreduce_gradients(), withexit_backward()left after.exit_backward()only manages backward-hook state (_backward_hook_state); it has no ordering dependency on the gradient accumulator.Diff: +10 / −1, single file (
deepspeed/runtime/engine.py).Root cause walkthrough
The bug requires both of the following to be true:
optimizer.backward_epilogue()mutates is the same tensor thatengine.allreduce_gradients()later reduces, ANDoptimizer.backward_epilogue()(no per-param hooks updating it inline during backward).Both conditions hold for
BF16_Optimizerwithoutimmediate_grad_update:fp32_groups_gradients_flat) — distinct fromparam.grad.backward_epilogue()callsupdate_hp_grads()which casts each param's bf16lp.gradto fp32 and adds it into that accumulator (and only this code path fills the accumulator whenimmediate_grad_update=False).engine.allreduce_gradients()→buffered_allreduce_fallback()→optimizer.get_grads_for_reduction()returns the samenon_expert_gradientslist =fp32_groups_gradients_flat.So on the gradient-accumulation boundary microbatch:
loss.backward()populates bf16lp.gradfor that microbatch._backward_epiloguefirst callsallreduce_gradients(). The fp32 accumulator at this point contains microbatches0..ga-2's grads (summed locally on each rank). The allreduce averages only that across ranks._backward_epiloguethen callsoptimizer.backward_epilogue()→update_hp_grads()→ adds the boundary microbatch's locallp.gradto the now-allreduced accumulator.Result, per rank
i:ZeRO-1 partitions optimizer states across ranks, so each rank then runs
optimizer.step()on its slice of this rank-divergent buffer;update_lp_params()allgathers the bf16 params back. The effective gradient applied to parameterpis:i.e. the boundary microbatch's contribution captures only
1 / world_sizeof the cross-rank average, biasing the global gradient by(world_size − 1) / world_size × 1 / ga_steps.Impact on other optimizers (no behavior change)
BF16_Optimizer(immediate_grad_update=False)fp32_groups_gradients_flatoptimizer.backward_epilogue()→update_hp_grads()engine.allreduce_gradients→buffered_allreduce_fallback→optimizer.get_grads_for_reduction()returns the same fp32 bufferBF16_Optimizer(immediate_grad_update=True)create_grad_acc_hooks) fire inline during backwardupdate_hp_gradsearly-returns whenimmediate_grad_update)DeepSpeedZeroOptimizer_Stage1And2(ZeRO-1, default for bf16+bf16-grad-accum)param.graddirectly + ipg bucketsoverlap_comm=Truedefault), orreduce_gradients()walks all params at boundaryengine.allreduce_gradientstakes theif hasattr(self.optimizer, 'reduce_gradients')branch →optimizer.reduce_gradients()walks all params; the boundary microbatch's grad is already onparam.grad(autograd populates this before_backward_epilogueruns)Stage1And2.backward_epiloguedoes not mutate the reduction buffer)DeepSpeedZeroOptimizer_Stage3overlapping_partition_gradients_reduce_epilogue()engine.allreduce_gradientstakes theif zero_optimization_partition_gradients()branch → calls overlapping epilogue, which is fed by hooksIn short: the fix is functionally relevant only for
BF16_Optimizerwithoutimmediate_grad_update. For every other ZeRO optimizer the change is observably a no-op because theirbackward_epiloguedoes not mutate the buffer being reduced.Reproducer
The minimum reproducer is a 2-rank standalone script that runs one gradient-accumulation cycle and prints the per-rank fp32 accumulator norm at each microbatch and immediately before
engine.step(). With the bug present the per-rank values disagree at the boundary microbatch and going into the optimizer step; with the fix they agree.Save as
probe_bf16_grad_accum.py:Verification
Probe (synthetic, 2 GPUs, 1 grad-accum cycle)
Run on
master(bug):Run on this PR (fixed):
The same agreement is reproduced by the existing
bf16: { immediate_grad_update: true }workaround, which uses per-param hooks to fill the fp32 accumulator inline during backward (and is therefore not affected by the_backward_epilogueordering).End-to-end training (HuggingFace Trainer + accelerate + DeepSpeed, 2× A100)
A small custom Qwen3-derived model (~64M params, bf16, ZeRO-1 with
grad_accum_dtype: fp32), 10 optimizer steps, identical seed and data ordering, identical effective batch size (global_batch_size = 64), onlyper_device_train_batch_sizevaries (sogradient_accumulation_steps = global_batch_size / (per_device * world_size)differs).master+grad_accum_dtype: fp32(broken)train_loss = 6.896train_loss = 6.999train_loss = 6.9035train_loss = 6.9037master+bf16.immediate_grad_update: true(existing workaround)train_loss = 6.9057train_loss = 6.9057train_loss = 6.9057train_loss = 6.9059The broken case also produces qualitatively misleading instabilities — e.g. at step 5 in the broken run, B's grad-norm spikes to 17.0 vs A's 1.35 (≈ 12× ratio), while in the fixed case the two grad-norm trajectories agree to within bf16 noise at every step.
Per-step loss / grad-norm trajectories under the fixed engine (this PR), for completeness:
Notes
BF16_Optimizerusers onmasterwho are not pinningper_device_train_batch_sizemay see silently degraded training when sweeping per-device batch sizes (the symptom that triggered this investigation). The bug also makes per-rank model weights briefly diverge between the optimizer step and the nextupdate_lp_params()allgather, which means cross-rank invariants (e.g. asserts that compare per-rank state) can flip behavior depending ongradient_accumulation_steps.master(0.18.10).BF16_Optimizerwould be a natural follow-up.