Skip to content

Commit

Permalink
[fix] SDP syncing buffers during gradient accumulation (#1075)
Browse files Browse the repository at this point in the history
- Fixes from Benjamin.

Original commit msg:
  - Fixes #1041. I just had a minute or two, hoping that it's enough :)

Co-authored-by: Min Xu <min.xu.public@gmail.com>
  • Loading branch information
min-xu-ai and flying-x committed Sep 23, 2022
1 parent abfa719 commit bfd57ff
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
2 changes: 1 addition & 1 deletion fairscale/nn/data_parallel/sharded_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def forward(self, *inputs: Any, **kwargs: Any) -> Any:
if needs_setup:
self.refresh_trainable()

if self._enable_broadcast_buffers:
if self._enable_broadcast_buffers and not self._should_accumulate_grads:
# NCCL communications are on a different stream, needs to be blocking
# for the subsequent FW to be correct
self.sync_buffers(blocking=True)
Expand Down
8 changes: 8 additions & 0 deletions tests/nn/data_parallel/test_sharded_ddp_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,14 @@ def closure():
with ddp_model.no_sync() if grad_accumulation else suppress():
input_tensor = torch.rand((64, 2)).to(device)
loss = ddp_model(input_tensor).abs().sum()

# If grad_accumulation, we can check after the forward that the models are different
# (not synced)
if grad_accumulation:
check_same_models_across_ranks(
ddp_model, dist.group.WORLD, params_should_be_equal=False, check_broadcast_buffers=True
)

loss.backward()
return loss

Expand Down

0 comments on commit bfd57ff

Please sign in to comment.