diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 19aab734784a4..fd0afa521a145 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -30,7 +30,7 @@ from dataclasses import dataclass from functools import partial, wraps from threading import Thread -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union from zipfile import is_zipfile import torch @@ -573,6 +573,79 @@ def set_initialized_submodules(model, state_dict_keys): return not_initialized_submodules +def _end_ptr(tensor: torch.Tensor) -> int: + # extract the end of the pointer if the tensor is a slice of a bigger tensor + if tensor.nelement(): + stop = tensor.view(-1)[-1].data_ptr() + tensor.element_size() + else: + stop = tensor.data_ptr() + return stop + + +def _get_tied_weight_keys(module: nn.Module, prefix=""): + tied_weight_keys = [] + if getattr(module, "_tied_weights_keys", None) is not None: + names = [f"{prefix}.{k}" if prefix else k for k in module._tied_weights_keys] + tied_weight_keys.extend(names) + if getattr(module, "_dynamic_tied_weights_keys", None) is not None: + names = [f"{prefix}.{k}" if prefix else k for k in module._dynamic_tied_weights_keys] + tied_weight_keys.extend(names) + for name, submodule in module.named_children(): + local_prefix = f"{prefix}.{name}" if prefix else name + tied_weight_keys.extend(_get_tied_weight_keys(submodule, prefix=local_prefix)) + return tied_weight_keys + + +def _find_disjoint(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]) -> Tuple[List[Set[str]], List[str]]: + filtered_tensors = [] + for shared in tensors: + if len(shared) < 2: + filtered_tensors.append(shared) + continue + + areas = [] + for name in shared: + tensor = state_dict[name] + areas.append((tensor.data_ptr(), _end_ptr(tensor), name)) + areas.sort() + + _, last_stop, last_name = areas[0] + filtered_tensors.append({last_name}) + for start, stop, name in areas[1:]: + if start >= last_stop: + filtered_tensors.append({name}) + else: + filtered_tensors[-1].add(name) + last_stop = stop + disjoint_tensors = [] + shared_tensors = [] + for tensors in filtered_tensors: + if len(tensors) == 1: + disjoint_tensors.append(tensors.pop()) + else: + shared_tensors.append(tensors) + return shared_tensors, disjoint_tensors + + +def _find_identical(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]) -> Tuple[List[Set[str]], Set[str]]: + shared_tensors = [] + identical = [] + for shared in tensors: + if len(shared) < 2: + continue + + areas = collections.defaultdict(set) + for name in shared: + tensor = state_dict[name] + area = (tensor.device, tensor.data_ptr(), _end_ptr(tensor)) + areas[area].add(name) + if len(areas) == 1: + identical.append(shared) + else: + shared_tensors.append(shared) + return shared_tensors, identical + + def _load_state_dict_into_model(model_to_load, state_dict, start_prefix): # Convert old format to new format if needed from a PyTorch state_dict old_keys = [] @@ -1646,15 +1719,24 @@ def tie_weights(self): if getattr(self.config, "is_encoder_decoder", False) and getattr(self.config, "tie_encoder_decoder", False): if hasattr(self, self.base_model_prefix): self = getattr(self, self.base_model_prefix) - self._tie_encoder_decoder_weights(self.encoder, self.decoder, self.base_model_prefix) + tied_weights = self._tie_encoder_decoder_weights( + self.encoder, self.decoder, self.base_model_prefix, "encoder" + ) + # Setting a dynamic variable instead of `_tied_weights_keys` because it's a class + # attributed not an instance member, therefore modifying it will modify the entire class + # Leading to issues on subsequent calls by different tests or subsequent calls. + self._dynamic_tied_weights_keys = tied_weights for module in self.modules(): if hasattr(module, "_tie_weights"): module._tie_weights() @staticmethod - def _tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str): + def _tie_encoder_decoder_weights( + encoder: nn.Module, decoder: nn.Module, base_model_prefix: str, base_encoder_name: str + ): uninitialized_encoder_weights: List[str] = [] + tied_weights: List[str] = [] if decoder.__class__ != encoder.__class__: logger.info( f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder" @@ -1665,8 +1747,11 @@ def tie_encoder_to_decoder_recursively( decoder_pointer: nn.Module, encoder_pointer: nn.Module, module_name: str, + base_encoder_name: str, uninitialized_encoder_weights: List[str], depth=0, + total_decoder_name="", + total_encoder_name="", ): assert isinstance(decoder_pointer, nn.Module) and isinstance( encoder_pointer, nn.Module @@ -1674,8 +1759,10 @@ def tie_encoder_to_decoder_recursively( if hasattr(decoder_pointer, "weight"): assert hasattr(encoder_pointer, "weight") encoder_pointer.weight = decoder_pointer.weight + tied_weights.append(f"{base_encoder_name}{total_encoder_name}.weight") if hasattr(decoder_pointer, "bias"): assert hasattr(encoder_pointer, "bias") + tied_weights.append(f"{base_encoder_name}{total_encoder_name}.bias") encoder_pointer.bias = decoder_pointer.bias return @@ -1713,19 +1800,26 @@ def tie_encoder_to_decoder_recursively( decoder_modules[decoder_name], encoder_modules[encoder_name], module_name + "/" + name, + base_encoder_name, uninitialized_encoder_weights, depth=depth + 1, + total_encoder_name=f"{total_encoder_name}.{encoder_name}", + total_decoder_name=f"{total_decoder_name}.{decoder_name}", ) all_encoder_weights.remove(module_name + "/" + encoder_name) uninitialized_encoder_weights += list(all_encoder_weights) # tie weights recursively - tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, uninitialized_encoder_weights) + tie_encoder_to_decoder_recursively( + decoder, encoder, base_model_prefix, base_encoder_name, uninitialized_encoder_weights + ) + if len(uninitialized_encoder_weights) > 0: logger.warning( f"The following encoder weights were not tied to the decoder {uninitialized_encoder_weights}" ) + return tied_weights def _tie_or_clone_weights(self, output_embeddings, input_embeddings): """Tie or clone module weights depending of whether we are using TorchScript or not""" @@ -2402,34 +2496,49 @@ def save_pretrained( # These are all the pointers of shared tensors. shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1} - warn_names = set() + error_names = [] + to_delete_names = set() + # Recursively descend to find tied weight keys + _tied_weights_keys = _get_tied_weight_keys(self) for names in shared_ptrs.values(): # Removing the keys which are declared as known duplicates on # load. This allows to make sure the name which is kept is consistent. - if self._tied_weights_keys is not None: + if _tied_weights_keys is not None: found = 0 for name in sorted(names): - matches_pattern = any(re.search(pat, name) for pat in self._tied_weights_keys) + matches_pattern = any(re.search(pat, name) for pat in _tied_weights_keys) if matches_pattern and name in state_dict: found += 1 if found < len(names): - del state_dict[name] - - # When not all duplicates have been cleaned, still remove those keys, but put a clear warning. - # If the link between tensors was done at runtime then `from_pretrained` will not get - # the key back leading to random tensor. A proper warning will be shown - # during reload (if applicable), but since the file is not necessarily compatible with - # the config, better show a proper warning. - found = 0 - for name in names: - if name in state_dict: - found += 1 - if found > 1: - del state_dict[name] - warn_names.add(name) - if len(warn_names) > 0: - logger.warning_once( - f"Removed shared tensor {warn_names} while saving. This should be OK, but check by verifying that you don't receive any warning while reloading", + to_delete_names.add(name) + # We are entering a place where the weights and the transformers configuration do NOT match. + shared_names, disjoint_names = _find_disjoint(shared_ptrs.values(), state_dict) + # Those are actually tensor sharing but disjoint from each other, we can safely clone them + # Reloaded won't have the same property, but it shouldn't matter in any meaningful way. + for name in disjoint_names: + state_dict[name] = state_dict[name].clone() + + # When not all duplicates have been cleaned, still remove those keys, but put a clear warning. + # If the link between tensors was done at runtime then `from_pretrained` will not get + # the key back leading to random tensor. A proper warning will be shown + # during reload (if applicable), but since the file is not necessarily compatible with + # the config, better show a proper warning. + shared_names, identical_names = _find_identical(shared_names, state_dict) + # delete tensors that have identical storage + for inames in identical_names: + known = inames.intersection(to_delete_names) + for name in known: + del state_dict[name] + unknown = inames.difference(to_delete_names) + if len(unknown) > 1: + error_names.append(unknown) + + if shared_names: + error_names.append(set(shared_names)) + + if len(error_names) > 0: + raise RuntimeError( + f"The weights trying to be saved contained shared tensors {error_names} that are mismatching the transformers base configuration. Try saving using `safe_serialization=False` or remove this tensor sharing.", ) # Shard the model if it is too big. diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 1b06c375780b7..262fc79f0d403 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -15,7 +15,6 @@ # limitations under the License. """PyTorch BERT model.""" - import math import os import warnings @@ -1128,7 +1127,7 @@ def forward( """Bert Model with a `language modeling` head on top for CLM fine-tuning.""", BERT_START_DOCSTRING ) class BertLMHeadModel(BertPreTrainedModel): - _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"] + _tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"] def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py index 1a6adcee1f838..16248fee64ce5 100644 --- a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py @@ -262,9 +262,16 @@ def tie_weights(self): if self.config.tie_encoder_decoder: # tie encoder and decoder base model decoder_base_model_prefix = self.decoder.base_model_prefix - self._tie_encoder_decoder_weights( - self.encoder, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix + tied_weights = self._tie_encoder_decoder_weights( + self.encoder, + self.decoder._modules[decoder_base_model_prefix], + self.decoder.base_model_prefix, + "encoder", ) + # Setting a dynamic variable instead of `_tied_weights_keys` because it's a class + # attributed not an instance member, therefore modifying it will modify the entire class + # Leading to issues on subsequent calls by different tests or subsequent calls. + self._dynamic_tied_weights_keys = tied_weights def get_encoder(self): return self.encoder diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 7c39acbcd4361..10d7f1b6b2d16 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -1343,7 +1343,13 @@ def tie_weights(self): if getattr(self.config, "is_encoder_decoder", False) and getattr(self.config, "tie_encoder_decoder", False): if hasattr(self, self.base_model_prefix): self = getattr(self, self.base_model_prefix) - self._tie_encoder_decoder_weights(self.encoder, self.decoder, self.base_model_prefix) + tied_weights = self._tie_encoder_decoder_weights( + self.encoder, self.decoder, self.base_model_prefix, "encoder" + ) + # Setting a dynamic variable instead of `_tied_weights_keys` because it's a class + # attributed not an instance member, therefore modifying it will modify the entire class + # Leading to issues on subsequent calls by different tests or subsequent calls. + self._dynamic_tied_weights_keys = tied_weights for module in self.modules(): if hasattr(module, "_tie_weights"): diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index 2520268f746f4..7e7c7cb7232c5 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -1891,9 +1891,16 @@ def tie_weights(self): if self.config.tie_encoder_decoder: # tie text encoder and decoder base model decoder_base_model_prefix = self.decoder.base_model_prefix - self._tie_encoder_decoder_weights( - self.text_encoder, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix + tied_weights = self._tie_encoder_decoder_weights( + self.text_encoder, + self.decoder._modules[decoder_base_model_prefix], + self.decoder.base_model_prefix, + "text_encoder", ) + # Setting a dynamic variable instead of `_tied_weights_keys` because it's a class + # attributed not an instance member, therefore modifying it will modify the entire class + # Leading to issues on subsequent calls by different tests or subsequent calls. + self._dynamic_tied_weights_keys = tied_weights def get_audio_encoder(self): return self.audio_encoder diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index 8b0afb2367352..0840635f6535b 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -1810,9 +1810,16 @@ def tie_weights(self): if self.config.tie_encoder_decoder: # tie text encoder and decoder base model decoder_base_model_prefix = self.decoder.base_model_prefix - self._tie_encoder_decoder_weights( - self.text_encoder, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix + tied_weights = self._tie_encoder_decoder_weights( + self.text_encoder, + self.decoder._modules[decoder_base_model_prefix], + self.decoder.base_model_prefix, + "text_encoder", ) + # Setting a dynamic variable instead of `_tied_weights_keys` because it's a class + # attributed not an instance member, therefore modifying it will modify the entire class + # Leading to issues on subsequent calls by different tests or subsequent calls. + self._dynamic_tied_weights_keys = tied_weights def get_text_encoder(self): return self.text_encoder diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 7f82d0dfcaf63..e6f57d68cc6a3 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -101,7 +101,7 @@ _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask, ) - from transformers.modeling_utils import shard_checkpoint + from transformers.modeling_utils import _find_disjoint, _find_identical, shard_checkpoint # Fake pretrained models for tests class BaseModel(PreTrainedModel): @@ -256,6 +256,26 @@ def test_model_from_pretrained_subfolder(self): self.assertTrue(check_models_equal(model, model_loaded)) + def test_model_manually_shared_disjointed_tensors_optimum(self): + config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert") + model = BertModel(config) + + # Let's fuse qkv + attn = model.encoder.layer[0].attention.self + q = attn.query.weight + k = attn.key.weight + v = attn.value.weight + # Force some shared storage + qkv = torch.stack([q, k, v], dim=0) + attn.query.weight = torch.nn.Parameter(qkv[0]) + attn.key.weight = torch.nn.Parameter(qkv[1]) + attn.value.weight = torch.nn.Parameter(qkv[2]) + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir) + model_loaded = BertModel.from_pretrained(tmp_dir) + + self.assertTrue(check_models_equal(model, model_loaded)) + def test_model_from_pretrained_subfolder_sharded(self): config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert") model = BertModel(config) @@ -2222,3 +2242,40 @@ def test_partial_stacked_causal_mask(self): ] self.assertEqual(decoded_0, decoded_1b) + + +@require_torch +class TestTensorSharing(TestCasePlus): + def test_disjoint(self): + main = torch.zeros(10) + a = main[:5] + b = main[5:] + state_dict = {"a": a, "b": b} + + shared_names, disjoint_names = _find_disjoint([{"a", "b"}], state_dict) + self.assertEqual(shared_names, []) + self.assertEqual(disjoint_names, ["a", "b"]) + + a = main[::2] + b = main[1::2] + state_dict = {"a": a, "b": b} + + shared_names, disjoint_names = _find_disjoint([{"a", "b"}], state_dict) + self.assertEqual(shared_names, [{"a", "b"}]) + self.assertEqual(disjoint_names, []) + + def test_identical(self): + a = torch.zeros(10) + b = a + state_dict = {"a": a, "b": b} + + shared_names, identical_names = _find_identical([{"a", "b"}], state_dict) + self.assertEqual(shared_names, []) + self.assertEqual(identical_names, [{"a", "b"}]) + + b = a[:5] + state_dict = {"a": a, "b": b} + + shared_names, identical_names = _find_identical([{"a", "b"}], state_dict) + self.assertEqual(shared_names, [{"a", "b"}]) + self.assertEqual(identical_names, [])