Skip to content

Commit

Permalink
Add support for gradient checkpointing (#19990)
Browse files Browse the repository at this point in the history
Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
  • Loading branch information
NielsRogge and Niels Rogge committed Oct 31, 2022
1 parent 8214a9f commit 4c9e0f0
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 0 deletions.
Expand Up @@ -581,6 +581,7 @@ class BertGenerationPreTrainedModel(PreTrainedModel):

config_class = BertGenerationConfig
base_model_prefix = "bert"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]

def _init_weights(self, module):
Expand All @@ -599,6 +600,10 @@ def _init_weights(self, module):
module.bias.data.zero_()
module.weight.data.fill_(1.0)

def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, BertEncoder):
module.gradient_checkpointing = value


BERT_GENERATION_START_DOCSTRING = r"""
Expand Down
Expand Up @@ -175,6 +175,8 @@ class EncoderDecoderModel(PreTrainedModel):
"""
config_class = EncoderDecoderConfig
base_model_prefix = "encoder_decoder"
main_input_name = "input_ids"
supports_gradient_checkpointing = True

def __init__(
self,
Expand Down Expand Up @@ -255,6 +257,11 @@ def tie_weights(self):
self.encoder, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix
)

def _set_gradient_checkpointing(self, module, value=False):
# call both encoder and decoder function on gradient checkpointing
self.encoder._set_gradient_checkpointing(module, value=value)
self.decoder._set_gradient_checkpointing(module, value=value)

def get_encoder(self):
return self.encoder

Expand Down
21 changes: 21 additions & 0 deletions tests/models/encoder_decoder/test_modeling_encoder_decoder.py
Expand Up @@ -611,6 +611,27 @@ def test_encoder_decoder_model_shared_weights(self):
input_ids_dict = self.prepare_config_and_inputs()
self.create_and_check_encoder_decoder_shared_weights(**input_ids_dict)

def test_training_gradient_checkpointing(self):
inputs_dict = self.prepare_config_and_inputs()
encoder_model, decoder_model = self.get_encoder_decoder_model(
inputs_dict["config"], inputs_dict["decoder_config"]
)

model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
model.train()
model.gradient_checkpointing_enable()
model.config.decoder_start_token_id = 0
model.config.pad_token_id = 0

model_inputs = {
"input_ids": inputs_dict["input_ids"],
"attention_mask": inputs_dict["attention_mask"],
"labels": inputs_dict["labels"],
"decoder_input_ids": inputs_dict["decoder_input_ids"],
}
loss = model(**model_inputs).loss
loss.backward()

@slow
def test_real_model_save_load_from_pretrained(self):
model_2 = self.get_pretrained_model()
Expand Down

0 comments on commit 4c9e0f0

Please sign in to comment.