New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Fix][FSDP] Don't remove post backward hooks for multiple backward fix #923
Conversation
cc: @myleott |
I remember if we don't remove it bad things will happen. :-) Exactly which bad things is the good question. I suspect that some backward hooks will fire excessively or some memory will leak. But if we can be confident that this is indeed save to do, it would be really nice! Less code is always better! |
hmm is there any way to test "some backward hooks will fire excessively or some memory will leak"? or any doc or any past issues about this part? or who's the best POC about this? Myle? |
I am the POC. :-) I remember touch this part of the code last time. I don't remember anyone else touched it after I did. I wish I remember the details of potential issues with removing it. I think the best way for now is to use the unit tests. After unit tests are good if I can't thinking of any potential issues that are not covered by the unit test, we can certainly go ahead remove these line. |
cc28e10
to
40b9436
Compare
Updated PR, local tests are still running but I am hoping now they will all pass 🤞 |
@min-xu-ai looks like all but MeVo tests are passing, which should be unrelated to this change |
hmm, are you sure? I clicked on this first one, it seems the serialization test failed and then it timed out: |
that has to be flaky tests 🤷♂️ cause all are passing locally and the change has no relation at all to serialization |
|
We do have flaky tests. We try to track them here: #908 However, I just briefly looked at this error: I have actually never seen it before. But I could be wrong. Let me trigger some re-runs on CI. |
@zhaojuanmao, can you take a look at this diff as well? Do you think we could remove the code that removes the hooks? |
@@ -1543,8 +1543,6 @@ def _register_post_backward_hooks(self) -> None: | |||
return # don't register grad hooks if grad isn't enabled | |||
for p in self.params: | |||
if p.requires_grad: | |||
if hasattr(p, "_shard_bwd_hook"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do the comments above need to updated because it seems like we don't need to remove the hook at the end of the BW?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ohh yeah definitely, will do that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if the two lines are removed, will the hooks fire multiple times for multiple forwards cases (e.g. multiple activation checkpointing)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
even for single activation checkpointing case like Checkpoint(FSDP(module)), if there is forward recomputation in the backward pass, the hooks will be registered twice and fired twice unexpectedly?
@ngoyal2707 Can you share the use case that this supports? I don't think I understood the context from the PR title. |
|
Can you add a unit test for this using the smaller repro you have? Great that you figured out the solution! |
yes, definitely, will add unit test as well |
This is very strange but I have triggered the re-run twice and both times + the first time failed with exact same set of errors. This could be a CI infra issue. But it doesn't seem to be due to flakiness at this point. :-( |
hmm thats weird, doesn't reproduces on locally, so unfortunately no idea. how to fix it. |
Is there a way to trigger the serialization unit test in the main branch? This can help to identify if the failure is due to the CI infra or the code. |
good point. let me trigger a rerun on main |
@ngoyal2707 Friendly ping on this PR! Do we still want to check in this fix? |
Thanks for this PR! I have taken it over in #1079, with the new test from you as well. Will close this and merge that one in once all tests are done. |
fixes #918
I am quite confident, that we dont need to remove backward hooks even after finalizing. They will be automatically removed if the leaf variables go out of context and cuda autograd graph cleans up.
Mots tests were succeeding, apart from one related to cpu offload locally, will debug that