Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions src/transformers/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,6 +897,21 @@ def __init__(self, config: BartConfig):
# Initialize weights and apply final processing
self.post_init()

def tie_weights(self, missing_keys: Optional[set[str]] = None, recompute_mapping: bool = True):
"""We need to overload here to handle the wrong key saved in some main checkpoints."""
if self.config.tie_word_embeddings:
# Some model checkpoints like "facebook/bart-large-cnn"'s embedding weight is in decoder.embed_tokens,
# need check here, see issue #36247
if missing_keys is not None:
if "shared.weight" in missing_keys and "decoder.embed_tokens.weight" not in missing_keys:
self.encoder.embed_tokens.weight = self.decoder.embed_tokens.weight
self.shared.weight = self.decoder.embed_tokens.weight
missing_keys.discard("encoder.embed_token.weight")
missing_keys.discard("shared.weight")

# needs to be done after, otherwise it raises an Error because the correct weights are not present
super().tie_weights(missing_keys=missing_keys, recompute_mapping=recompute_mapping)

def get_input_embeddings(self):
return self.shared

Expand Down Expand Up @@ -1034,6 +1049,21 @@ def __init__(self, config: BartConfig):
# Initialize weights and apply final processing
self.post_init()

def tie_weights(self, missing_keys: Optional[set[str]] = None, recompute_mapping: bool = True):
"""We need to overload here to handle the wrong key saved in some main checkpoints."""
if self.config.tie_word_embeddings:
# Some model checkpoints like "facebook/bart-large-cnn"'s embedding weight is in decoder.embed_tokens,
# need check here, see issue #36247
if missing_keys is not None:
if "model.shared.weight" in missing_keys and "model.decoder.embed_tokens.weight" not in missing_keys:
self.model.encoder.embed_tokens.weight = self.model.decoder.embed_tokens.weight
self.model.shared.weight = self.model.decoder.embed_tokens.weight
missing_keys.discard("model.encoder.embed_token.weight")
missing_keys.discard("model.shared.weight")

# needs to be done after, otherwise it raises an Error because the correct weights are not present
super().tie_weights(missing_keys=missing_keys, recompute_mapping=recompute_mapping)

def resize_token_embeddings(
self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
) -> nn.Embedding:
Expand Down Expand Up @@ -1212,6 +1242,21 @@ def __init__(self, config: BartConfig, **kwargs):
# Initialize weights and apply final processing
self.post_init()

def tie_weights(self, missing_keys: Optional[set[str]] = None, recompute_mapping: bool = True):
"""We need to overload here to handle the wrong key saved in some main checkpoints."""
if self.config.tie_word_embeddings:
# Some model checkpoints like "facebook/bart-large-cnn"'s embedding weight is in decoder.embed_tokens,
# need check here, see issue #36247
if missing_keys is not None:
if "model.shared.weight" in missing_keys and "model.decoder.embed_tokens.weight" not in missing_keys:
self.model.encoder.embed_tokens.weight = self.model.decoder.embed_tokens.weight
self.model.shared.weight = self.model.decoder.embed_tokens.weight
missing_keys.discard("model.encoder.embed_token.weight")
missing_keys.discard("model.shared.weight")

# needs to be done after, otherwise it raises an Error because the correct weights are not present
super().tie_weights(missing_keys=missing_keys, recompute_mapping=recompute_mapping)

@auto_docstring
def forward(
self,
Expand Down Expand Up @@ -1343,6 +1388,21 @@ def __init__(self, config):
# Initialize weights and apply final processing
self.post_init()

def tie_weights(self, missing_keys: Optional[set[str]] = None, recompute_mapping: bool = True):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to add a FIXME/TODO here to cleanup after allowing the reverse direction in tied weights?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will need to remove all model-specific tie_weights anyway at that time, so fine like this IMO. A TODO does not add much value as we need to search for it and we don't always do anyway

"""We need to overload here to handle the wrong key saved in some main checkpoints."""
if self.config.tie_word_embeddings:
# Some model checkpoints like "facebook/bart-large-cnn"'s embedding weight is in decoder.embed_tokens,
# need check here, see issue #36247
if missing_keys is not None:
if "model.shared.weight" in missing_keys and "model.decoder.embed_tokens.weight" not in missing_keys:
self.model.encoder.embed_tokens.weight = self.model.decoder.embed_tokens.weight
self.model.shared.weight = self.model.decoder.embed_tokens.weight
missing_keys.discard("model.encoder.embed_token.weight")
missing_keys.discard("model.shared.weight")

# needs to be done after, otherwise it raises an Error because the correct weights are not present
super().tie_weights(missing_keys=missing_keys, recompute_mapping=recompute_mapping)

@auto_docstring
def forward(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2183,14 +2183,14 @@ def forward(
The BigBirdPegasus Model with a language modeling head. Can be used for summarization.
"""
)
# Copied from transformers.models.bart.modeling_bart.BartForConditionalGeneration with Bart->BigBirdPegasus, BART->BIGBIRD_PEGASUS
class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel, GenerationMixin):
base_model_prefix = "model"
_tied_weights_keys = {
"lm_head.weight": "model.shared.weight",
}
_keys_to_ignore_on_load_missing = ["final_logits_bias"]

# Copied from transformers.models.bart.modeling_bart.BartForConditionalGeneration.__init__ with Bart->BigBirdPegasus, BART->BIGBIRD_PEGASUS
def __init__(self, config: BigBirdPegasusConfig):
super().__init__(config)
self.model = BigBirdPegasusModel(config)
Expand All @@ -2200,13 +2200,15 @@ def __init__(self, config: BigBirdPegasusConfig):
# Initialize weights and apply final processing
self.post_init()

# Copied from transformers.models.bart.modeling_bart.BartForConditionalGeneration.resize_token_embeddings with Bart->BigBirdPegasus, BART->BIGBIRD_PEGASUS
def resize_token_embeddings(
self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
return new_embeddings

# Copied from transformers.models.bart.modeling_bart.BartForConditionalGeneration._resize_final_logits_bias with Bart->BigBirdPegasus, BART->BIGBIRD_PEGASUS
def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
old_num_tokens = self.final_logits_bias.shape[-1]
if new_num_tokens <= old_num_tokens:
Expand All @@ -2217,7 +2219,6 @@ def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
self.register_buffer("final_logits_bias", new_bias)

@auto_docstring
# Ignore copy
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
Expand Down Expand Up @@ -2327,6 +2328,7 @@ def forward(
encoder_attentions=outputs.encoder_attentions,
)

# Copied from transformers.models.bart.modeling_bart.BartForConditionalGeneration.prepare_decoder_input_ids_from_labels with Bart->BigBirdPegasus, BART->BIGBIRD_PEGASUS
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)

Expand Down