Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix][FSDP] Don't remove post backward hooks for multiple backward fix #923

Closed
wants to merge 1 commit into from

Conversation

ngoyal2707
Copy link
Contributor

fixes #918

I am quite confident, that we dont need to remove backward hooks even after finalizing. They will be automatically removed if the leaf variables go out of context and cuda autograd graph cleans up.

Mots tests were succeeding, apart from one related to cpu offload locally, will debug that

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 1, 2022
@ngoyal2707
Copy link
Contributor Author

cc: @myleott

@min-xu-ai
Copy link
Contributor

I remember if we don't remove it bad things will happen. :-) Exactly which bad things is the good question. I suspect that some backward hooks will fire excessively or some memory will leak. But if we can be confident that this is indeed save to do, it would be really nice! Less code is always better!

@ngoyal2707
Copy link
Contributor Author

I remember if we don't remove it bad things will happen. :-) Exactly which bad things is the good question. I suspect that some backward hooks will fire excessively or some memory will leak. But if we can be confident that this is indeed save to do, it would be really nice! Less code is always better!

hmm is there any way to test "some backward hooks will fire excessively or some memory will leak"? or any doc or any past issues about this part? or who's the best POC about this? Myle?

@min-xu-ai
Copy link
Contributor

hmm is there any way to test "some backward hooks will fire excessively or some memory will leak"? or any doc or any past issues about this part? or who's the best POC about this? Myle?

I am the POC. :-) I remember touch this part of the code last time. I don't remember anyone else touched it after I did.

I wish I remember the details of potential issues with removing it. I think the best way for now is to use the unit tests. After unit tests are good if I can't thinking of any potential issues that are not covered by the unit test, we can certainly go ahead remove these line.

@anj-s anj-s changed the title dont remove post backward hooks for multiple backward fix [Fix][FSDP] Don't remove post backward hooks for multiple backward fix Feb 2, 2022
@tmarkstrum tmarkstrum self-requested a review February 2, 2022 20:11
@ngoyal2707 ngoyal2707 force-pushed the ngoyal_fix_for_multiple_backwards branch from cc28e10 to 40b9436 Compare February 2, 2022 20:31
@ngoyal2707
Copy link
Contributor Author

Updated PR, local tests are still running but I am hoping now they will all pass 🤞

@ngoyal2707
Copy link
Contributor Author

@min-xu-ai looks like all but MeVo tests are passing, which should be unrelated to this change

@min-xu-ai
Copy link
Contributor

@min-xu-ai looks like all but MeVo tests are passing, which should be unrelated to this change

hmm, are you sure? I clicked on this first one, it seems the serialization test failed and then it timed out:

https://app.circleci.com/pipelines/github/facebookresearch/fairscale/4061/workflows/c33d045c-fbe1-4364-8bc7-7a44335cb7d3/jobs/44795

@ngoyal2707
Copy link
Contributor Author

@min-xu-ai looks like all but MeVo tests are passing, which should be unrelated to this change

hmm, are you sure? I clicked on this first one, it seems the serialization test failed and then it timed out:

https://app.circleci.com/pipelines/github/facebookresearch/fairscale/4061/workflows/c33d045c-fbe1-4364-8bc7-7a44335cb7d3/jobs/44795

that has to be flaky tests 🤷‍♂️ cause all are passing locally and the change has no relation at all to serialization

@ngoyal2707
Copy link
Contributor Author

(fairseq-20211123-py39-meg-v2.6) namangoyal@learnfair7614:~/src/fairscale$ pytest tests/nn/data_parallel/test_fsdp.py::TestSerialization
=========================================================================================================================================================================== test session starts ============================================================================================================================================================================
platform linux -- Python 3.9.7, pytest-5.4.1, py-1.11.0, pluggy-0.13.1 -- /private/home/namangoyal/.conda/envs/fairseq-20211123-py39-meg-v2.6/bin/python
cachedir: .pytest_cache
rootdir: /private/home/namangoyal/src/fairscale, inifile: setup.cfg
plugins: hydra-core-1.0.7, timeout-1.4.2, cov-2.10.0
collected 8 items

tests/nn/data_parallel/test_fsdp.py::TestSerialization::test_multiprocessing__False_False_ PASSED                                                                                                                                                                                                                                                                    [ 12%]
tests/nn/data_parallel/test_fsdp.py::TestSerialization::test_multiprocessing__False_True_ PASSED                                                                                                                                                                                                                                                                     [ 25%]
tests/nn/data_parallel/test_fsdp.py::TestSerialization::test_multiprocessing__True_False_ PASSED                                                                                                                                                                                                                                                                     [ 37%]
tests/nn/data_parallel/test_fsdp.py::TestSerialization::test_multiprocessing__True_True_ PASSED                                                                                                                                                                                                                                                                      [ 50%]
tests/nn/data_parallel/test_fsdp.py::TestSerialization::test_pickle__False_False_ PASSED                                                                                                                                                                                                                                                                             [ 62%]
tests/nn/data_parallel/test_fsdp.py::TestSerialization::test_pickle__False_True_ PASSED                                                                                                                                                                                                                                                                              [ 75%]
tests/nn/data_parallel/test_fsdp.py::TestSerialization::test_pickle__True_False_ PASSED                                                                                                                                                                                                                                                                              [ 87%]
tests/nn/data_parallel/test_fsdp.py::TestSerialization::test_pickle__True_True_ PASSED                                                                                                                                                                                                                                                                               [100%]

======================================================================================================================================================================= 8 passed in 77.47s (0:01:17) =======================================================================================================================================================================

@min-xu-ai
Copy link
Contributor

We do have flaky tests. We try to track them here: #908

However, I just briefly looked at this error:
image

I have actually never seen it before. But I could be wrong. Let me trigger some re-runs on CI.

@min-xu-ai
Copy link
Contributor

@zhaojuanmao, can you take a look at this diff as well? Do you think we could remove the code that removes the hooks?

@@ -1543,8 +1543,6 @@ def _register_post_backward_hooks(self) -> None:
return # don't register grad hooks if grad isn't enabled
for p in self.params:
if p.requires_grad:
if hasattr(p, "_shard_bwd_hook"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do the comments above need to updated because it seems like we don't need to remove the hook at the end of the BW?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohh yeah definitely, will do that

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the two lines are removed, will the hooks fire multiple times for multiple forwards cases (e.g. multiple activation checkpointing)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

even for single activation checkpointing case like Checkpoint(FSDP(module)), if there is forward recomputation in the backward pass, the hooks will be registered twice and fired twice unexpectedly?

@anj-s
Copy link
Contributor

anj-s commented Feb 3, 2022

@ngoyal2707 Can you share the use case that this supports? I don't think I understood the context from the PR title.

@ngoyal2707
Copy link
Contributor Author

@ngoyal2707 Can you share the use case that this supports? I don't think I understood the context from the PR title.

@anj-s its a fix for this: #918

@anj-s
Copy link
Contributor

anj-s commented Feb 3, 2022

@ngoyal2707 Can you share the use case that this supports? I don't think I understood the context from the PR title.

@anj-s its a fix for this: #918

Can you add a unit test for this using the smaller repro you have? Great that you figured out the solution!

@ngoyal2707
Copy link
Contributor Author

@ngoyal2707 Can you share the use case that this supports? I don't think I understood the context from the PR title.

@anj-s its a fix for this: #918

Can you add a unit test for this using the smaller repro you have? Great that you figured out the solution!

yes, definitely, will add unit test as well

@min-xu-ai
Copy link
Contributor

This is very strange but I have triggered the re-run twice and both times + the first time failed with exact same set of errors. This could be a CI infra issue. But it doesn't seem to be due to flakiness at this point. :-(

@ngoyal2707
Copy link
Contributor Author

ime failed with exact same set of errors. This could be a CI infra issue. But it doesn't seem to be due to flaki

hmm thats weird, doesn't reproduces on locally, so unfortunately no idea. how to fix it.

@tmarkstrum
Copy link
Contributor

Is there a way to trigger the serialization unit test in the main branch? This can help to identify if the failure is due to the CI infra or the code.

@min-xu-ai
Copy link
Contributor

Is there a way to trigger the serialization unit test in the main branch? This can help to identify if the failure is due to the CI infra or the code.

good point. let me trigger a rerun on main

@min-xu-ai
Copy link
Contributor

OK. It seems that there is some infra issue with CI. The main branch is failing with timeouts as well

image

@anj-s
Copy link
Contributor

anj-s commented Apr 4, 2022

@ngoyal2707 Friendly ping on this PR! Do we still want to check in this fix?

@min-xu-ai
Copy link
Contributor

Thanks for this PR! I have taken it over in #1079, with the new test from you as well. Will close this and merge that one in once all tests are done.

@min-xu-ai min-xu-ai closed this Sep 24, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

FSDP fails with multiple forward and then multiple backward calls
6 participants