Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix][FSDP] Don't remove post backward hooks for multiple backward fix #923

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 0 additions & 4 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1543,8 +1543,6 @@ def _register_post_backward_hooks(self) -> None:
return # don't register grad hooks if grad isn't enabled
for p in self.params:
if p.requires_grad:
if hasattr(p, "_shard_bwd_hook"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do the comments above need to updated because it seems like we don't need to remove the hook at the end of the BW?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohh yeah definitely, will do that

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the two lines are removed, will the hooks fire multiple times for multiple forwards cases (e.g. multiple activation checkpointing)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

even for single activation checkpointing case like Checkpoint(FSDP(module)), if there is forward recomputation in the backward pass, the hooks will be registered twice and fired twice unexpectedly?

continue
# Register a hook on the first call, empirically, autograd
# fires it at the end for this param, which makes sense.
p_tmp = p.expand_as(p) # Get a grad_fn on p_tmp.
Expand Down Expand Up @@ -1751,8 +1749,6 @@ def _finalize_parameters(fsdp_module: FullyShardedDataParallel) -> None:
continue
if hasattr(p, "_shard_bwd_hook"):
p_assert(len(p._shard_bwd_hook) == 2, f"WFPB: incorrect hook num: {len(p._shard_bwd_hook)}")
p._shard_bwd_hook[1].remove()
delattr(p, "_shard_bwd_hook")

# Leave the gradient accumulation state as-is if not synchronizing this pass. This ensures p.grad
# remains the unsharded gradient accumulated from prior no-sync passes, and p._saved_grad_shard
Expand Down