From 0676d992a5c1f6107a611018494ec952613a4d7f Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Mon, 11 Dec 2023 12:38:17 +0100 Subject: [PATCH] [`from_pretrained`] Make from_pretrained fast again (#27709) * Skip nn.Module.reset_parameters * Actually skip * Check quality * Maybe change all inits * Fix init issues: only modify public functions * Add a small test for now * Style * test updates * style * nice tes * style * make it even faster * one more second * remove fx icompatible * Update tests/test_modeling_common.py Co-authored-by: Lysandre Debut * Update tests/test_modeling_common.py Co-authored-by: Lysandre Debut * skip * fix quality * protect the import --------- Co-authored-by: Lysandre Debut --- src/transformers/modeling_utils.py | 35 ++++++++++++++++++- tests/test_modeling_common.py | 55 +++++++++++++++++++++++++++++- 2 files changed, 88 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 51fb21987f487..3247c32368581 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -154,6 +154,23 @@ def is_local_dist_rank_0(): if is_peft_available(): from .utils import find_adapter_config_file +TORCH_INIT_FUNCTIONS = { + "uniform_": nn.init.uniform_, + "normal_": nn.init.normal_, + "trunc_normal_": nn.init.trunc_normal_, + "constant_": nn.init.constant_, + "xavier_uniform_": nn.init.xavier_uniform_, + "xavier_normal_": nn.init.xavier_normal_, + "kaiming_uniform_": nn.init.kaiming_uniform_, + "kaiming_normal_": nn.init.kaiming_normal_, + "uniform": nn.init.uniform, + "normal": nn.init.normal, + "xavier_uniform": nn.init.xavier_uniform, + "xavier_normal": nn.init.xavier_normal, + "kaiming_uniform": nn.init.kaiming_uniform, + "kaiming_normal": nn.init.kaiming_normal, +} + @contextmanager def no_init_weights(_enable=True): @@ -164,12 +181,24 @@ def no_init_weights(_enable=True): """ global _init_weights old_init_weights = _init_weights + if _enable: _init_weights = False + + def _skip_init(*args, **kwargs): + pass + + # # Save the original initialization functions + for name, init_func in TORCH_INIT_FUNCTIONS.items(): + setattr(torch.nn.init, name, _skip_init) try: yield finally: _init_weights = old_init_weights + if _enable: + # # Restore the original initialization functions + for name, init_func in TORCH_INIT_FUNCTIONS.items(): + setattr(torch.nn.init, name, init_func) def get_parameter_device(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]): @@ -1506,7 +1535,10 @@ def get_output_embeddings(self) -> nn.Module: def _init_weights(self, module): """ - Initialize the weights. This method should be overridden by derived class. + Initialize the weights. This method should be overridden by derived class and is + the only initialization method that will be called when loading a checkpoint + using `from_pretrained`. Any attempt to initialize outside of this function + will be useless as the torch.nn.init function are all replaced with skip. """ pass @@ -3414,6 +3446,7 @@ def from_pretrained( ) with ContextManagers(init_contexts): + # Let's make sure we don't run the init function of buffer modules model = cls(config, *model_args, **model_kwargs) # make sure we use the model's config since the __init__ call might have copied it diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 8293829b009d0..85e6930051616 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -36,8 +36,10 @@ AutoModelForCausalLM, AutoModelForSequenceClassification, PretrainedConfig, + PreTrainedModel, is_torch_available, logging, + set_seed, ) from transformers.models.auto import get_values from transformers.models.auto.modeling_auto import ( @@ -85,7 +87,7 @@ is_torch_fx_available, is_torch_sdpa_available, ) -from transformers.utils.generic import ModelOutput +from transformers.utils.generic import ContextManagers, ModelOutput if is_accelerate_available(): @@ -99,6 +101,7 @@ from torch import nn from transformers import MODEL_MAPPING, AdaptiveEmbedding + from transformers.modeling_utils import no_init_weights from transformers.pytorch_utils import id_tensor_storage @@ -428,6 +431,56 @@ class CopyClass(model_class): max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item() self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + def test_fast_init_context_manager(self): + # 1. Create a dummy class. Should have buffers as well? To make sure we test __init__ + class MyClass(PreTrainedModel): + config_class = PretrainedConfig + + def __init__(self, config=None): + super().__init__(config if config is not None else PretrainedConfig()) + self.linear = nn.Linear(10, 10, bias=True) + self.embedding = nn.Embedding(10, 10) + self.std = 1 + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + module.weight.data = nn.init.kaiming_uniform_(module.weight.data, np.sqrt(5)) + if module.bias is not None: + module.bias.data.normal_(mean=0.0, std=self.std) + + # 2. Make sure a linear layer's reset params is properly skipped: + with ContextManagers([no_init_weights(True)]): + no_init_instance = MyClass() + + set_seed(0) + expected_bias = torch.tensor( + ([0.2975, 0.2131, -0.1379, -0.0796, -0.3012, -0.0057, -0.2381, -0.2439, -0.0174, 0.0475]) + ) + init_instance = MyClass() + torch.testing.assert_allclose(init_instance.linear.bias, expected_bias, rtol=1e-3, atol=1e-4) + + set_seed(0) + torch.testing.assert_allclose( + init_instance.linear.weight, nn.init.kaiming_uniform_(no_init_instance.linear.weight, np.sqrt(5)) + ) + + # 3. Make sure weights that are not present use init_weight_ and get expected values + with tempfile.TemporaryDirectory() as tmpdirname: + state_dict = init_instance.state_dict() + del state_dict["linear.weight"] + + init_instance.config.save_pretrained(tmpdirname) + torch.save(state_dict, os.path.join(tmpdirname, "pytorch_model.bin")) + set_seed(0) + model_fast_init = MyClass.from_pretrained(tmpdirname) + + set_seed(0) + model_slow_init = MyClass.from_pretrained(tmpdirname, _fast_init=False) + + for key in model_fast_init.state_dict().keys(): + max_diff = torch.max(torch.abs(model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key])) + self.assertLessEqual(max_diff.item(), 1e-3, msg=f"{key} not identical") + def test_save_load_fast_init_to_base(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() if config.__class__ not in MODEL_MAPPING: