Skip to content

Commit

Permalink
Fix 29807, sinusoidal positional encodings overwritten by post_init() (
Browse files Browse the repository at this point in the history
…#29813)

* Check for requires_grad when initing weights

* Add unit test

* Move sinusoidal positional encoding generation after post_init()

* Add modules to skip init list

* Move create_sinusoidal_embeddings to _init_weights
  • Loading branch information
hovnatan authored and Ita Zaporozhets committed May 14, 2024
1 parent f63f27d commit 5ac0e3a
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/transformers/models/distilbert/modeling_distilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,6 @@ def __init__(self, config: PretrainedConfig):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.dim, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.dim)
if config.sinusoidal_pos_embds:
create_sinusoidal_embeddings(
n_pos=config.max_position_embeddings, dim=config.dim, out=self.position_embeddings.weight
)

self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12)
self.dropout = nn.Dropout(config.dropout)
Expand Down Expand Up @@ -634,6 +630,10 @@ def _init_weights(self, module: nn.Module):
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, Embeddings) and self.config.sinusoidal_pos_embds:
create_sinusoidal_embeddings(
self.config.max_position_embeddings, self.config.dim, module.position_embeddings.weight
)


DISTILBERT_START_DOCSTRING = r"""
Expand Down
10 changes: 10 additions & 0 deletions tests/models/distilbert/test_modeling_distilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
DistilBertForTokenClassification,
DistilBertModel,
)
from transformers.models.distilbert.modeling_distilbert import _create_sinusoidal_embeddings


class DistilBertModelTester(object):
Expand Down Expand Up @@ -238,6 +239,15 @@ def test_distilbert_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_distilbert_model(*config_and_inputs)

def test_distilbert_model_with_sinusoidal_encodings(self):
config = DistilBertConfig(sinusoidal_pos_embds=True)
model = DistilBertModel(config=config)
sinusoidal_pos_embds = torch.empty((config.max_position_embeddings, config.dim), dtype=torch.float32)
_create_sinusoidal_embeddings(config.max_position_embeddings, config.dim, sinusoidal_pos_embds)
self.model_tester.parent.assertTrue(
torch.equal(model.embeddings.position_embeddings.weight, sinusoidal_pos_embds)
)

def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_distilbert_for_masked_lm(*config_and_inputs)
Expand Down

0 comments on commit 5ac0e3a

Please sign in to comment.