Skip to content

Commit

Permalink
Cleared backward hooks to avoid accumulating over iterations (#1143)
Browse files Browse the repository at this point in the history
  • Loading branch information
awgu committed Oct 10, 2023
1 parent 71aeffe commit 17ecf4a
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1865,6 +1865,7 @@ def _finalize_parameters(fsdp_module: FullyShardedDataParallel) -> None:
if hasattr(p, "_shard_bwd_hooks") and self._require_backward_grad_sync:
for _, handle in p._shard_bwd_hooks:
handle.remove()
p._shard_bwd_hooks.clear()

# Leave the gradient accumulation state as-is if not synchronizing this pass. This ensures p.grad
# remains the unsharded gradient accumulated from prior no-sync passes, and p._saved_grad_shard
Expand Down

0 comments on commit 17ecf4a

Please sign in to comment.