Skip to content

Commit

Permalink
Add a post init method to all models (#14431)
Browse files Browse the repository at this point in the history
* Add a post init method to all models

* Fix tests

* Fix last tests

* Fix templates

* Add comment

* Forgot to save
  • Loading branch information
sgugger authored Nov 18, 2021
1 parent 08816de commit d83b0e0
Show file tree
Hide file tree
Showing 70 changed files with 693 additions and 359 deletions.
27 changes: 14 additions & 13 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,17 +412,6 @@ def floating_point_ops(
return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings)


def gradient_checkpointing_hook(module, _):
# Hook to enable backward compatibility for gradient checkpointing. Will be removed once all models have a
# proper post_init method.
if getattr(module.config, "gradient_checkpointing", False):
module.gradient_checkpointing_enable()
# Remove the attribute now that is has been consumed, so it's no saved in the config.
delattr(module.config, "gradient_checkpointing")
# The hook will remove itself after the first execution
module._gradient_checkpointing_hook.remove()


class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin):
r"""
Base class for all models.
Expand Down Expand Up @@ -490,8 +479,20 @@ def __init__(self, config: PretrainedConfig, *inputs, **kwargs):
# Save config and origin of the pretrained weights if given in model
self.config = config
self.name_or_path = config.name_or_path
if self.supports_gradient_checkpointing:
self._gradient_checkpointing_hook = self.register_forward_pre_hook(gradient_checkpointing_hook)

def post_init(self):
"""
A method executed at the end of each Transformer model initialization, to execute code that needs the model's
modules properly initialized (such as weight initialization).
"""
self.init_weights()
self._backward_compatibility_gradient_checkpointing()

def _backward_compatibility_gradient_checkpointing(self):
if self.supports_gradient_checkpointing and getattr(self.config, "gradient_checkpointing", False):
self.gradient_checkpointing_enable()
# Remove the attribute now that is has been consumed, so it's no saved in the config.
delattr(self.config, "gradient_checkpointing")

@classmethod
def _from_config(cls, config, **kwargs):
Expand Down
21 changes: 14 additions & 7 deletions src/transformers/models/albert/modeling_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,8 @@ def __init__(self, config, add_pooling_layer=True):
self.pooler = None
self.pooler_activation = None

self.init_weights()
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.embeddings.word_embeddings
Expand Down Expand Up @@ -757,7 +758,8 @@ def __init__(self, config):
self.predictions = AlbertMLMHead(config)
self.sop_classifier = AlbertSOPHead(config)

self.init_weights()
# Initialize weights and apply final processing
self.post_init()

def get_output_embeddings(self):
return self.predictions.decoder
Expand Down Expand Up @@ -903,7 +905,8 @@ def __init__(self, config):
self.albert = AlbertModel(config, add_pooling_layer=False)
self.predictions = AlbertMLMHead(config)

self.init_weights()
# Initialize weights and apply final processing
self.post_init()

def get_output_embeddings(self):
return self.predictions.decoder
Expand Down Expand Up @@ -991,7 +994,8 @@ def __init__(self, config):
self.dropout = nn.Dropout(config.classifier_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)

self.init_weights()
# Initialize weights and apply final processing
self.post_init()

@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
Expand Down Expand Up @@ -1097,7 +1101,8 @@ def __init__(self, config):
self.dropout = nn.Dropout(classifier_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)

self.init_weights()
# Initialize weights and apply final processing
self.post_init()

@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
Expand Down Expand Up @@ -1187,7 +1192,8 @@ def __init__(self, config):
self.albert = AlbertModel(config, add_pooling_layer=False)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)

self.init_weights()
# Initialize weights and apply final processing
self.post_init()

@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
Expand Down Expand Up @@ -1286,7 +1292,8 @@ def __init__(self, config):
self.dropout = nn.Dropout(config.classifier_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, 1)

self.init_weights()
# Initialize weights and apply final processing
self.post_init()

@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@add_code_sample_docstrings(
Expand Down
15 changes: 10 additions & 5 deletions src/transformers/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,8 +699,9 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No
self.layers = nn.ModuleList([BartEncoderLayer(config) for _ in range(config.encoder_layers)])
self.layernorm_embedding = nn.LayerNorm(embed_dim)

self.init_weights()
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.embed_tokens
Expand Down Expand Up @@ -870,8 +871,9 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No
self.layers = nn.ModuleList([BartDecoderLayer(config) for _ in range(config.decoder_layers)])
self.layernorm_embedding = nn.LayerNorm(config.d_model)

self.init_weights()
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.embed_tokens
Expand Down Expand Up @@ -1130,7 +1132,8 @@ def __init__(self, config: BartConfig):
self.encoder = BartEncoder(config, self.shared)
self.decoder = BartDecoder(config, self.shared)

self.init_weights()
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.shared
Expand Down Expand Up @@ -1248,7 +1251,8 @@ def __init__(self, config: BartConfig):
self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)

self.init_weights()
# Initialize weights and apply final processing
self.post_init()

def get_encoder(self):
return self.model.get_encoder()
Expand Down Expand Up @@ -1666,7 +1670,8 @@ def __init__(self, config):

self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

self.init_weights()
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.model.decoder.embed_tokens
Expand Down
12 changes: 8 additions & 4 deletions src/transformers/models/beit/modeling_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,8 @@ def __init__(self, config, add_pooling_layer=True):
)
self.pooler = BeitPooler(config) if add_pooling_layer else None

self.init_weights()
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.embeddings.patch_embeddings
Expand Down Expand Up @@ -715,7 +716,8 @@ def __init__(self, config):
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)

self.init_weights()
# Initialize weights and apply final processing
self.post_init()

@add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
Expand Down Expand Up @@ -805,7 +807,8 @@ def __init__(self, config):
# Classifier head
self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()

self.init_weights()
# Initialize weights and apply final processing
self.post_init()

@add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
Expand Down Expand Up @@ -1121,7 +1124,8 @@ def __init__(self, config):
self.decode_head = BeitUperHead(config)
self.auxiliary_head = BeitFCNHead(config) if config.use_auxiliary_head else None

self.init_weights()
# Initialize weights and apply final processing
self.post_init()

def compute_loss(self, logits, auxiliary_logits, labels):
# upsample logits to the images' original size
Expand Down
27 changes: 18 additions & 9 deletions src/transformers/models/bert/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,7 +870,8 @@ def __init__(self, config, add_pooling_layer=True):

self.pooler = BertPooler(config) if add_pooling_layer else None

self.init_weights()
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.embeddings.word_embeddings
Expand Down Expand Up @@ -1037,7 +1038,8 @@ def __init__(self, config):
self.bert = BertModel(config)
self.cls = BertPreTrainingHeads(config)

self.init_weights()
# Initialize weights and apply final processing
self.post_init()

def get_output_embeddings(self):
return self.cls.predictions.decoder
Expand Down Expand Up @@ -1145,7 +1147,8 @@ def __init__(self, config):
self.bert = BertModel(config, add_pooling_layer=False)
self.cls = BertOnlyMLMHead(config)

self.init_weights()
# Initialize weights and apply final processing
self.post_init()

def get_output_embeddings(self):
return self.cls.predictions.decoder
Expand Down Expand Up @@ -1294,7 +1297,8 @@ def __init__(self, config):
self.bert = BertModel(config, add_pooling_layer=False)
self.cls = BertOnlyMLMHead(config)

self.init_weights()
# Initialize weights and apply final processing
self.post_init()

def get_output_embeddings(self):
return self.cls.predictions.decoder
Expand Down Expand Up @@ -1394,7 +1398,8 @@ def __init__(self, config):
self.bert = BertModel(config)
self.cls = BertOnlyNSPHead(config)

self.init_weights()
# Initialize weights and apply final processing
self.post_init()

@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
Expand Down Expand Up @@ -1501,7 +1506,8 @@ def __init__(self, config):
self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)

self.init_weights()
# Initialize weights and apply final processing
self.post_init()

@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
Expand Down Expand Up @@ -1600,7 +1606,8 @@ def __init__(self, config):
self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, 1)

self.init_weights()
# Initialize weights and apply final processing
self.post_init()

@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@add_code_sample_docstrings(
Expand Down Expand Up @@ -1698,7 +1705,8 @@ def __init__(self, config):
self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)

self.init_weights()
# Initialize weights and apply final processing
self.post_init()

@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
Expand Down Expand Up @@ -1788,7 +1796,8 @@ def __init__(self, config):
self.bert = BertModel(config, add_pooling_layer=False)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)

self.init_weights()
# Initialize weights and apply final processing
self.post_init()

@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,8 @@ def __init__(self, config):
self.embeddings = BertGenerationEmbeddings(config)
self.encoder = BertEncoder(config)

self.init_weights()
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.embeddings.word_embeddings
Expand Down Expand Up @@ -456,7 +457,8 @@ def __init__(self, config):
self.bert = BertGenerationEncoder(config)
self.lm_head = BertGenerationOnlyLMHead(config)

self.init_weights()
# Initialize weights and apply final processing
self.post_init()

def get_output_embeddings(self):
return self.lm_head.decoder
Expand Down
24 changes: 16 additions & 8 deletions src/transformers/models/big_bird/modeling_big_bird.py
Original file line number Diff line number Diff line change
Expand Up @@ -1953,7 +1953,8 @@ def __init__(self, config, add_pooling_layer=True):
)
self.set_attention_type("original_full")

self.init_weights()
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.embeddings.word_embeddings
Expand Down Expand Up @@ -2262,7 +2263,8 @@ def __init__(self, config):
self.bert = BigBirdModel(config, add_pooling_layer=True)
self.cls = BigBirdPreTrainingHeads(config)

self.init_weights()
# Initialize weights and apply final processing
self.post_init()

def get_output_embeddings(self):
return self.cls.predictions.decoder
Expand Down Expand Up @@ -2370,7 +2372,8 @@ def __init__(self, config):
self.bert = BigBirdModel(config)
self.cls = BigBirdOnlyMLMHead(config)

self.init_weights()
# Initialize weights and apply final processing
self.post_init()

def get_output_embeddings(self):
return self.cls.predictions.decoder
Expand Down Expand Up @@ -2472,7 +2475,8 @@ def __init__(self, config):
self.bert = BigBirdModel(config)
self.cls = BigBirdOnlyMLMHead(config)

self.init_weights()
# Initialize weights and apply final processing
self.post_init()

def get_output_embeddings(self):
return self.cls.predictions.decoder
Expand Down Expand Up @@ -2642,7 +2646,8 @@ def __init__(self, config):
self.bert = BigBirdModel(config)
self.classifier = BigBirdClassificationHead(config)

self.init_weights()
# Initialize weights and apply final processing
self.post_init()

@add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
Expand Down Expand Up @@ -2737,7 +2742,8 @@ def __init__(self, config):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, 1)

self.init_weights()
# Initialize weights and apply final processing
self.post_init()

@add_start_docstrings_to_model_forward(
BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
Expand Down Expand Up @@ -2834,7 +2840,8 @@ def __init__(self, config):
self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)

self.init_weights()
# Initialize weights and apply final processing
self.post_init()

@add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
Expand Down Expand Up @@ -2942,7 +2949,8 @@ def __init__(self, config, add_pooling_layer=False):
self.bert = BigBirdModel(config, add_pooling_layer=add_pooling_layer)
self.qa_classifier = BigBirdForQuestionAnsweringHead(config)

self.init_weights()
# Initialize weights and apply final processing
self.post_init()

@add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
Expand Down
Loading

0 comments on commit d83b0e0

Please sign in to comment.