Skip to content

Commit

Permalink
Hard error when ignoring tensors. (#27484) (#29906)
Browse files Browse the repository at this point in the history
* Hard error when ignoring tensors. (#27484)

* [WIP] Hard error when ignoring tensors.

* Better selection/error when saving a checkpoint.

- Find all names we should normally drop (those are in the transformers
  config)
- Find all disjoint tensors (for those we can safely trigger a copy to
  get rid of the sharing before saving)
- Clone those disjoint tensors getting rid of the issue
- Find all identical names (those should be declared in the config
  but we try to find them all anyway.)
- For all identical names:
  - If they are in the config, just ignore them everything is fine
  - If they are not, warn about them.
- For all remainder tensors which are shared yet neither identical NOR
  disjoint. raise a hard error.

* Adding a failing test on `main` that passes here.

* We don't need to keep the subfolder logic in this test.

* Apply suggestions from code review

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Add small tests.

* Dead variable.

* Fixup.

* Fixing tied_Weights_keys on generic models.

* Fixup + T5 encoder/decoder tying (with different layers)

* Code quality.

* Dynamic member.

* trigger

* Fixing encoder name for other types of encoder/decoder combos.

* Fix scoping.

* Update .github/workflows/self-scheduled.yml

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Fixing the tied_weights after the call.

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
  • Loading branch information
3 people committed Apr 2, 2024
1 parent 15cd687 commit 9b0a8ea
Show file tree
Hide file tree
Showing 7 changed files with 226 additions and 34 deletions.
157 changes: 133 additions & 24 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from dataclasses import dataclass
from functools import partial, wraps
from threading import Thread
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from zipfile import is_zipfile

import torch
Expand Down Expand Up @@ -573,6 +573,79 @@ def set_initialized_submodules(model, state_dict_keys):
return not_initialized_submodules


def _end_ptr(tensor: torch.Tensor) -> int:
# extract the end of the pointer if the tensor is a slice of a bigger tensor
if tensor.nelement():
stop = tensor.view(-1)[-1].data_ptr() + tensor.element_size()
else:
stop = tensor.data_ptr()
return stop


def _get_tied_weight_keys(module: nn.Module, prefix=""):
tied_weight_keys = []
if getattr(module, "_tied_weights_keys", None) is not None:
names = [f"{prefix}.{k}" if prefix else k for k in module._tied_weights_keys]
tied_weight_keys.extend(names)
if getattr(module, "_dynamic_tied_weights_keys", None) is not None:
names = [f"{prefix}.{k}" if prefix else k for k in module._dynamic_tied_weights_keys]
tied_weight_keys.extend(names)
for name, submodule in module.named_children():
local_prefix = f"{prefix}.{name}" if prefix else name
tied_weight_keys.extend(_get_tied_weight_keys(submodule, prefix=local_prefix))
return tied_weight_keys


def _find_disjoint(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]) -> Tuple[List[Set[str]], List[str]]:
filtered_tensors = []
for shared in tensors:
if len(shared) < 2:
filtered_tensors.append(shared)
continue

areas = []
for name in shared:
tensor = state_dict[name]
areas.append((tensor.data_ptr(), _end_ptr(tensor), name))
areas.sort()

_, last_stop, last_name = areas[0]
filtered_tensors.append({last_name})
for start, stop, name in areas[1:]:
if start >= last_stop:
filtered_tensors.append({name})
else:
filtered_tensors[-1].add(name)
last_stop = stop
disjoint_tensors = []
shared_tensors = []
for tensors in filtered_tensors:
if len(tensors) == 1:
disjoint_tensors.append(tensors.pop())
else:
shared_tensors.append(tensors)
return shared_tensors, disjoint_tensors


def _find_identical(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]) -> Tuple[List[Set[str]], Set[str]]:
shared_tensors = []
identical = []
for shared in tensors:
if len(shared) < 2:
continue

areas = collections.defaultdict(set)
for name in shared:
tensor = state_dict[name]
area = (tensor.device, tensor.data_ptr(), _end_ptr(tensor))
areas[area].add(name)
if len(areas) == 1:
identical.append(shared)
else:
shared_tensors.append(shared)
return shared_tensors, identical


def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
# Convert old format to new format if needed from a PyTorch state_dict
old_keys = []
Expand Down Expand Up @@ -1646,15 +1719,24 @@ def tie_weights(self):
if getattr(self.config, "is_encoder_decoder", False) and getattr(self.config, "tie_encoder_decoder", False):
if hasattr(self, self.base_model_prefix):
self = getattr(self, self.base_model_prefix)
self._tie_encoder_decoder_weights(self.encoder, self.decoder, self.base_model_prefix)
tied_weights = self._tie_encoder_decoder_weights(
self.encoder, self.decoder, self.base_model_prefix, "encoder"
)
# Setting a dynamic variable instead of `_tied_weights_keys` because it's a class
# attributed not an instance member, therefore modifying it will modify the entire class
# Leading to issues on subsequent calls by different tests or subsequent calls.
self._dynamic_tied_weights_keys = tied_weights

for module in self.modules():
if hasattr(module, "_tie_weights"):
module._tie_weights()

@staticmethod
def _tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str):
def _tie_encoder_decoder_weights(
encoder: nn.Module, decoder: nn.Module, base_model_prefix: str, base_encoder_name: str
):
uninitialized_encoder_weights: List[str] = []
tied_weights: List[str] = []
if decoder.__class__ != encoder.__class__:
logger.info(
f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder"
Expand All @@ -1665,17 +1747,22 @@ def tie_encoder_to_decoder_recursively(
decoder_pointer: nn.Module,
encoder_pointer: nn.Module,
module_name: str,
base_encoder_name: str,
uninitialized_encoder_weights: List[str],
depth=0,
total_decoder_name="",
total_encoder_name="",
):
assert isinstance(decoder_pointer, nn.Module) and isinstance(
encoder_pointer, nn.Module
), f"{decoder_pointer} and {encoder_pointer} have to be of type nn.Module"
if hasattr(decoder_pointer, "weight"):
assert hasattr(encoder_pointer, "weight")
encoder_pointer.weight = decoder_pointer.weight
tied_weights.append(f"{base_encoder_name}{total_encoder_name}.weight")
if hasattr(decoder_pointer, "bias"):
assert hasattr(encoder_pointer, "bias")
tied_weights.append(f"{base_encoder_name}{total_encoder_name}.bias")
encoder_pointer.bias = decoder_pointer.bias
return

Expand Down Expand Up @@ -1713,19 +1800,26 @@ def tie_encoder_to_decoder_recursively(
decoder_modules[decoder_name],
encoder_modules[encoder_name],
module_name + "/" + name,
base_encoder_name,
uninitialized_encoder_weights,
depth=depth + 1,
total_encoder_name=f"{total_encoder_name}.{encoder_name}",
total_decoder_name=f"{total_decoder_name}.{decoder_name}",
)
all_encoder_weights.remove(module_name + "/" + encoder_name)

uninitialized_encoder_weights += list(all_encoder_weights)

# tie weights recursively
tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, uninitialized_encoder_weights)
tie_encoder_to_decoder_recursively(
decoder, encoder, base_model_prefix, base_encoder_name, uninitialized_encoder_weights
)

if len(uninitialized_encoder_weights) > 0:
logger.warning(
f"The following encoder weights were not tied to the decoder {uninitialized_encoder_weights}"
)
return tied_weights

def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
"""Tie or clone module weights depending of whether we are using TorchScript or not"""
Expand Down Expand Up @@ -2402,34 +2496,49 @@ def save_pretrained(

# These are all the pointers of shared tensors.
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
warn_names = set()
error_names = []
to_delete_names = set()
# Recursively descend to find tied weight keys
_tied_weights_keys = _get_tied_weight_keys(self)
for names in shared_ptrs.values():
# Removing the keys which are declared as known duplicates on
# load. This allows to make sure the name which is kept is consistent.
if self._tied_weights_keys is not None:
if _tied_weights_keys is not None:
found = 0
for name in sorted(names):
matches_pattern = any(re.search(pat, name) for pat in self._tied_weights_keys)
matches_pattern = any(re.search(pat, name) for pat in _tied_weights_keys)
if matches_pattern and name in state_dict:
found += 1
if found < len(names):
del state_dict[name]

# When not all duplicates have been cleaned, still remove those keys, but put a clear warning.
# If the link between tensors was done at runtime then `from_pretrained` will not get
# the key back leading to random tensor. A proper warning will be shown
# during reload (if applicable), but since the file is not necessarily compatible with
# the config, better show a proper warning.
found = 0
for name in names:
if name in state_dict:
found += 1
if found > 1:
del state_dict[name]
warn_names.add(name)
if len(warn_names) > 0:
logger.warning_once(
f"Removed shared tensor {warn_names} while saving. This should be OK, but check by verifying that you don't receive any warning while reloading",
to_delete_names.add(name)
# We are entering a place where the weights and the transformers configuration do NOT match.
shared_names, disjoint_names = _find_disjoint(shared_ptrs.values(), state_dict)
# Those are actually tensor sharing but disjoint from each other, we can safely clone them
# Reloaded won't have the same property, but it shouldn't matter in any meaningful way.
for name in disjoint_names:
state_dict[name] = state_dict[name].clone()

# When not all duplicates have been cleaned, still remove those keys, but put a clear warning.
# If the link between tensors was done at runtime then `from_pretrained` will not get
# the key back leading to random tensor. A proper warning will be shown
# during reload (if applicable), but since the file is not necessarily compatible with
# the config, better show a proper warning.
shared_names, identical_names = _find_identical(shared_names, state_dict)
# delete tensors that have identical storage
for inames in identical_names:
known = inames.intersection(to_delete_names)
for name in known:
del state_dict[name]
unknown = inames.difference(to_delete_names)
if len(unknown) > 1:
error_names.append(unknown)

if shared_names:
error_names.append(set(shared_names))

if len(error_names) > 0:
raise RuntimeError(
f"The weights trying to be saved contained shared tensors {error_names} that are mismatching the transformers base configuration. Try saving using `safe_serialization=False` or remove this tensor sharing.",
)

# Shard the model if it is too big.
Expand Down
3 changes: 1 addition & 2 deletions src/transformers/models/bert/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# limitations under the License.
"""PyTorch BERT model."""


import math
import os
import warnings
Expand Down Expand Up @@ -1128,7 +1127,7 @@ def forward(
"""Bert Model with a `language modeling` head on top for CLM fine-tuning.""", BERT_START_DOCSTRING
)
class BertLMHeadModel(BertPreTrainedModel):
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
_tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]

def __init__(self, config):
super().__init__(config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,16 @@ def tie_weights(self):
if self.config.tie_encoder_decoder:
# tie encoder and decoder base model
decoder_base_model_prefix = self.decoder.base_model_prefix
self._tie_encoder_decoder_weights(
self.encoder, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix
tied_weights = self._tie_encoder_decoder_weights(
self.encoder,
self.decoder._modules[decoder_base_model_prefix],
self.decoder.base_model_prefix,
"encoder",
)
# Setting a dynamic variable instead of `_tied_weights_keys` because it's a class
# attributed not an instance member, therefore modifying it will modify the entire class
# Leading to issues on subsequent calls by different tests or subsequent calls.
self._dynamic_tied_weights_keys = tied_weights

def get_encoder(self):
return self.encoder
Expand Down
8 changes: 7 additions & 1 deletion src/transformers/models/marian/modeling_marian.py
Original file line number Diff line number Diff line change
Expand Up @@ -1343,7 +1343,13 @@ def tie_weights(self):
if getattr(self.config, "is_encoder_decoder", False) and getattr(self.config, "tie_encoder_decoder", False):
if hasattr(self, self.base_model_prefix):
self = getattr(self, self.base_model_prefix)
self._tie_encoder_decoder_weights(self.encoder, self.decoder, self.base_model_prefix)
tied_weights = self._tie_encoder_decoder_weights(
self.encoder, self.decoder, self.base_model_prefix, "encoder"
)
# Setting a dynamic variable instead of `_tied_weights_keys` because it's a class
# attributed not an instance member, therefore modifying it will modify the entire class
# Leading to issues on subsequent calls by different tests or subsequent calls.
self._dynamic_tied_weights_keys = tied_weights

for module in self.modules():
if hasattr(module, "_tie_weights"):
Expand Down
11 changes: 9 additions & 2 deletions src/transformers/models/musicgen/modeling_musicgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1891,9 +1891,16 @@ def tie_weights(self):
if self.config.tie_encoder_decoder:
# tie text encoder and decoder base model
decoder_base_model_prefix = self.decoder.base_model_prefix
self._tie_encoder_decoder_weights(
self.text_encoder, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix
tied_weights = self._tie_encoder_decoder_weights(
self.text_encoder,
self.decoder._modules[decoder_base_model_prefix],
self.decoder.base_model_prefix,
"text_encoder",
)
# Setting a dynamic variable instead of `_tied_weights_keys` because it's a class
# attributed not an instance member, therefore modifying it will modify the entire class
# Leading to issues on subsequent calls by different tests or subsequent calls.
self._dynamic_tied_weights_keys = tied_weights

def get_audio_encoder(self):
return self.audio_encoder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1810,9 +1810,16 @@ def tie_weights(self):
if self.config.tie_encoder_decoder:
# tie text encoder and decoder base model
decoder_base_model_prefix = self.decoder.base_model_prefix
self._tie_encoder_decoder_weights(
self.text_encoder, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix
tied_weights = self._tie_encoder_decoder_weights(
self.text_encoder,
self.decoder._modules[decoder_base_model_prefix],
self.decoder.base_model_prefix,
"text_encoder",
)
# Setting a dynamic variable instead of `_tied_weights_keys` because it's a class
# attributed not an instance member, therefore modifying it will modify the entire class
# Leading to issues on subsequent calls by different tests or subsequent calls.
self._dynamic_tied_weights_keys = tied_weights

def get_text_encoder(self):
return self.text_encoder
Expand Down

0 comments on commit 9b0a8ea

Please sign in to comment.