From ae5cadb1cda8806d0bfd0820a876b61d251286a4 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 14 Nov 2023 12:17:29 +0100 Subject: [PATCH 1/5] [WIP] Hard error when ignoring tensors. --- src/transformers/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index fcb51e6a56be2..2354fee19bfb9 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2146,7 +2146,7 @@ def save_pretrained( del state_dict[name] warn_names.add(name) if len(warn_names) > 0: - logger.warning_once( + raise Exception( f"Removed shared tensor {warn_names} while saving. This should be OK, but check by verifying that you don't receive any warning while reloading", ) From 92d1715d93cb4e4793a443e57cca2f5bd7201abb Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 14 Nov 2023 15:12:24 +0100 Subject: [PATCH 2/5] Better selection/error when saving a checkpoint. - Find all names we should normally drop (those are in the transformers config) - Find all disjoint tensors (for those we can safely trigger a copy to get rid of the sharing before saving) - Clone those disjoint tensors getting rid of the issue - Find all identical names (those should be declared in the config but we try to find them all anyway.) - For all identical names: - If they are in the config, just ignore them everything is fine - If they are not, warn about them. - For all remainder tensors which are shared yet neither identical NOR disjoint. raise a hard error. --- src/transformers/modeling_utils.py | 108 ++++++++++++++++++++++++----- 1 file changed, 92 insertions(+), 16 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 2354fee19bfb9..8e6f887c9f60f 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -27,7 +27,7 @@ from contextlib import contextmanager from dataclasses import dataclass from functools import partial, wraps -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union import torch from packaging import version @@ -515,6 +515,64 @@ def set_initialized_submodules(model, state_dict_keys): module._is_hf_initialized = True +def _end_ptr(tensor: torch.Tensor) -> int: + if tensor.nelement(): + stop = tensor.view(-1)[-1].data_ptr() + tensor.element_size() + else: + stop = tensor.data_ptr() + return stop + + +def _find_disjoint(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]) -> Tuple[List[Set[str]], Set[str]]: + filtered_tensors = [] + for shared in tensors: + if len(shared) < 2: + filtered_tensors.append(shared) + continue + + areas = [] + for name in shared: + tensor = state_dict[name] + areas.append((tensor.data_ptr(), _end_ptr(tensor), name)) + areas.sort() + + _, last_stop, last_name = areas[0] + filtered_tensors.append({last_name}) + for start, stop, name in areas[1:]: + if start >= last_stop: + filtered_tensors.append({name}) + else: + filtered_tensors[-1].add(name) + last_stop = stop + disjoint_tensors = [] + shared_tensors = [] + for tensors in filtered_tensors: + if len(tensors) == 1: + disjoint_tensors.append(tensors.pop()) + else: + shared_tensors.append(tensors) + return shared_tensors, disjoint_tensors + + +def _find_identical(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]) -> Tuple[List[Set[str]], Set[str]]: + shared_tensors = [] + identical = [] + for shared in tensors: + if len(shared) < 2: + continue + + areas = collections.defaultdict(set) + for name in shared: + tensor = state_dict[name] + area = (tensor.device, tensor.data_ptr(), _end_ptr(tensor)) + areas[area].add(name) + if len(areas) == 1: + identical.append(shared) + else: + shared_tensors.append(shared) + return shared_tensors, identical + + def _load_state_dict_into_model(model_to_load, state_dict, start_prefix): # Convert old format to new format if needed from a PyTorch state_dict old_keys = [] @@ -2121,6 +2179,8 @@ def save_pretrained( # These are all the pointers of shared tensors. shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1} warn_names = set() + error_names = set() + to_delete_names = set() for names in shared_ptrs.values(): # Removing the keys which are declared as known duplicates on # load. This allows to make sure the name which is kept is consistent. @@ -2131,25 +2191,41 @@ def save_pretrained( if matches_pattern and name in state_dict: found += 1 if found < len(names): - del state_dict[name] - - # When not all duplicates have been cleaned, still remove those keys, but put a clear warning. - # If the link between tensors was done at runtime then `from_pretrained` will not get - # the key back leading to random tensor. A proper warning will be shown - # during reload (if applicable), but since the file is not necessarily compatible with - # the config, better show a proper warning. - found = 0 - for name in names: - if name in state_dict: - found += 1 - if found > 1: - del state_dict[name] - warn_names.add(name) + to_delete_names.add(name) + # We are entering a place where the weights and the transformers configuration do NOT match. + names, disjoint_names = _find_disjoint(shared_ptrs.values(), state_dict) + # Those are actually tensor sharing but disjoint from each other, we can safely clone them + # Reloaded won't have the same property, but it shouldn't matter in any meaningful way. + for name in disjoint_names: + state_dict[name] = state_dict[name].clone() + + # When not all duplicates have been cleaned, still remove those keys, but put a clear warning. + # If the link between tensors was done at runtime then `from_pretrained` will not get + # the key back leading to random tensor. A proper warning will be shown + # during reload (if applicable), but since the file is not necessarily compatible with + # the config, better show a proper warning. + names, identical_names = _find_identical(names, state_dict) + for inames in identical_names: + known = inames.intersection(to_delete_names) + for name in known: + del state_dict[name] + unknown = sorted(inames.difference(to_delete_names)) + for name in unknown[1:]: + del state_dict[name] + warn_names.add(name) + + error_names.update(names) + if len(warn_names) > 0: - raise Exception( + logger.warning_once( f"Removed shared tensor {warn_names} while saving. This should be OK, but check by verifying that you don't receive any warning while reloading", ) + if len(error_names) > 0: + raise RuntimeError( + f"The weights trying to be saved contained shared tensors {error_names} that are mismatching the transformers base configuration. Try saving using `safe_serialization=False` or remove this tensor sharing.", + ) + # Shard the model if it is too big. if not _hf_peft_config_loaded: weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME From 88571f3d580c7013e096a865328d558b28a4fa74 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 14 Nov 2023 15:35:57 +0100 Subject: [PATCH 3/5] Adding a failing test on `main` that passes here. --- tests/test_modeling_utils.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index e457dc07a9fb7..b11605a6554dd 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -212,6 +212,31 @@ def test_model_from_pretrained_subfolder(self): self.assertTrue(check_models_equal(model, model_loaded)) + def test_model_manually_shared_disjointed_tensors_optimum(self): + config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert") + model = BertModel(config) + + subfolder = "bert" + # Let's fuse qkv + attn = model.encoder.layer[0].attention.self + q = attn.query.weight + k = attn.key.weight + v = attn.value.weight + # Force some shared storage + qkv = torch.stack([q, k, v], dim=0) + attn.query.weight = torch.nn.Parameter(qkv[0]) + attn.key.weight = torch.nn.Parameter(qkv[1]) + attn.value.weight = torch.nn.Parameter(qkv[2]) + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(os.path.join(tmp_dir, subfolder)) + + with self.assertRaises(OSError): + _ = BertModel.from_pretrained(tmp_dir) + + model_loaded = BertModel.from_pretrained(tmp_dir, subfolder=subfolder) + + self.assertTrue(check_models_equal(model, model_loaded)) + def test_model_from_pretrained_subfolder_sharded(self): config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert") model = BertModel(config) From d8e1ed17ee7e640a1d5ba999345c71d4039a5a34 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 14 Nov 2023 15:40:32 +0100 Subject: [PATCH 4/5] We don't need to keep the subfolder logic in this test. --- tests/test_modeling_utils.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index b11605a6554dd..e84e3821a5100 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -216,7 +216,6 @@ def test_model_manually_shared_disjointed_tensors_optimum(self): config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert") model = BertModel(config) - subfolder = "bert" # Let's fuse qkv attn = model.encoder.layer[0].attention.self q = attn.query.weight @@ -228,12 +227,8 @@ def test_model_manually_shared_disjointed_tensors_optimum(self): attn.key.weight = torch.nn.Parameter(qkv[1]) attn.value.weight = torch.nn.Parameter(qkv[2]) with tempfile.TemporaryDirectory() as tmp_dir: - model.save_pretrained(os.path.join(tmp_dir, subfolder)) - - with self.assertRaises(OSError): - _ = BertModel.from_pretrained(tmp_dir) - - model_loaded = BertModel.from_pretrained(tmp_dir, subfolder=subfolder) + model.save_pretrained(tmp_dir) + model_loaded = BertModel.from_pretrained(tmp_dir) self.assertTrue(check_models_equal(model, model_loaded)) From 7ee5bd03646eda152359fb82d0682b0d32cf42e0 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 28 Nov 2023 18:34:41 +0100 Subject: [PATCH 5/5] Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/modeling_utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 8e6f887c9f60f..1fa6818d42274 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -516,6 +516,7 @@ def set_initialized_submodules(model, state_dict_keys): def _end_ptr(tensor: torch.Tensor) -> int: + # extract the end of the pointer if the tensor is a slice of a bigger tensor if tensor.nelement(): stop = tensor.view(-1)[-1].data_ptr() + tensor.element_size() else: @@ -2193,7 +2194,7 @@ def save_pretrained( if found < len(names): to_delete_names.add(name) # We are entering a place where the weights and the transformers configuration do NOT match. - names, disjoint_names = _find_disjoint(shared_ptrs.values(), state_dict) + shared_names, disjoint_names = _find_disjoint(shared_ptrs.values(), state_dict) # Those are actually tensor sharing but disjoint from each other, we can safely clone them # Reloaded won't have the same property, but it shouldn't matter in any meaningful way. for name in disjoint_names: @@ -2204,7 +2205,8 @@ def save_pretrained( # the key back leading to random tensor. A proper warning will be shown # during reload (if applicable), but since the file is not necessarily compatible with # the config, better show a proper warning. - names, identical_names = _find_identical(names, state_dict) + shared_names, identical_names = _find_identical(shared_names, state_dict) + # delete tensors that have identical storage for inames in identical_names: known = inames.intersection(to_delete_names) for name in known: @@ -2214,7 +2216,7 @@ def save_pretrained( del state_dict[name] warn_names.add(name) - error_names.update(names) + error_names.update(shared_names) if len(warn_names) > 0: logger.warning_once(