diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 2c041bcaf398..c345c1a5bee4 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1499,11 +1499,8 @@ def post_init(self): # Attach the different parallel plans and tied weight keys to the top-most model, so that everything is # easily available self._tp_plan, self._ep_plan, self._pp_plan = {}, {}, {} - # Current submodel should register its tied weights keys only if the config is asking for it - if not self.config.tie_word_embeddings and not self.config.tie_encoder_decoder: - self.all_tied_weights_keys = {} - else: - self.all_tied_weights_keys = self._tied_weights_keys.copy() if self._tied_weights_keys is not None else {} + # Current submodel should register its tied weights + self.all_tied_weights_keys = self.get_expanded_tied_weights_keys(all_submodels=False) # If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config if self.base_model is self: self._pp_plan = self.config.base_model_pp_plan.copy() if self.config.base_model_pp_plan is not None else {} @@ -2341,99 +2338,155 @@ def smart_apply(self, fn): # Let the magic happen with this simple call self.smart_apply(self._initialize_weights) - def tie_weights(self, missing_keys: Optional[set[str]] = None): - """ - If set in the config, tie the weights between the input embeddings and the output embeddings, - and the encoder and decoder. This relies on the `_tied_weights_keys` dict. - - This is very sensible! For many reasons and especially this one: - ```python - from torch import nn - import torch - class MyClass(nn.Module): - def __init__(self): - super().__init__() - self.proj = nn.Linear(8,8) - self.bias = nn.Parameter(torch.empty(8)) - self.proj.bias = self.bias - - c = MyClass() - print(list(c.named_parameters())) + def get_expanded_tied_weights_keys(self, all_submodels: bool = False) -> dict: + r""" + Return the expanded tied weight keys (in case they contain modules or regex patterns) for only the current + model, or recursively for all submodels if `all_submodels=True` (i.e. it will re-check the config values for all + submodels). + + For almost all models, we only require to tie the embeddings, so the model has an internal property + `_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}`. In this case, the mapping is already + "expanded", i.e. it already contains full parameters, and this function will simply return a copy of the property. + For more complex patterns, e.g. for `DFineForObjectDetection`, we have the following attribute ``` - That's for a parameter, for a module, it will just remove the ones that are "shared" (that makes sense) and overwrite getattr for it. - - ```python - from torch import nn - import torch - class Decoder(nn.Module): - def __init__(self): - super().__init__() - self.embedding = nn.Embedding(8,8) - - class Encoder(nn.Module): - def __init__(self): - super().__init__() - self.embedding = nn.Embedding(8,8) - - class EncoderDecoder(nn.Module): - def __init__(self): - super().__init__() - self.encoder = Encoder() - self.decoder = Decoder() - self.encoder.embedding = self.decoder.embedding # setattr is convenient - - c = EncoderDecoder() - print(list(c.named_parameters())) + _tied_weights_keys = { + r"bbox_embed.(?![0])\d+": "bbox_embed.0", + r"class_embed.(?![0])\d+": "class_embed.0", + "model.decoder.class_embed": "class_embed", + "model.decoder.bbox_embed": "bbox_embed", + } ``` - Thus the order of the keys matters. If you tie `self.decoder.embedding` you can no longer tie anything inside it. - - If you call this function, it will always tie. There is only 1 tricky case, if all weights are missing, you still want to mention that - the ones you tied were missing. + In this case, the function looks up all the model's parameters and buffers, and matches all the params, + returning the following: + ``` + { + 'bbox_embed.1.layers.0.bias': 'bbox_embed.0.layers.0.bias', + 'bbox_embed.1.layers.0.weight': 'bbox_embed.0.layers.0.weight', + 'bbox_embed.1.layers.1.bias': 'bbox_embed.0.layers.1.bias', + 'bbox_embed.1.layers.1.weight': 'bbox_embed.0.layers.1.weight', + 'bbox_embed.1.layers.2.bias': 'bbox_embed.0.layers.2.bias', + 'bbox_embed.1.layers.2.weight': 'bbox_embed.0.layers.2.weight', + 'bbox_embed.2.layers.0.bias': 'bbox_embed.0.layers.0.bias', + 'bbox_embed.2.layers.0.weight': 'bbox_embed.0.layers.0.weight', + ... + 'class_embed.1.bias': 'class_embed.0.bias', + 'class_embed.1.weight': 'class_embed.0.weight', + 'class_embed.2.bias': 'class_embed.0.bias', + 'class_embed.2.weight': 'class_embed.0.weight', + ... + 'model.decoder.class_embed.0.bias': 'class_embed.0.bias', + 'model.decoder.class_embed.0.weight': 'class_embed.0.weight', + 'model.decoder.class_embed.1.bias': 'class_embed.0.bias', + 'model.decoder.class_embed.1.weight': 'class_embed.0.weight', + ... + 'model.decoder.bbox_embed.0.layers.0.bias': 'bbox_embed.0.layers.0.bias', + 'model.decoder.bbox_embed.0.layers.0.weight': 'bbox_embed.0.layers.0.weight', + 'model.decoder.bbox_embed.0.layers.1.bias': 'bbox_embed.0.layers.1.bias', + 'model.decoder.bbox_embed.0.layers.1.weight': 'bbox_embed.0.layers.1.weight', + ... + } + ``` + i.e. all the parameters matching the regex and modules patterns in `_tied_weights_keys` """ - # TODO Cyril: using this fixed set of keys (set in post_init()) does not allow to switch the config flag and re-tie - mapping = getattr(self, "all_tied_weights_keys", None) - if not isinstance(mapping, dict): - return + if all_submodels: + expanded_tied_weights = {} + for prefix, submodule in self.named_modules(remove_duplicate=False): + if isinstance(submodule, PreTrainedModel): + # Will dynamically check the config if it has changed + submodel_tied_weights = submodule.get_expanded_tied_weights_keys(all_submodels=False) + if prefix != "": + submodel_tied_weights = { + f"{prefix}.{k}": f"{prefix}.{v}" for k, v in submodel_tied_weights.items() + } + expanded_tied_weights.update(submodel_tied_weights) + return expanded_tied_weights - # TODO let's pray this is not too slow :) - top_level_params = dict(self.named_parameters(remove_duplicate=False)) | dict( - self.named_buffers(remove_duplicate=False) - ) - for target_name, source_name in mapping.items(): - source_name = "^" + source_name + tied_mapping = self._tied_weights_keys + # If the config does not specify any tying, return empty dict + if not self.config.tie_word_embeddings and not self.config.tie_encoder_decoder: + return {} + # If None, return empty dict + elif tied_mapping is None: + return {} + # Short-cut for the most common cases: if the tied weights mapping only contains already expanded params, + # return it directly (the regex matches names containing only letters, numbers, dots, and underscores to make + # sure it does not contain a regex pattern, and finishing by "bias" or "weight" to make sure it's not a module) + common_case_regex = re.compile(r"^[A-Za-z0-9_\.]+(weight)|(bias)$") + if all(common_case_regex.match(k) for k in tied_mapping.keys() | tied_mapping.values()): + return tied_mapping.copy() + + # We need to expand the regex patterns or the modules into proper parameters + expanded_tied_weights = {} + all_param_names = {k for k, _ in self.named_parameters(remove_duplicate=False)} | { + k for k, _ in self.named_buffers(remove_duplicate=False) + } + for target_name, source_name in tied_mapping.items(): target_name = "^" + target_name + source_name = "^" + source_name - source_is_there = bool(missing_keys) and not re.search( - source_name, "\n".join(missing_keys), flags=re.MULTILINE - ) - source_params = sorted(filter(lambda x: re.search(source_name, x), top_level_params.keys())) - target_params = sorted(filter(lambda x: re.search(target_name, x), top_level_params.keys())) - if not len(source_params) > 0 or len(target_params) % len(source_params) != 0: + source_params = sorted(filter(lambda x: re.search(source_name, x), all_param_names)) + target_params = sorted(filter(lambda x: re.search(target_name, x), all_param_names)) + if ( + not len(source_params) > 0 + or not len(target_params) > 0 + or len(target_params) % len(source_params) != 0 + ): raise ValueError( - f"There is an issue with your definition of `tie_weights_keys` for {source_name}:{target_name}. We found {source_params} to tie into {target_params}" + f"There is an issue with your definition of `tie_weights_keys` for {source_name}:{target_name}. " + f"We found {source_params} to tie into {target_params}" ) - if len(target_params) > 0: - # we cycle source as it should be dispatch in many target if regex - for target_n, source_n in zip(target_params, cycle(source_params)): - if "." in target_n: - parent_path, last = target_n.rsplit(".", 1) - parent = self.get_submodule(parent_path) - else: - parent_path, last = "", target_n - parent = self # top-level - setattr(parent, last, top_level_params[source_n]) - self._adjust_bias(parent, top_level_params[source_n]) - if missing_keys and source_is_there: # test_model_weights_reload_no_missing_tied_weights - missing_keys.discard(target_n) + # we cycle source as it should be dispatch in many target if regex + for target_n, source_n in zip(target_params, cycle(source_params)): + # If the source is already registed as a target, use the original corresponding source. This should never + # happen in general, but some models such as `d_fine` have complicated regex patterns, so it end up being + # the case for simplicity of the regexes. Fix it silently here + if source_n in expanded_tied_weights.keys(): + # Use original source instead of having keys both as source and targets + expanded_tied_weights[target_n] = expanded_tied_weights[source_n] + # Usual case, everything is already correct + else: + expanded_tied_weights[target_n] = source_n + + return expanded_tied_weights + + def tie_weights(self, missing_keys: Optional[set[str]] = None, recompute_mapping: bool = True): + """ + Tie the model weights. If `recompute_mapping=False` (default when called internally), it will rely on the + `model.all_tied_weights_keys` attribute, containing the `{target: source}` mapping for the tied params. + If `recompute_mapping=True`, it will re-check all internal submodels and their config to determine the params + that need to be tied. This is the default when `model.tie_weights()` is called on its own, outside of + `__init__`, and `from_pretrained`, in case the config values were changed somewhere. + """ + # In this case, the keys stored in `all_tied_weights_keys` are already correct + if not recompute_mapping: + tied_keys = self.all_tied_weights_keys + else: + tied_keys = self.get_expanded_tied_weights_keys(all_submodels=True) + + for target_param_name, source_param_name in tied_keys.items(): + source_param = self.get_parameter_or_buffer(source_param_name) + if "." in target_param_name: + parent_name, name = target_param_name.rsplit(".", 1) + parent = self.get_submodule(parent_name) else: - target_is_not_there = missing_keys and re.search( - target_name, "\n".join(missing_keys), flags=re.MULTILINE - ) - raise ValueError( - "There is a problem in the way you tie your keys or the way they were saved.\n" - f"source_is_there={source_is_there}, target_is_there={not target_is_not_there}, missing_keys={missing_keys}," - "tie_word_embeddings/tie_encoder_decoder={(self.config.tie_word_embeddings or self.config.tie_encoder_decoder)}" - ) + name = target_param_name + parent = self + setattr(parent, name, source_param) + self._adjust_bias(parent, source_param) + if missing_keys is not None: + source_is_there = source_param_name not in missing_keys + target_is_there = target_param_name not in missing_keys + # If we tied correctly, remove the target from the missing keys + if source_is_there: + missing_keys.discard(target_param_name) + # If the source is not present, but the target is, the checkpoint is corrupted + # TODO: maybe we could simply tie in the opposite direction here instead of error? + elif target_is_there: + raise ValueError( + f"This checkpoint seem corrupted. The tied weights mapping for this model specifies to tie " + f"{source_param_name} (which should be present and is not), to {target_param_name} (which is " + f"present)." + ) def _adjust_bias(self, output_embeddings, input_embeddings): if getattr(output_embeddings, "bias", None) is not None and hasattr(output_embeddings, "weight"): @@ -2949,9 +3002,8 @@ def init_weights(self): if _init_weights: # Initialize weights self.initialize_weights() - # Tie weights needs to be called as it figures out recursively if sub modules - # need to tie - self.tie_weights() + # Tie weights needs to be called here, but it can use the pre-computed `all_tied_weights_keys` + self.tie_weights(recompute_mapping=False) def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): """ @@ -4246,7 +4298,7 @@ def _load_pretrained_model( model._initialize_missing_keys(is_quantized) # Tie the weights - model.tie_weights(missing_keys) + model.tie_weights(missing_keys=missing_keys, recompute_mapping=False) # Adjust missing and unexpected keys missing_keys, unexpected_keys = model._adjust_missing_and_unexpected_keys(missing_keys, unexpected_keys) @@ -4563,13 +4615,8 @@ def mark_tied_weights_as_initialized(self): This is very important as most embeddings are tied, and they are huge params (vocabularies are often 256k), so running inits on them is very costly.""" for tied_param in self.all_tied_weights_keys.keys(): - # It's always a proper weight except for 2 or 3 old models where it's a regex or module set to None - # -> just skip it in those cases (they will just re-init before tying, so they loose the added optimization) - try: - param = self.get_parameter(tied_param) - param._is_hf_initialized = True - except AttributeError: - pass + param = self.get_parameter(tied_param) + param._is_hf_initialized = True def get_parameter_or_buffer(self, target: str): """ diff --git a/src/transformers/models/esm/modeling_esm.py b/src/transformers/models/esm/modeling_esm.py index c4b3e1ea11ec..a64684eddc18 100755 --- a/src/transformers/models/esm/modeling_esm.py +++ b/src/transformers/models/esm/modeling_esm.py @@ -732,8 +732,6 @@ def __init__(self, config): self.esm = EsmModel(config, add_pooling_layer=False) self.lm_head = EsmLMHead(config) - self.init_weights() - self.post_init() def get_output_embeddings(self): @@ -828,8 +826,6 @@ def __init__(self, config): self.esm = EsmModel(config, add_pooling_layer=False) self.classifier = EsmClassificationHead(config) - self.init_weights() - self.post_init() @can_return_tuple @@ -903,8 +899,6 @@ def __init__(self, config): self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = nn.Linear(config.hidden_size, config.num_labels) - self.init_weights() - self.post_init() @can_return_tuple diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index 614492f23ae0..8adc51426fef 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -993,7 +993,7 @@ def __init__(self, config, target_lang: Optional[str] = None): # Initialize weights and apply final processing self.post_init() - def tie_weights(self, missing_keys=None): + def tie_weights(self, **kwargs): """ This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when passing `target_lang=...` to `from_pretrained(...)`. diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index cb57a0fd8d20..fee52071b4c0 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -1111,7 +1111,7 @@ def __init__(self, config, vision_model=None): # Initialize weights and apply final processing self.post_init() - def tie_weights(self, missing_keys=None): + def tie_weights(self, **kwargs): """ Overwrite `transformers.modeling_utils.PreTrainedModel.tie_weights` to handle the case of IdeficsDecoupledLinear and IdeficsDecoupledEmbedding. diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index 28b5541b80e5..58d5c80b5bc5 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -857,7 +857,7 @@ def __init__(self, config, target_lang: Optional[str] = None): # Initialize weights and apply final processing self.post_init() - def tie_weights(self, missing_keys=None): + def tie_weights(self, **kwargs): """ This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when passing `target_lang=...` to `from_pretrained(...)`. diff --git a/src/transformers/models/sew_d/modeling_sew_d.py b/src/transformers/models/sew_d/modeling_sew_d.py index ca426c3e23be..acbd0ab16b82 100644 --- a/src/transformers/models/sew_d/modeling_sew_d.py +++ b/src/transformers/models/sew_d/modeling_sew_d.py @@ -1400,7 +1400,7 @@ def __init__(self, config, target_lang: Optional[str] = None): # Initialize weights and apply final processing self.post_init() - def tie_weights(self, missing_keys=None): + def tie_weights(self, **kwargs): """ This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when passing `target_lang=...` to `from_pretrained(...)`. diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index 03160d33bd90..30951a1f03ad 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -1210,7 +1210,7 @@ def __init__(self, config, target_lang: Optional[str] = None): # Initialize weights and apply final processing self.post_init() - def tie_weights(self, missing_keys=None): + def tie_weights(self, **kwargs): """ This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when passing `target_lang=...` to `from_pretrained(...)`. diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index b83c184e2b46..d3a405b758cf 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -1206,7 +1206,7 @@ def __init__(self, config, target_lang: Optional[str] = None): # Initialize weights and apply final processing self.post_init() - def tie_weights(self, missing_keys=None): + def tie_weights(self, **kwargs): """ This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when passing `target_lang=...` to `from_pretrained(...)`. diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index 361064816991..60b8f7f04c9e 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -1684,7 +1684,7 @@ def __init__(self, config, target_lang: Optional[str] = None): # Initialize weights and apply final processing self.post_init() - def tie_weights(self, missing_keys=None): + def tie_weights(self, **kwargs): """ This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when passing `target_lang=...` to `from_pretrained(...)`. diff --git a/src/transformers/models/wavlm/modeling_wavlm.py b/src/transformers/models/wavlm/modeling_wavlm.py index 25b600086582..06c571e0bf45 100755 --- a/src/transformers/models/wavlm/modeling_wavlm.py +++ b/src/transformers/models/wavlm/modeling_wavlm.py @@ -1135,7 +1135,7 @@ def __init__(self, config, target_lang: Optional[str] = None): # Initialize weights and apply final processing self.post_init() - def tie_weights(self, missing_keys=None): + def tie_weights(self, **kwargs): """ This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when passing `target_lang=...` to `from_pretrained(...)`. diff --git a/tests/models/openai/test_modeling_openai.py b/tests/models/openai/test_modeling_openai.py index c27f35656d7e..094dc816233e 100644 --- a/tests/models/openai/test_modeling_openai.py +++ b/tests/models/openai/test_modeling_openai.py @@ -269,6 +269,13 @@ def test_model_from_pretrained(self): model = OpenAIGPTModel.from_pretrained(model_name) self.assertIsNotNone(model) + @unittest.skip("Tied weights mapping is reversed, so this is supposed to error out") + def test_correct_missing_keys(self): + # openai defines `_tied_weights_keys = {"transformer.tokens_embed.weight": "lm_head.weight"}` instead + # of the usual `_tied_weights_keys = {"lm_head.weight": "transformer.tokens_embed.weight"}`, so removing + # the head parameters actually removes the source weight, so this test is supposed to fail + pass + @require_torch class OPENAIGPTModelLanguageGenerationTest(unittest.TestCase): diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 9c0be2abb8ac..2bea77d23c5d 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -180,7 +180,7 @@ def __init__(self, config): def forward(self, x): return self.linear_2(self.linear(x)) - def tie_weights(self, missing_keys=None): + def tie_weights(self, missing_keys=None, **kwargs): self.linear_2.weight = self.linear.weight if missing_keys is not None: missing_keys.discard("linear_2.weight") @@ -254,7 +254,7 @@ def __init__(self, config): def forward(self, x): return self.decoder(self.base(x)) - def tie_weights(self, missing_keys=None): + def tie_weights(self, missing_keys=None, **kwargs): self.decoder.weight = self.base.linear.weight if missing_keys is not None: missing_keys.discard("decoder.weight")