From 5ac0e3ae8b65a964319b2af4fbd8c57c0db0d557 Mon Sep 17 00:00:00 2001 From: Hovnatan Karapetyan Date: Wed, 27 Mar 2024 09:28:00 +0400 Subject: [PATCH] Fix 29807, sinusoidal positional encodings overwritten by post_init() (#29813) * Check for requires_grad when initing weights * Add unit test * Move sinusoidal positional encoding generation after post_init() * Add modules to skip init list * Move create_sinusoidal_embeddings to _init_weights --- .../models/distilbert/modeling_distilbert.py | 8 ++++---- tests/models/distilbert/test_modeling_distilbert.py | 10 ++++++++++ 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index d33ffc8844607a..3a65e0296116dc 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -106,10 +106,6 @@ def __init__(self, config: PretrainedConfig): super().__init__() self.word_embeddings = nn.Embedding(config.vocab_size, config.dim, padding_idx=config.pad_token_id) self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.dim) - if config.sinusoidal_pos_embds: - create_sinusoidal_embeddings( - n_pos=config.max_position_embeddings, dim=config.dim, out=self.position_embeddings.weight - ) self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12) self.dropout = nn.Dropout(config.dropout) @@ -634,6 +630,10 @@ def _init_weights(self, module: nn.Module): elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) + elif isinstance(module, Embeddings) and self.config.sinusoidal_pos_embds: + create_sinusoidal_embeddings( + self.config.max_position_embeddings, self.config.dim, module.position_embeddings.weight + ) DISTILBERT_START_DOCSTRING = r""" diff --git a/tests/models/distilbert/test_modeling_distilbert.py b/tests/models/distilbert/test_modeling_distilbert.py index c24eb3096033f9..481d4b24cd76dd 100644 --- a/tests/models/distilbert/test_modeling_distilbert.py +++ b/tests/models/distilbert/test_modeling_distilbert.py @@ -37,6 +37,7 @@ DistilBertForTokenClassification, DistilBertModel, ) + from transformers.models.distilbert.modeling_distilbert import _create_sinusoidal_embeddings class DistilBertModelTester(object): @@ -238,6 +239,15 @@ def test_distilbert_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_distilbert_model(*config_and_inputs) + def test_distilbert_model_with_sinusoidal_encodings(self): + config = DistilBertConfig(sinusoidal_pos_embds=True) + model = DistilBertModel(config=config) + sinusoidal_pos_embds = torch.empty((config.max_position_embeddings, config.dim), dtype=torch.float32) + _create_sinusoidal_embeddings(config.max_position_embeddings, config.dim, sinusoidal_pos_embds) + self.model_tester.parent.assertTrue( + torch.equal(model.embeddings.position_embeddings.weight, sinusoidal_pos_embds) + ) + def test_for_masked_lm(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_distilbert_for_masked_lm(*config_and_inputs)