-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix gradient_checkpointing backward compatibility #14408
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok good, this looks good to me! Thanks for working on it
def gradient_checkpointing_hook(module, _): | ||
# Hook to enable backward compatibility for gradient checkpointing. Will be removed once all models have a | ||
# proper post_init method. | ||
if getattr(module.config, "gradient_checkpointing", False): | ||
module.gradient_checkpointing_enable() | ||
# Remove the attribute now that is has been consumed, so it's no saved in the config. | ||
delattr(module.config, "gradient_checkpointing") | ||
# The hook will remove itself after the first execution | ||
module._gradient_checkpointing_hook.remove() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok that works for me
… gradient_checkpointing_fix
* Fix gradient_checkpointing backward compatibility * Remove needless line * make sure mask prob is big enough and length small enough * Fix tests Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This broke HF/deepspeed integration with pt-1.8 or pt-1.9 - works fine with pt-1.10. found with git bisecting and reported by @jeffra, as their CI broke with our master.
|
if self.supports_gradient_checkpointing: | ||
self._gradient_checkpointing_hook = self.register_forward_pre_hook(gradient_checkpointing_hook) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the culprit for the failure I reported here: #14408 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removing it fixes the problem
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this mean that DeepSpeed does not support PyTorch hooks?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it does and uses those extensively. Which perhaps is the cause of the problem if some hooks disagree or don't follow the prescribed instruction of not modifying certain things.
If you're not sure about the cause I can investigate it and report back what I find.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Investigated: the issue is triggered by:
module._gradient_checkpointing_hook.remove() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so it looks like Deepspeed is just a harbinger here, and any other application that also uses hooks will trigger this issue.
It appears that what happens is that the hook is being removed from the dict that is being traversed one or more frames above.
I looked at what others did to solve this and they had to move the hook removal outside of the hook itself and into the forward
when it's safe to remove it. Except we don't have a forward
for this super class.
For some reason I can't reproduce this with pt-1.10, which means that they have reworked the loop that traverses the hooks dict to allow hooks to self-remove - probably using a copy to traverse the dict.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
possible fix: #14427
What does this PR do?
This supercedes #14405 and fixes #14388 by going at the root of the problem. When the code for backward compatibility is executed in the main init, the submodules of the model have not been created yet, so there is nothing to do. That code needs to be executed in some kind of
post_init
.We currently don't have a
post_init
in our models, and for another operation that is very similar (init_weights
, which needs ot be executed at the end of the init), we have a call to that method at the end of the init of every model. The good fix will thus be to replace that call toinit_weights
to a call topost_init
(which will callinit_weights
internally). This will be a big PR that touches every model, so will implement this for the end of the week.For a quick fix since we need to do a patch release because of the BC problem, this PR uses a forward pre hook (executed before the forward method) that removes itself. So the code is executed just before the first forward (not as clean as in a post init but the next best thing).