Skip to content

Fix BF16_Optimizer last-microbatch grad leak under ZeRO-1#7985

Merged
delock merged 3 commits intodeepspeedai:masterfrom
maxyu1115:fix/bf16-optimizer-grad-accum-boundary-leak
Apr 29, 2026
Merged

Fix BF16_Optimizer last-microbatch grad leak under ZeRO-1#7985
delock merged 3 commits intodeepspeedai:masterfrom
maxyu1115:fix/bf16-optimizer-grad-accum-boundary-leak

Conversation

@maxyu1115
Copy link
Copy Markdown
Contributor

Fix BF16_Optimizer last-microbatch grad leak under ZeRO-1

Summary

DeepSpeedEngine._backward_epilogue calls allreduce_gradients() before optimizer.backward_epilogue(). For BF16_Optimizer (used when bf16 model + grad_accum_dtype: fp32 + ZeRO stage 1) without immediate_grad_update, this means the boundary microbatch's gradient is added to the rank-local fp32 accumulator AFTER the cross-rank allreduce, so it is silently skipped from the average.

The bias is (world_size − 1) / world_size × 1 / gradient_accumulation_steps of the per-step gradient. Because the bias scales with per-microbatch grad weight, training trajectories visibly diverge depending on per_device_train_batch_size even with identical effective batch size — the symptom users see is loss / grad-norm curves drifting apart between otherwise-equivalent configs.

The bug is reproducible in DeepSpeed 0.18.6 through current master (0.18.10 at time of writing).

Fix

Swap the order so optimizer.backward_epilogue() runs before allreduce_gradients(), with exit_backward() left after. exit_backward() only manages backward-hook state (_backward_hook_state); it has no ordering dependency on the gradient accumulator.

def _backward_epilogue(self):
    self._stop_timers(self.engine_timers.backward_inner_timers)
    self._start_timers(self.engine_timers.backward_reduce_timers)
    # NEW: run backward_epilogue() before allreduce so the boundary microbatch
    # grad lands in the optimizer accumulator that gets reduced.
    if isinstance(self.optimizer, ZeROOptimizer):
        self.optimizer.backward_epilogue()

    if self.enable_backward_allreduce and not self.inside_no_sync_ctxt:
        self.allreduce_gradients()

    if isinstance(self.optimizer, ZeROOptimizer):
        self.optimizer.exit_backward()

    see_memory_usage("Engine after backward", force=self.memory_breakdown())
    self._stop_timers(self.engine_timers.backward_reduce_timers)
    self._stop_timers(self.engine_timers.backward_timers)

Diff: +10 / −1, single file (deepspeed/runtime/engine.py).

Root cause walkthrough

The bug requires both of the following to be true:

  1. The accumulator that optimizer.backward_epilogue() mutates is the same tensor that engine.allreduce_gradients() later reduces, AND
  2. The accumulator is updated only by optimizer.backward_epilogue() (no per-param hooks updating it inline during backward).

Both conditions hold for BF16_Optimizer without immediate_grad_update:

  • It maintains a separate fp32 accumulator (fp32_groups_gradients_flat) — distinct from param.grad.
  • Its backward_epilogue() calls update_hp_grads() which casts each param's bf16 lp.grad to fp32 and adds it into that accumulator (and only this code path fills the accumulator when immediate_grad_update=False).
  • engine.allreduce_gradients()buffered_allreduce_fallback()optimizer.get_grads_for_reduction() returns the same non_expert_gradients list = fp32_groups_gradients_flat.

So on the gradient-accumulation boundary microbatch:

  1. loss.backward() populates bf16 lp.grad for that microbatch.
  2. _backward_epilogue first calls allreduce_gradients(). The fp32 accumulator at this point contains microbatches 0..ga-2's grads (summed locally on each rank). The allreduce averages only that across ranks.
  3. _backward_epilogue then calls optimizer.backward_epilogue()update_hp_grads() → adds the boundary microbatch's local lp.grad to the now-allreduced accumulator.

Result, per rank i:

fp32_buffer_rank_i = avg_ranks(Σ_{m=0..ga-2} grad_m) + local_grad_{ga-1}_rank_i
                     └────── shared across ranks ─────┘   └─── rank-private leak ───┘

ZeRO-1 partitions optimizer states across ranks, so each rank then runs optimizer.step() on its slice of this rank-divergent buffer; update_lp_params() allgathers the bf16 params back. The effective gradient applied to parameter p is:

g_p = avg_ranks(prior microbatches' grad for p) + local_grad_last_for_p_at_owning_rank

i.e. the boundary microbatch's contribution captures only 1 / world_size of the cross-rank average, biasing the global gradient by (world_size − 1) / world_size × 1 / ga_steps.

Impact on other optimizers (no behavior change)

Optimizer Accumulator How accumulator is filled Reduction path Affected by current bug? Affected by this fix?
BF16_Optimizer (immediate_grad_update=False) separate fp32_groups_gradients_flat only via optimizer.backward_epilogue()update_hp_grads() engine.allreduce_gradientsbuffered_allreduce_fallbackoptimizer.get_grads_for_reduction() returns the same fp32 buffer Yes — leak Yes — leak fixed
BF16_Optimizer (immediate_grad_update=True) same fp32 buffer per-param hooks (create_grad_acc_hooks) fire inline during backward same allreduce path No (hooks already filled buffer before allreduce) No-op (update_hp_grads early-returns when immediate_grad_update)
DeepSpeedZeroOptimizer_Stage1And2 (ZeRO-1, default for bf16+bf16-grad-accum) param.grad directly + ipg buckets hooks fire inline during backward (overlap_comm=True default), or reduce_gradients() walks all params at boundary engine.allreduce_gradients takes the if hasattr(self.optimizer, 'reduce_gradients') branch → optimizer.reduce_gradients() walks all params; the boundary microbatch's grad is already on param.grad (autograd populates this before _backward_epilogue runs) No No-op (Stage1And2.backward_epilogue does not mutate the reduction buffer)
DeepSpeedZeroOptimizer_Stage3 partitioned via overlapping_partition_gradients_reduce_epilogue() reduce-scatter inline during backward via hooks engine.allreduce_gradients takes the if zero_optimization_partition_gradients() branch → calls overlapping epilogue, which is fed by hooks No No-op

In short: the fix is functionally relevant only for BF16_Optimizer without immediate_grad_update. For every other ZeRO optimizer the change is observably a no-op because their backward_epilogue does not mutate the buffer being reduced.

Reproducer

The minimum reproducer is a 2-rank standalone script that runs one gradient-accumulation cycle and prints the per-rank fp32 accumulator norm at each microbatch and immediately before engine.step(). With the bug present the per-rank values disagree at the boundary microbatch and going into the optimizer step; with the fix they agree.

Save as probe_bf16_grad_accum.py:

"""Probe whether DeepSpeed's BF16_Optimizer leaks the boundary microbatch grad
out of the cross-rank average. Run with two ranks, e.g.:
    deepspeed --num_gpus 2 probe_bf16_grad_accum.py
or
    accelerate launch --num_processes 2 --num_machines 1 probe_bf16_grad_accum.py
"""
import os
import torch
import torch.nn as nn
import deepspeed
import torch.distributed as dist


def main():
    rank = int(os.environ.get("RANK", "0"))
    world = int(os.environ.get("WORLD_SIZE", "1"))
    GA_STEPS = 4
    HIDDEN = 64
    BATCH = 4

    torch.manual_seed(0)  # SAME init across ranks
    model = nn.Sequential(
        nn.Linear(HIDDEN, HIDDEN),
        nn.GELU(),
        nn.Linear(HIDDEN, HIDDEN),
    ).to(torch.bfloat16).cuda()

    ds_config = {
        "bf16": {"enabled": True},  # set "immediate_grad_update": True to also bypass the bug
        "data_types": {"grad_accum_dtype": "fp32"},
        "communication_data_type": "fp32",
        "zero_optimization": {
            "stage": 1,
            "overlap_comm": True,
            "contiguous_gradients": True,
            "reduce_scatter": True,
            "allgather_partitions": True,
            "allgather_bucket_size": 200000000,
            "reduce_bucket_size": 200000000,
        },
        "train_micro_batch_size_per_gpu": BATCH,
        "gradient_accumulation_steps": GA_STEPS,
        "train_batch_size": BATCH * world * GA_STEPS,
        "gradient_clipping": 0.0,
        "steps_per_print": 9999,
    }
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, betas=(0.9, 0.95))
    engine, _, _, _ = deepspeed.initialize(model=model, optimizer=optimizer, config=ds_config)
    bf16_opt = engine.optimizer
    if rank == 0:
        print(f"[INFO] DeepSpeed optimizer class: {type(bf16_opt).__name__}", flush=True)
        print(f"[INFO] grad_acc_dtype = {bf16_opt.grad_acc_dtype}", flush=True)
        print(f"[INFO] world={world} ga={GA_STEPS} batch={BATCH}", flush=True)

    def buffer_summary(label):
        bufs = bf16_opt.fp32_groups_gradients_flat
        local_norm = sum(b.detach().to(torch.float64).norm().item() ** 2 for b in bufs) ** 0.5
        rank_buf = torch.tensor([local_norm], device="cuda", dtype=torch.float64)
        gathered = [torch.zeros_like(rank_buf) for _ in range(world)]
        dist.all_gather(gathered, rank_buf)
        if rank == 0:
            vals = [g.item() for g in gathered]
            diffs = [(v - vals[0]) / vals[0] * 100 if vals[0] != 0 else 0 for v in vals]
            print(f"[{label}] per_rank={vals}  diff_pct_vs_rank0={diffs}", flush=True)
        dist.barrier()

    buffer_summary("init (zero)")

    # Different inputs per rank so per-rank grads differ.
    torch.manual_seed(100 + rank)
    inputs = [torch.randn(BATCH, HIDDEN, device="cuda", dtype=torch.bfloat16) for _ in range(GA_STEPS)]
    targets = [torch.randn(BATCH, HIDDEN, device="cuda", dtype=torch.bfloat16) for _ in range(GA_STEPS)]

    for i in range(GA_STEPS):
        is_boundary = (i == GA_STEPS - 1)
        engine.set_gradient_accumulation_boundary(is_boundary=is_boundary)  # what accelerate's wrapper does
        out = engine(inputs[i])
        loss = ((out - targets[i]) ** 2).mean()
        engine.backward(loss)
        buffer_summary(f"after backward microbatch {i} (boundary={is_boundary})")

    buffer_summary("BEFORE engine.step()")

    engine.step()


if __name__ == "__main__":
    main()

Verification

Probe (synthetic, 2 GPUs, 1 grad-accum cycle)

Run on master (bug):

[init (zero)]                            per_rank=[0.0,    0.0   ]   diff = 0%
[after microbatch 0  (boundary=False)]   per_rank=[0.1495, 0.1378]   diff = -7.78%   (local-only accumulation)
[after microbatch 1  (boundary=False)]   per_rank=[0.1998, 0.2098]   diff = +5.03%
[after microbatch 2  (boundary=False)]   per_rank=[0.2322, 0.2434]   diff = +4.82%
[after microbatch 3  (boundary=True)]    per_rank=[0.2206, 0.2123]   diff = -3.80%   ← bug
[BEFORE engine.step()]                   per_rank=[0.2206, 0.2123]   diff = -3.80%   ← bug

Run on this PR (fixed):

[init (zero)]                            per_rank=[0.0,    0.0   ]   diff = 0%
[after microbatch 0  (boundary=False)]   per_rank=[0.1495, 0.1378]   diff = -7.78%
[after microbatch 1  (boundary=False)]   per_rank=[0.1998, 0.2098]   diff = +5.03%
[after microbatch 2  (boundary=False)]   per_rank=[0.2322, 0.2434]   diff = +4.82%
[after microbatch 3  (boundary=True)]    per_rank=[0.1942, 0.1942]   diff = 0.00%   ← fixed
[BEFORE engine.step()]                   per_rank=[0.1942, 0.1942]   diff = 0.00%   ← fixed

The same agreement is reproduced by the existing bf16: { immediate_grad_update: true } workaround, which uses per-param hooks to fill the fp32 accumulator inline during backward (and is therefore not affected by the _backward_epilogue ordering).

End-to-end training (HuggingFace Trainer + accelerate + DeepSpeed, 2× A100)

A small custom Qwen3-derived model (~64M params, bf16, ZeRO-1 with grad_accum_dtype: fp32), 10 optimizer steps, identical seed and data ordering, identical effective batch size (global_batch_size = 64), only per_device_train_batch_size varies (so gradient_accumulation_steps = global_batch_size / (per_device * world_size) differs).

Configuration Run A: per_device=2, ga=16 Run B: per_device=8, ga=4 Final loss gap
DeepSpeed master + grad_accum_dtype: fp32 (broken) train_loss = 6.896 train_loss = 6.999 0.103
No DeepSpeed (DDP + native grad-accum, control) train_loss = 6.9035 train_loss = 6.9037 0.0002 (bf16 noise)
master + bf16.immediate_grad_update: true (existing workaround) train_loss = 6.9057 train_loss = 6.9057 < 0.0001
This PR + original config train_loss = 6.9057 train_loss = 6.9059 0.0002 (bf16 noise)

The broken case also produces qualitatively misleading instabilities — e.g. at step 5 in the broken run, B's grad-norm spikes to 17.0 vs A's 1.35 (≈ 12× ratio), while in the fixed case the two grad-norm trajectories agree to within bf16 noise at every step.

Per-step loss / grad-norm trajectories under the fixed engine (this PR), for completeness:

step A loss A gnorm B loss B gnorm
1 9.1907 8.0613 9.1907 8.0615
2 8.1962 5.0553 8.1961 5.0561
3 7.2035 2.2668 7.2035 2.2683
4 7.0588 3.1079 7.0618 3.1118
5 6.5661 2.6192 6.5627 2.4634
6 6.3097 1.8798 6.3086 1.8913
7 6.1332 1.2811 6.1317 1.2584
8 6.1297 2.8776 6.1305 2.9574
9 6.1739 1.4336 6.1748 1.4687
10 6.0950 1.5941 6.0957 1.6031

Notes

  • BF16_Optimizer users on master who are not pinning per_device_train_batch_size may see silently degraded training when sweeping per-device batch sizes (the symptom that triggered this investigation). The bug also makes per-rank model weights briefly diverge between the optimizer step and the next update_lp_params() allgather, which means cross-rank invariants (e.g. asserts that compare per-rank state) can flip behavior depending on gradient_accumulation_steps.
  • Tested on DeepSpeed 0.18.6 (where the bug was first observed) and confirmed unchanged on master (0.18.10).
  • No new tests are added in this PR, but a regression test that asserts cross-rank fp32 buffer agreement after the boundary microbatch in BF16_Optimizer would be a natural follow-up.

In `DeepSpeedEngine._backward_epilogue`, `allreduce_gradients()` ran before
`optimizer.backward_epilogue()`. For `BF16_Optimizer` (used when bf16 model +
grad_accum_dtype=fp32 + ZeRO-1) without `immediate_grad_update`, this means the
boundary microbatch's gradient is added to the rank-local fp32 accumulator
*after* the cross-rank allreduce, so it is silently skipped from the average.

Effect: each rank's fp32 buffer ends with
  fp32_buffer_rank_i = avg_ranks(sum_{m=0..ga-2} grad_m) + local_grad_{ga-1}_rank_i
which biases the global gradient by `(world_size-1)/world_size * 1/ga_steps`.
Because the bias scales with per-microbatch grad weight, training trajectories
diverge depending on `per_device_train_batch_size` even with identical effective
batch size — the symptom is loss/grad-norm curves that drift apart between
otherwise-equivalent configs.

Fix: call `optimizer.backward_epilogue()` *before* `allreduce_gradients()` so
the boundary microbatch's grad is in the buffer being reduced. `exit_backward()`
remains after the allreduce; it only manages backward-hook state and has no
ordering dependency on the accumulator.

This is a no-op for `DeepSpeedZeroOptimizer_Stage1And2` and Stage3, whose
`backward_epilogue` does not mutate the reduction buffer (their grads are
either on `param.grad` already populated by autograd or accumulated via
inline backward hooks). It is also a no-op for `BF16_Optimizer` with
`immediate_grad_update=true` because the per-param hooks already fill the
fp32 buffer synchronously during backward.

Signed-off-by: Max Yu <18641481+maxyu1115@users.noreply.github.com>
@delock
Copy link
Copy Markdown
Collaborator

delock commented Apr 28, 2026

Copy link
Copy Markdown
Collaborator

@tohtana tohtana left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your fix and thorough investigation, @maxyu1115! This is very important. Let's merge it and release a new version soon.

@tohtana
Copy link
Copy Markdown
Collaborator

tohtana commented Apr 28, 2026

Hi @maxyu1115,
I found the fix should be appiled to BF16_Optimizer path. That's why the CI test fails.
I opened a PR on your fork to patch it. Can you review and merge it if it is okay?

@maxyu1115 maxyu1115 requested a review from loadams as a code owner April 29, 2026 01:52
@delock delock merged commit 5999fb0 into deepspeedai:master Apr 29, 2026
9 checks passed
@maxyu1115 maxyu1115 deleted the fix/bf16-optimizer-grad-accum-boundary-leak branch April 29, 2026 03:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants