Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[M2M100] fix positional embeddings #10590

Merged
merged 6 commits into from
Mar 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
24 changes: 19 additions & 5 deletions src/transformers/models/m2m_100/modeling_m2m_100.py
Expand Up @@ -121,8 +121,17 @@ def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional
self.offset = 2
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.weights = self.get_embedding(num_positions + self.offset, embedding_dim, padding_idx)
self.register_buffer("_float_tensor", torch.FloatTensor(1))
self.make_weights(num_positions + self.offset, embedding_dim, padding_idx)

def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx)
if hasattr(self, "weights"):
# in forward, put the weights on correct device
emb_weights = emb_weights.to(self.weights.device)

self.weights = nn.Parameter(emb_weights)
self.weights.requires_grad = False
self.weights.detach_()

@staticmethod
def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
Expand All @@ -142,6 +151,7 @@ def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
if padding_idx is not None:
emb[padding_idx, :] = 0

return emb

@torch.no_grad()
Expand All @@ -161,9 +171,7 @@ def forward(
# expand embeddings if needed
max_pos = self.padding_idx + 1 + seq_len
if max_pos > self.weights.size(0):
self.weights = self.get_embedding(max_pos, self.embedding_dim, self.padding_idx)

self.weights = self.weights.to(self._float_tensor)
self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx)

return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach()

Expand Down Expand Up @@ -1149,6 +1157,12 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
r"encoder\.version",
r"decoder\.version",
r"lm_head\.weight",
r"model.encoder.embed_positions.weights",
r"model.decoder.embed_positions.weights",
]
_keys_to_ignore_on_save = [
r"model.encoder.embed_positions.weights",
r"model.decoder.embed_positions.weights",
Comment on lines +1163 to +1165
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since M2M100 uses sinusoidal positional embeddings, we don't need to save the pos embed weights.

]

def __init__(self, config: M2M100Config):
Expand Down
14 changes: 10 additions & 4 deletions tests/test_modeling_m2m_100.py
Expand Up @@ -96,13 +96,19 @@ def __init__(

def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(
3,
)
input_ids[:, -1] = self.eos_token_id # Eos Token

decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)

# we need to clamp the input ids here to avoid having pad token in between
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great! Thank for the very in-detail explanation!

# this is because for M2M100 the position_ids are prepared such that
# all pad tokens have pos id = 2 and rest are between 2..seq_length
# and the seq_length here is seq_length - num_pad_tokens
# but when using past, there is no way of knowing if the past input ids had
# pad tokens in them, which results in incorrect seq_lenth and which in turn results in
# position_ids being off by num_pad_tokens in past input
input_ids = input_ids.clamp(self.pad_token_id + 1)
decoder_input_ids = decoder_input_ids.clamp(self.pad_token_id + 1)

config = M2M100Config(
vocab_size=self.vocab_size,
d_model=self.hidden_size,
Expand Down