diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index d663821578a2..925bbe8ac372 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -897,6 +897,21 @@ def __init__(self, config: BartConfig): # Initialize weights and apply final processing self.post_init() + def tie_weights(self, missing_keys: Optional[set[str]] = None, recompute_mapping: bool = True): + """We need to overload here to handle the wrong key saved in some main checkpoints.""" + if self.config.tie_word_embeddings: + # Some model checkpoints like "facebook/bart-large-cnn"'s embedding weight is in decoder.embed_tokens, + # need check here, see issue #36247 + if missing_keys is not None: + if "shared.weight" in missing_keys and "decoder.embed_tokens.weight" not in missing_keys: + self.encoder.embed_tokens.weight = self.decoder.embed_tokens.weight + self.shared.weight = self.decoder.embed_tokens.weight + missing_keys.discard("encoder.embed_token.weight") + missing_keys.discard("shared.weight") + + # needs to be done after, otherwise it raises an Error because the correct weights are not present + super().tie_weights(missing_keys=missing_keys, recompute_mapping=recompute_mapping) + def get_input_embeddings(self): return self.shared @@ -1034,6 +1049,21 @@ def __init__(self, config: BartConfig): # Initialize weights and apply final processing self.post_init() + def tie_weights(self, missing_keys: Optional[set[str]] = None, recompute_mapping: bool = True): + """We need to overload here to handle the wrong key saved in some main checkpoints.""" + if self.config.tie_word_embeddings: + # Some model checkpoints like "facebook/bart-large-cnn"'s embedding weight is in decoder.embed_tokens, + # need check here, see issue #36247 + if missing_keys is not None: + if "model.shared.weight" in missing_keys and "model.decoder.embed_tokens.weight" not in missing_keys: + self.model.encoder.embed_tokens.weight = self.model.decoder.embed_tokens.weight + self.model.shared.weight = self.model.decoder.embed_tokens.weight + missing_keys.discard("model.encoder.embed_token.weight") + missing_keys.discard("model.shared.weight") + + # needs to be done after, otherwise it raises an Error because the correct weights are not present + super().tie_weights(missing_keys=missing_keys, recompute_mapping=recompute_mapping) + def resize_token_embeddings( self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True ) -> nn.Embedding: @@ -1212,6 +1242,21 @@ def __init__(self, config: BartConfig, **kwargs): # Initialize weights and apply final processing self.post_init() + def tie_weights(self, missing_keys: Optional[set[str]] = None, recompute_mapping: bool = True): + """We need to overload here to handle the wrong key saved in some main checkpoints.""" + if self.config.tie_word_embeddings: + # Some model checkpoints like "facebook/bart-large-cnn"'s embedding weight is in decoder.embed_tokens, + # need check here, see issue #36247 + if missing_keys is not None: + if "model.shared.weight" in missing_keys and "model.decoder.embed_tokens.weight" not in missing_keys: + self.model.encoder.embed_tokens.weight = self.model.decoder.embed_tokens.weight + self.model.shared.weight = self.model.decoder.embed_tokens.weight + missing_keys.discard("model.encoder.embed_token.weight") + missing_keys.discard("model.shared.weight") + + # needs to be done after, otherwise it raises an Error because the correct weights are not present + super().tie_weights(missing_keys=missing_keys, recompute_mapping=recompute_mapping) + @auto_docstring def forward( self, @@ -1343,6 +1388,21 @@ def __init__(self, config): # Initialize weights and apply final processing self.post_init() + def tie_weights(self, missing_keys: Optional[set[str]] = None, recompute_mapping: bool = True): + """We need to overload here to handle the wrong key saved in some main checkpoints.""" + if self.config.tie_word_embeddings: + # Some model checkpoints like "facebook/bart-large-cnn"'s embedding weight is in decoder.embed_tokens, + # need check here, see issue #36247 + if missing_keys is not None: + if "model.shared.weight" in missing_keys and "model.decoder.embed_tokens.weight" not in missing_keys: + self.model.encoder.embed_tokens.weight = self.model.decoder.embed_tokens.weight + self.model.shared.weight = self.model.decoder.embed_tokens.weight + missing_keys.discard("model.encoder.embed_token.weight") + missing_keys.discard("model.shared.weight") + + # needs to be done after, otherwise it raises an Error because the correct weights are not present + super().tie_weights(missing_keys=missing_keys, recompute_mapping=recompute_mapping) + @auto_docstring def forward( self, diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 078f78fff2a9..b42d5ce37409 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -2183,7 +2183,6 @@ def forward( The BigBirdPegasus Model with a language modeling head. Can be used for summarization. """ ) -# Copied from transformers.models.bart.modeling_bart.BartForConditionalGeneration with Bart->BigBirdPegasus, BART->BIGBIRD_PEGASUS class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel, GenerationMixin): base_model_prefix = "model" _tied_weights_keys = { @@ -2191,6 +2190,7 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel, Gene } _keys_to_ignore_on_load_missing = ["final_logits_bias"] + # Copied from transformers.models.bart.modeling_bart.BartForConditionalGeneration.__init__ with Bart->BigBirdPegasus, BART->BIGBIRD_PEGASUS def __init__(self, config: BigBirdPegasusConfig): super().__init__(config) self.model = BigBirdPegasusModel(config) @@ -2200,6 +2200,7 @@ def __init__(self, config: BigBirdPegasusConfig): # Initialize weights and apply final processing self.post_init() + # Copied from transformers.models.bart.modeling_bart.BartForConditionalGeneration.resize_token_embeddings with Bart->BigBirdPegasus, BART->BIGBIRD_PEGASUS def resize_token_embeddings( self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True ) -> nn.Embedding: @@ -2207,6 +2208,7 @@ def resize_token_embeddings( self._resize_final_logits_bias(new_embeddings.weight.shape[0]) return new_embeddings + # Copied from transformers.models.bart.modeling_bart.BartForConditionalGeneration._resize_final_logits_bias with Bart->BigBirdPegasus, BART->BIGBIRD_PEGASUS def _resize_final_logits_bias(self, new_num_tokens: int) -> None: old_num_tokens = self.final_logits_bias.shape[-1] if new_num_tokens <= old_num_tokens: @@ -2217,7 +2219,6 @@ def _resize_final_logits_bias(self, new_num_tokens: int) -> None: self.register_buffer("final_logits_bias", new_bias) @auto_docstring - # Ignore copy def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -2327,6 +2328,7 @@ def forward( encoder_attentions=outputs.encoder_attentions, ) + # Copied from transformers.models.bart.modeling_bart.BartForConditionalGeneration.prepare_decoder_input_ids_from_labels with Bart->BigBirdPegasus, BART->BIGBIRD_PEGASUS def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)