Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make gradient_checkpointing a training argument #13657

Merged
merged 12 commits into from
Sep 22, 2021
Merged

Conversation

sgugger
Copy link
Collaborator

@sgugger sgugger commented Sep 20, 2021

What does this PR do?

This PR reworks the logic behind gradient accumulation. It is currently set as a configuration argument which is annoying because:

  • it's not easily discoverable
  • when someone pushes a model trained with gradient checkpointing activated to the Hub, that models keeps this gradient checkpointing even if new users don't want to use it.

That's why this PR depractes the gradient_checkpointing argument in any config and adds:

  • a method gradient_checkpointing_enable to PreTrainedModel to activate gradient checkpointing
  • a training argument for the users using the Trainer API that will call that gradient_checkpointing method.

Internally, the implementation still relies on the config as it's the easiest place to set something that needs to pass several layers of a model (if we have a BertForMaskedLM for instance, the actual gradient checkpointing only applies to the BertEncoder inside the BertModel inside that BertForMaskedLM) but that argument is made private and not saved to the model Hub.

Copy link
Contributor

@stas00 stas00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fantastic! Thanks so much for adding this feature and making it independent from tweaking the config object. Loving it!

Left a few small suggestions.

src/transformers/configuration_utils.py Outdated Show resolved Hide resolved
Will activate gradient checkpointing if :obj:`True`, deactivate it if :obj:`False`.
"""
if not self.supports_gradient_checkpointing and flag:
logger.warn(f"{self.__class__.__name__} does not support gradient checkpointing so nothing will happen.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason not to assert here instead? The user can then change their setup and proceed without problems.

It's a clear error to activate this option if a model doesn't support it, IMHO.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's to be consistent with the previous behavior where we did nothing if the user input gradient_checkpointing for a model that did not support it.

I'm not opposed to asserting, but let's see what @LysandreJik and @patrickvonplaten think.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would also be in favor of raising an error here actually. It's a new function so I think we can add this behavior here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will switch then!

src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
sgugger and others added 4 commits September 20, 2021 15:58
Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me! Thanks for taking care of all the mentions of gradient_checkpointing in the repository, very cool work!

- To fine-tune LED on all 16384, it is necessary to enable *gradient checkpointing* by setting
``config.gradient_checkpointing = True``.
- To fine-tune LED on all 16384, it is necessary to enable *gradient checkpointing* by executing
``model.gradient_checkpointing_enable()``.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about enable_gradient_checkpointing?

@@ -932,6 +933,21 @@ def prune_heads(self, heads_to_prune: Dict[int, List[int]]):

self.base_model._prune_heads(heads_to_prune)

def gradient_checkpointing_enable(self, flag: bool = True):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should there be a disable too?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I didn't see this had a flag! Maybe toggle then? Or set_gradient_checkpointing to follow traditional boolean setter conventions?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stas00 really wanted the method name to start with gradient_checkpointing to be more easily discoverable.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After some discussion, with Lysandre, we decided to try gradient_checkpointing_enable and gradient_checkpointing_disable (no args for each).

@stas00
Copy link
Contributor

stas00 commented Sep 20, 2021

I took the liberty to also document this feature in https://huggingface.co/transformers/performance.html and pushed it here, so if you rename the method please adjust the doc as well. Thank you!

@patrickvonplaten
Copy link
Contributor

I'm not very happy about keeping gradient_checkpointing in the config internally as it adds IMO significantly more complexity to what a user has to know now about model configurations. Before this PR, every configuration parameter that one sees in configuration_utils.py is stored when saving the configuration file. If we introduce now private configuration parameters that are not saved when the model is saved, it forces users to learn/understand a new exception and makes the code harder to understand/read.

I'm very much in favor of removing gradient_checkpointing from the config, but the better option IMO is not to go over the config anymore at all but to provide _disable_gradient_checkpointing, _enable_gradient_checkpointing functions to all sub-modules. It's much more work, but IMO there are also much more upsides to having this approach.

@sgugger
Copy link
Collaborator Author

sgugger commented Sep 21, 2021

I'm not very happy about keeping gradient_checkpointing in the config internally as it adds IMO significantly more complexity to what a user has to know now about model configurations. Before this PR, every configuration parameter that one sees in configuration_utils.py is stored when saving the configuration file. If we introduce now private configuration parameters that are not saved when the model is saved, it forces users to learn/understand a new exception and makes the code harder to understand/read.

I am not following since this is all private. The user does not have to know anything about model configurations for this option. I'm also not sure which new exceptions you are mentioning?

I'm very much in favor of removing gradient_checkpointing from the config, but the better option IMO is not to go over the config anymore at all but to provide _disable_gradient_checkpointing, _enable_gradient_checkpointing functions to all sub-modules. It's much more work, but IMO there are also much more upsides to having this approach.

Note that those submodules are often not even PreTrainedModel, so we will have to add those functions manually to a tons of nn.Module. For backward compatibility, we will also need to still have something stored in the config, since the config can't call the method gradient_checkpointing_enable on the model, so this effort is a bit pointless before v5 in the sense that there will be private parameters not saved anyway.

In any case, if this second approach is selected, I would still urge to merge this PR as soon as possible to avoid any merge conflict or many user diverging from the templates. We can then change the internal implementation on the models added more progressively.

@patrickvonplaten
Copy link
Contributor

I'm just a bit worried that we'll start using the "private" configuration parameters of PreTrainedConfig just as a way to easily pass flags to all the nn,Modules even though those parameters shouldn't be in the config at all. For me the configuration should really just be static configuration and not serve any other purpose than defining the model architecture.

For a user that just looks at the configuration on the hub this PR is great, but for users that actually looks into the code, adding a NO_SAVE_CONFIG_KEYS option to PreTrainedConfig adds a new layer of complexity for the reader to understand. This could be avoided IMO.

Think we should be able to add a single method to the BertPreTrainedModel like this:

def _enable_gradient_checkpointing(self):
    model = self
    if hasattr(model, self.base_model_prefix):
        model = getattr(model, self.base_model_prefix)
    
    # set gradient checkpointing to True in the encoder
    model.encoder.gradient_checkpointing = True

=> this should work just fine no?

Given that we will have to leave it in the config anyways until v5, I'm fine with leveraging the config I guess - I just don't think it's good practice to introduce "special" configuration parameters with NO_SAVE_CONFIG_KEYS

@sgugger sgugger mentioned this pull request Sep 21, 2021
5 tasks
@stas00
Copy link
Contributor

stas00 commented Sep 21, 2021

If we leave the config as is, as proposed by Patrick, should we perhaps discuss the ability for the user to choose what goes into the published model's config? We are sort of trying to do DWIM (do what I mean) and magically have the published model have all the right settings.

So adding to the model saving interface our default filters which for example will automatically disable gradient_checkpointing and then allowing users to override those if they need to? So we have the ease of use of having sensible defaults and then allow users to override any of the defaults?

In the current PR the user has no control over NO_SAVE_CONFIG_KEYS

And we won't need to wait till v5 to do so.

@sgugger
Copy link
Collaborator Author

sgugger commented Sep 21, 2021

@stas00 This is out of scope of this PR (which does not contain the NO_SAVE_CONFIG_KEYS anymore btw, to address Patrick's comments), so maybe the discussion should be moved elsewhere?

@stas00
Copy link
Contributor

stas00 commented Sep 21, 2021

I was just following up to Patrick's comment. I have no problem with not discussing it here.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the extra effort! Really like the new design

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me! Thank you for iterating.

@sgugger sgugger merged commit 27d4639 into master Sep 22, 2021
@sgugger sgugger deleted the gradient_checkpointing branch September 22, 2021 11:51
Narsil pushed a commit to Narsil/transformers that referenced this pull request Sep 25, 2021
* Make gradient_checkpointing a training argument

* Update src/transformers/modeling_utils.py

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* Update src/transformers/configuration_utils.py

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* Fix tests

* Style

* document Gradient Checkpointing as a performance feature

* Small rename

* PoC for not using the config

* Adapt BC to new PoC

* Forgot to save

* Rollout changes to all other models

* Fix typo

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Co-authored-by: Stas Bekman <stas@stason.org>
stas00 added a commit to stas00/transformers that referenced this pull request Oct 12, 2021
* Make gradient_checkpointing a training argument

* Update src/transformers/modeling_utils.py

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* Update src/transformers/configuration_utils.py

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* Fix tests

* Style

* document Gradient Checkpointing as a performance feature

* Small rename

* PoC for not using the config

* Adapt BC to new PoC

* Forgot to save

* Rollout changes to all other models

* Fix typo

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Co-authored-by: Stas Bekman <stas@stason.org>
Albertobegue pushed a commit to Albertobegue/transformers that referenced this pull request Jan 13, 2022
* Make gradient_checkpointing a training argument

* Update src/transformers/modeling_utils.py

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* Update src/transformers/configuration_utils.py

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* Fix tests

* Style

* document Gradient Checkpointing as a performance feature

* Small rename

* PoC for not using the config

* Adapt BC to new PoC

* Forgot to save

* Rollout changes to all other models

* Fix typo

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Co-authored-by: Stas Bekman <stas@stason.org>
Albertobegue pushed a commit to Albertobegue/transformers that referenced this pull request Jan 27, 2022
* Make gradient_checkpointing a training argument

* Update src/transformers/modeling_utils.py

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* Update src/transformers/configuration_utils.py

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* Fix tests

* Style

* document Gradient Checkpointing as a performance feature

* Small rename

* PoC for not using the config

* Adapt BC to new PoC

* Forgot to save

* Rollout changes to all other models

* Fix typo

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Co-authored-by: Stas Bekman <stas@stason.org>
versae added a commit to versae/transformers that referenced this pull request Apr 20, 2023
It uses `flax.linen.remat` and follows on PRs huggingface#13657 and huggingface#17994
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants