Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix gradient_checkpointing backward compatibility #14408

Merged
merged 5 commits into from
Nov 16, 2021

Conversation

sgugger
Copy link
Collaborator

@sgugger sgugger commented Nov 15, 2021

What does this PR do?

This supercedes #14405 and fixes #14388 by going at the root of the problem. When the code for backward compatibility is executed in the main init, the submodules of the model have not been created yet, so there is nothing to do. That code needs to be executed in some kind of post_init.

We currently don't have a post_init in our models, and for another operation that is very similar (init_weights, which needs ot be executed at the end of the init), we have a call to that method at the end of the init of every model. The good fix will thus be to replace that call to init_weights to a call to post_init (which will call init_weights internally). This will be a big PR that touches every model, so will implement this for the end of the week.

For a quick fix since we need to do a patch release because of the BC problem, this PR uses a forward pre hook (executed before the forward method) that removes itself. So the code is executed just before the first forward (not as clean as in a post init but the next best thing).

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok good, this looks good to me! Thanks for working on it

Comment on lines +415 to +423
def gradient_checkpointing_hook(module, _):
# Hook to enable backward compatibility for gradient checkpointing. Will be removed once all models have a
# proper post_init method.
if getattr(module.config, "gradient_checkpointing", False):
module.gradient_checkpointing_enable()
# Remove the attribute now that is has been consumed, so it's no saved in the config.
delattr(module.config, "gradient_checkpointing")
# The hook will remove itself after the first execution
module._gradient_checkpointing_hook.remove()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok that works for me

@sgugger sgugger merged commit 040fd47 into master Nov 16, 2021
@sgugger sgugger deleted the gradient_checkpointing_fix branch November 16, 2021 13:58
LysandreJik pushed a commit that referenced this pull request Nov 16, 2021
* Fix gradient_checkpointing backward compatibility

* Remove needless line

* make sure mask prob is big enough and length small enough

* Fix tests

Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
@stas00
Copy link
Contributor

stas00 commented Nov 17, 2021

This broke HF/deepspeed integration with pt-1.8 or pt-1.9 - works fine with pt-1.10. found with git bisecting and reported by @jeffra, as their CI broke with our master.

RUN_SLOW=1 pyt tests/deepspeed/test_deepspeed.py::TestDeepSpeedWithLauncher::test_clm_1_zero3 -sv
E           Traceback (most recent call last):
E             File "/mnt/nvme1/code/huggingface/transformers-master/examples/pytorch/language-modeling/run_clm.py", line 524, in <module>
E               main()
E             File "/mnt/nvme1/code/huggingface/transformers-master/examples/pytorch/language-modeling/run_clm.py", line 472, in main
E               train_result = trainer.train(resume_from_checkpoint=checkpoint)
E             File "/mnt/nvme1/code/huggingface/transformers-master/src/transformers/trainer.py", line 1316, in train
E               tr_loss_step = self.training_step(model, inputs)
E             File "/mnt/nvme1/code/huggingface/transformers-master/src/transformers/trainer.py", line 1849, in training_step
E               loss = self.compute_loss(model, inputs)
E             File "/mnt/nvme1/code/huggingface/transformers-master/src/transformers/trainer.py", line 1881, in compute_loss
E               outputs = model(**inputs)
E             File "/home/stas/anaconda3/envs/py38-pt19/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
E               return forward_call(*input, **kwargs)
E             File "/mnt/nvme1/code/github/00optimize/deepspeed/deepspeed/runtime/engine.py", line 1580, in forward
E               loss = self.module(*inputs, **kwargs)
E             File "/home/stas/anaconda3/envs/py38-pt19/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1057, in _call_impl
E               for hook in itertools.chain(
E           RuntimeError: OrderedDict mutated during iteration

Comment on lines +493 to +494
if self.supports_gradient_checkpointing:
self._gradient_checkpointing_hook = self.register_forward_pre_hook(gradient_checkpointing_hook)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the culprit for the failure I reported here: #14408 (comment)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removing it fixes the problem

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean that DeepSpeed does not support PyTorch hooks?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it does and uses those extensively. Which perhaps is the cause of the problem if some hooks disagree or don't follow the prescribed instruction of not modifying certain things.

If you're not sure about the cause I can investigate it and report back what I find.

Copy link
Contributor

@stas00 stas00 Nov 17, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Investigated: the issue is triggered by:

module._gradient_checkpointing_hook.remove()

Copy link
Contributor

@stas00 stas00 Nov 17, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so it looks like Deepspeed is just a harbinger here, and any other application that also uses hooks will trigger this issue.

It appears that what happens is that the hook is being removed from the dict that is being traversed one or more frames above.

I looked at what others did to solve this and they had to move the hook removal outside of the hook itself and into the forward when it's safe to remove it. Except we don't have a forward for this super class.

For some reason I can't reproduce this with pt-1.10, which means that they have reworked the loop that traverses the hooks dict to allow hooks to self-remove - probably using a copy to traverse the dict.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

possible fix: #14427

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Wav2Vec2 CUDA memory usage doubled in v4.11.3 compared to v4.10.3 with the same batch size
4 participants