diff --git a/Makefile b/Makefile index 58994409a06b..591fd5b6387b 100644 --- a/Makefile +++ b/Makefile @@ -45,6 +45,7 @@ repo-consistency: python utils/check_modular_conversion.py python utils/check_dummies.py python utils/check_repo.py + python utils/check_init_weights_data.py python utils/check_inits.py python utils/check_pipeline_typing.py python utils/check_config_docstrings.py diff --git a/docs/source/de/add_new_model.md b/docs/source/de/add_new_model.md index 848dcbc30631..8f19517819b9 100644 --- a/docs/source/de/add_new_model.md +++ b/docs/source/de/add_new_model.md @@ -508,16 +508,16 @@ BERT `_init_weights` Methode: def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) ``` Sie können weitere benutzerdefinierte Schemata verwenden, wenn Sie eine spezielle Initialisierung für einige Module benötigen. Zum Beispiel in @@ -533,9 +533,9 @@ def _init_weights(self, module): module.project_hid._is_hf_initialized = True module.project_q._is_hf_initialized = True elif isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() ``` Das Flag `_is_hf_initialized` wird intern verwendet, um sicherzustellen, dass wir ein Submodul nur einmal initialisieren. Wenn Sie es auf diff --git a/docs/source/en/add_new_model.md b/docs/source/en/add_new_model.md index a9d8168f7505..2cd88930fbbc 100644 --- a/docs/source/en/add_new_model.md +++ b/docs/source/en/add_new_model.md @@ -314,16 +314,16 @@ Random initialization occurs in the `_init_weights` method of `BrandNewLlamaPreT def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) ``` The initialization scheme can look different if you need to adapt it to your model. For example, [`Wav2Vec2ForPreTraining`] initializes [nn.Linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) in its last two linear layers. @@ -339,9 +339,9 @@ def _init_weights(self, module): module.project_hid._is_hf_initialized = True module.project_q._is_hf_initialized = True elif isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() ``` ### Convert checkpoints to Transformers diff --git a/docs/source/en/perf_infer_gpu_multi.md b/docs/source/en/perf_infer_gpu_multi.md index cb426b81916c..893dd28d7b45 100644 --- a/docs/source/en/perf_infer_gpu_multi.md +++ b/docs/source/en/perf_infer_gpu_multi.md @@ -149,7 +149,7 @@ The example below packs `up_proj` and `gate_proj` into a single `gate_up_proj` m ```python class Llama4TextExperts(nn.Module): ... - self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim)) + self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size, 2 * self.expert_dim)) ``` Batch matrix multiplication can be used in the `forward` pass to compute the output of the `gate_up_proj` module. diff --git a/docs/source/ja/add_new_model.md b/docs/source/ja/add_new_model.md index 75219dcb8f88..f768c094a084 100644 --- a/docs/source/ja/add_new_model.md +++ b/docs/source/ja/add_new_model.md @@ -406,16 +406,16 @@ model = BrandNewBertModel(BrandNewBertConfig()) def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) ``` 特定のモジュールに特別な初期化が必要な場合、カスタムスキームをさらに持つことができます。たとえば、 @@ -431,9 +431,9 @@ def _init_weights(self, module): module.project_hid._is_hf_initialized = True module.project_q._is_hf_initialized = True elif isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() ``` `_is_hf_initialized`フラグは、サブモジュールを一度だけ初期化することを確実にするために内部で使用されます。 diff --git a/docs/source/ko/add_new_model.md b/docs/source/ko/add_new_model.md index a75032c000d0..be33c92dc4b0 100644 --- a/docs/source/ko/add_new_model.md +++ b/docs/source/ko/add_new_model.md @@ -348,16 +348,16 @@ model = BrandNewBertModel(BrandNewBertConfig()) def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) ``` 몇 가지 모듈에 대해 특별한 초기화가 필요한 경우 사용자 정의 방식을 사용할 수도 있습니다. 예를 들어, `Wav2Vec2ForPreTraining`에서 마지막 두 개의 선형 레이어는 일반적인 PyTorch `nn.Linear`의 초기화를 가져야 하지만, 다른 모든 레이어는 위와 같은 초기화를 사용해야 합니다. 이는 다음과 같이 코드화됩니다: @@ -371,9 +371,9 @@ def _init_weights(self, module): module.project_hid._is_hf_initialized = True module.project_q._is_hf_initialized = True elif isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() ``` `_is_hf_initialized` 플래그는 서브모듈을 한 번만 초기화하도록 내부적으로 사용됩니다. `module.project_q` 및 `module.project_hid`에 대해 `True`로 설정함으로써, 우리가 수행한 사용자 정의 초기화가 이후에 덮어쓰이지 않도록 합니다. 즉, `_init_weights` 함수가 이들에게 적용되지 않습니다. diff --git a/docs/source/ko/perf_infer_gpu_multi.md b/docs/source/ko/perf_infer_gpu_multi.md index 304b798796f6..676ed5980035 100644 --- a/docs/source/ko/perf_infer_gpu_multi.md +++ b/docs/source/ko/perf_infer_gpu_multi.md @@ -152,7 +152,7 @@ class ParallelInterface(MutableMapping): ```python class Llama4TextExperts(nn.Module): ... - self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim)) + self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size, 2 * self.expert_dim)) ``` 배치 행렬 곱셈을 `forward` 패스에서 사용하여 `gate_up_proj` 모듈의 출력을 계산할 수 있습니다. diff --git a/examples/modular-transformers/modeling_dummy_bert.py b/examples/modular-transformers/modeling_dummy_bert.py index d3dc55f845d2..15c96bf7bbc8 100644 --- a/examples/modular-transformers/modeling_dummy_bert.py +++ b/examples/modular-transformers/modeling_dummy_bert.py @@ -502,16 +502,10 @@ def __init__(self, config): # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. - self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` - self.decoder.bias = self.bias - - def _tie_weights(self): - self.decoder.bias = self.bias - def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -536,18 +530,18 @@ class DummyBertPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, DummyBertLMPredictionHead): - module.bias.data.zero_() + module.bias.zero_() @auto_docstring( diff --git a/examples/modular-transformers/modeling_my_new_model2.py b/examples/modular-transformers/modeling_my_new_model2.py index 0dd5efe4e89b..440878c3df49 100644 --- a/examples/modular-transformers/modeling_my_new_model2.py +++ b/examples/modular-transformers/modeling_my_new_model2.py @@ -265,7 +265,7 @@ def _init_weights(self, module): # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight) if "RMSNorm" in module.__class__.__name__: - module.weight.data.zero_() + module.weight.zero_() class MyNewModel2ForSequenceClassification(GenericForSequenceClassification, MyNewModel2PreTrainedModel): diff --git a/examples/modular-transformers/modeling_new_task_model.py b/examples/modular-transformers/modeling_new_task_model.py index c74ce212d834..041f1d4a0422 100644 --- a/examples/modular-transformers/modeling_new_task_model.py +++ b/examples/modular-transformers/modeling_new_task_model.py @@ -104,9 +104,9 @@ def _init_weights(self, module): std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() def token_type_ids_mask_function( @@ -428,7 +428,7 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin): "^multi_modal_projector": "model.multi_modal_projector", "^language_model.lm_head": "lm_head", } - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related def __init__(self, config): @@ -440,7 +440,15 @@ def __init__(self, config): self.custom_text_proj = nn.Linear(self.config.text_config.hidden_size, self.embedding_dim) if self.language_model._tied_weights_keys is not None: - self._tied_weights_keys = [f"model.language_model.{k}" for k in self.language_model._tied_weights_keys] + prefix = "model.language_model." + prefixed_mapping = { + f"{prefix}{target}": f"{prefix}{source}" + for target, source in self.language_model._tied_weights_keys.items() + } + if isinstance(self._tied_weights_keys, dict): + self._tied_weights_keys.update(prefixed_mapping) + else: + self._tied_weights_keys = prefixed_mapping self.post_init() def get_input_embeddings(self): diff --git a/examples/modular-transformers/modeling_roberta.py b/examples/modular-transformers/modeling_roberta.py index cb125123bf8c..b1f35119580b 100644 --- a/examples/modular-transformers/modeling_roberta.py +++ b/examples/modular-transformers/modeling_roberta.py @@ -505,16 +505,10 @@ def __init__(self, config): # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. - self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` - self.decoder.bias = self.bias - - def _tie_weights(self): - self.decoder.bias = self.bias - def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -539,18 +533,18 @@ class RobertaPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, RobertaLMPredictionHead): - module.bias.data.zero_() + module.bias.zero_() @auto_docstring( diff --git a/examples/modular-transformers/modeling_test_detr.py b/examples/modular-transformers/modeling_test_detr.py index 3ff225c0b3ff..6f88e341a032 100644 --- a/examples/modular-transformers/modeling_test_detr.py +++ b/examples/modular-transformers/modeling_test_detr.py @@ -846,11 +846,11 @@ def _init_weights(self, module): nn.init.xavier_uniform_(module.output_proj.weight.data) nn.init.constant_(module.output_proj.bias.data, 0.0) elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() if hasattr(module, "reference_points") and not self.config.two_stage: diff --git a/examples/modular-transformers/modular_new_task_model.py b/examples/modular-transformers/modular_new_task_model.py index 2a6dc470d74b..43830b12c784 100644 --- a/examples/modular-transformers/modular_new_task_model.py +++ b/examples/modular-transformers/modular_new_task_model.py @@ -19,7 +19,15 @@ def __init__(self, config): self.custom_text_proj = nn.Linear(self.config.text_config.hidden_size, self.embedding_dim) if self.language_model._tied_weights_keys is not None: - self._tied_weights_keys = [f"model.language_model.{k}" for k in self.language_model._tied_weights_keys] + prefix = "model.language_model." + prefixed_mapping = { + f"{prefix}{target}": f"{prefix}{source}" + for target, source in self.language_model._tied_weights_keys.items() + } + if isinstance(self._tied_weights_keys, dict): + self._tied_weights_keys.update(prefixed_mapping) + else: + self._tied_weights_keys = prefixed_mapping self.post_init() diff --git a/setup.py b/setup.py index 048087ab84a3..ec5b9ab54ac8 100644 --- a/setup.py +++ b/setup.py @@ -138,7 +138,7 @@ "pyyaml>=5.1", "pydantic>=2", "pytest>=7.2.0", - "pytest-asyncio", + "pytest-asyncio>=1.2.0", "pytest-rerunfailures<16.0", "pytest-timeout", "pytest-xdist", diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index a04b6f2f4332..f94b4b0c5aa4 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -876,7 +876,7 @@ def to_diff_dict(self) -> dict[str, Any]: if hasattr(self, "quantization_config"): serializable_config_dict["quantization_config"] = ( self.quantization_config.to_dict() - if not isinstance(self.quantization_config, dict) + if not isinstance(self.quantization_config, dict) and self.quantization_config is not None else self.quantization_config ) self.dict_dtype_to_str(serializable_config_dict) @@ -910,7 +910,7 @@ def to_dict(self) -> dict[str, Any]: if hasattr(self, "quantization_config"): output["quantization_config"] = ( self.quantization_config.to_dict() - if not isinstance(self.quantization_config, dict) + if not isinstance(self.quantization_config, dict) and self.quantization_config is not None else self.quantization_config ) self.dict_dtype_to_str(output) diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py new file mode 100644 index 000000000000..0498ab2a64f5 --- /dev/null +++ b/src/transformers/conversion_mapping.py @@ -0,0 +1,141 @@ +# coding=utf-8 +# Copyright (C) 2025 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .core_model_loading import Concatenate, MergeModulelist, WeightConverter +from .utils import is_torch_available + + +if is_torch_available(): + import torch + + +def _build_checkpoint_conversion_mapping(): + mapping = { + "mixtral": [ + WeightConverter( + source_keys=[ + "block_sparse_moe.experts.*.w1.weight", + "block_sparse_moe.experts.*.w3.weight", + ], # you give me a list of 2 keys, I collect a list of a list of tensors + target_keys="mlp.experts.gate_up_proj", # target key gets the list of two tensors + operations=[ + MergeModulelist( + dim=0 + ), # each process has two lists of tensors, we cat each list. -> we end up with 2 tensors + Concatenate(dim=1), # each process has 2 tensors, gate and up, we concat them into gate_up + ], # we want the loading to add this shard operation here. Though we can't shard after concats and merge, needs to be first + ), + WeightConverter( + source_keys=[ + "block_sparse_moe.experts.*.w2.weight", + ], + target_keys="mlp.experts.down_proj", # target key gets the list of two tensors + operations=[ + MergeModulelist( + dim=0 + ), # each process has two lists of tensors, we cat each list. -> we end up with 2 tensors + ], # we want the loading to add this shard operation here. Though we can't shard after concats and merge, needs to be first + ), + # WeightConverter( + # ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"], + # "self_attn.qkv_proj", + # operations=[Concatenate(dim=0)], # more like stack? + # ), + WeightConverter("*.block_sparse_moe.", "*.mlp."), + ], + "qwen2_moe": [ + WeightConverter( + source_keys=[ + "mlp.experts.*.gate_proj.weight", + "mlp.experts.*.up_proj.weight", + ], + target_keys="mlp.experts.gate_up_proj", + operations=[MergeModulelist(dim=0), Concatenate(dim=1)], + ), + WeightConverter( + source_keys=["mlp.experts.*.down_proj.weight"], + target_keys="mlp.experts.down_proj", + operations=[MergeModulelist(dim=0)], + ), + ], + "legacy": [ + WeightConverter( + source_keys="LayerNorm.gamma", + target_keys="LayerNorm.weight", + ), + WeightConverter( + source_keys="LayerNorm.beta", + target_keys="LayerNorm.bias", + ), + ], + } + if hasattr(torch.nn.utils.parametrizations, "weight_norm"): + mapping["legacy"] += [ + WeightConverter( + source_keys="weight_g", + target_keys="parametrizations.weight.original0", + ), + WeightConverter( + source_keys="weight_v", + target_keys="parametrizations.weight.original1", + ), + ] + else: + mapping["legacy"] += [ + WeightConverter( + source_keys="parametrizations.weight.original0", + target_keys="weight_g", + ), + WeightConverter( + source_keys="parametrizations.weight.original1", + target_keys="weight_v", + ), + ] + + mapping["phimoe"] = mapping["mixtral"].copy() + mapping["deepseek_v2"] = mapping["qwen2_moe"].copy() + mapping["deepseek_v3"] = mapping["qwen2_moe"].copy() + mapping["dot1"] = mapping["qwen2_moe"].copy() + mapping["ernie_4_5_moe"] = mapping["qwen2_moe"].copy() + mapping["glm4_moe"] = mapping["qwen2_moe"].copy() + mapping["glm4v_moe"] = mapping["qwen2_moe"].copy() + mapping["jamba"] = mapping["qwen2_moe"].copy() + mapping["lfm2_moe"] = mapping["mixtral"].copy() + mapping["long_cat_flash"] = mapping["qwen2_moe"].copy() + mapping["qwen3_moe"] = mapping["qwen2_moe"].copy() + mapping["qwen3_omni_moe"] = mapping["qwen2_moe"].copy() + mapping["qwen3_next"] = mapping["qwen2_moe"].copy() + mapping["qwen3_vl_moe"] = mapping["qwen2_moe"].copy() + mapping["hunyuan_v1_moe"] = mapping["qwen2_moe"].copy() + mapping["minimax"] = mapping["mixtral"].copy() + + return mapping + + +_checkpoint_conversion_mapping_cache = None + + +def get_checkpoint_conversion_mapping(): + global _checkpoint_conversion_mapping_cache + if _checkpoint_conversion_mapping_cache is None: + _checkpoint_conversion_mapping_cache = _build_checkpoint_conversion_mapping() + globals()["_checkpoint_conversion_mapping"] = _checkpoint_conversion_mapping_cache + return _checkpoint_conversion_mapping_cache + + +def __getattr__(name): + if name == "_checkpoint_conversion_mapping": + return get_checkpoint_conversion_mapping() + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py new file mode 100644 index 000000000000..84da55315a13 --- /dev/null +++ b/src/transformers/core_model_loading.py @@ -0,0 +1,761 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Core helpers for loading model checkpoints.""" + +from __future__ import annotations + +import itertools +import os +import re +from abc import abstractmethod +from collections import defaultdict +from collections.abc import MutableMapping, MutableSet, Sequence +from concurrent.futures import Future, ThreadPoolExecutor +from contextlib import contextmanager +from dataclasses import dataclass, field +from functools import partial +from types import MethodType +from typing import Any, Optional, Union + +import torch + +from .integrations.tensor_parallel import ALL_PARALLEL_STYLES, TensorParallelLayer, DTensor, Replicate +from .quantizers import HfQuantizer +from .utils import is_torch_greater_or_equal, logging +from .utils.quantization_config import QuantizationMethod + + +_torch_distributed_available = torch.distributed.is_available() +_is_dtensor_available = _torch_distributed_available and is_torch_greater_or_equal("2.5") +if _is_dtensor_available: + from torch.distributed.tensor import DTensor + + +import itertools +import os +import re +from abc import abstractmethod +from collections import defaultdict +from collections.abc import MutableMapping, MutableSet, Sequence +from concurrent.futures import Future, ThreadPoolExecutor +from contextlib import contextmanager +from dataclasses import dataclass, field +from functools import partial +from types import MethodType +from typing import Any, Optional, Union + +import torch + +from .integrations.tensor_parallel import ALL_PARALLEL_STYLES, TensorParallelLayer +from .quantizers import HfQuantizer +from .utils import is_torch_greater_or_equal, logging +from .utils.quantization_config import QuantizationMethod + + +_torch_distributed_available = torch.distributed.is_available() +_is_dtensor_available = _torch_distributed_available and is_torch_greater_or_equal("2.5") +if _is_dtensor_available: + from torch.distributed.tensor import DTensor + + +logger = logging.get_logger(__name__) + +str_to_torch_dtype = { + "BOOL": torch.bool, + "U8": torch.uint8, + "I8": torch.int8, + "I16": torch.int16, + "F16": torch.float16, + "BF16": torch.bfloat16, + "I32": torch.int32, + "F32": torch.float32, + "F64": torch.float64, + "I64": torch.int64, + "F8_E4M3": torch.float8_e4m3fn, + "F8_E5M2": torch.float8_e5m2, +} + + +logger = logging.get_logger(__name__) + +str_to_torch_dtype = { + "BOOL": torch.bool, + "U8": torch.uint8, + "I8": torch.int8, + "I16": torch.int16, + "F16": torch.float16, + "BF16": torch.bfloat16, + "I32": torch.int32, + "F32": torch.float32, + "F64": torch.float64, + "I64": torch.int64, + "F8_E4M3": torch.float8_e4m3fn, + "F8_E5M2": torch.float8_e5m2, +} + + +def _glob_to_regex_src(glob: str, *, digits_only: bool = True) -> str: + """ + Convert a glob with '*' into a regex *source* string. We don't use `glob.translate` + '*' matches (\\d+) if digits_only else (.+). Inner groups are non-capturing. + """ + star = r"(\d+)" if digits_only else r"(.+)" + return re.escape(glob).replace(r"\*", star) + + +def build_glob_alt( + globs: list[str], +) -> tuple[re.Pattern, dict[str, str]]: + r""" + Build one compiled regex alternation with a named group per glob. This allows to run a single + re.match and get the correct group name to finally get which pattern matched. + Returns (compiled_regex, name->glob map). + + Example: + + ```py + >>> reg, map_ = build_glob_alt(["mlp.*.w1", "mlp.*.w2"]) + >>> print(reg) + (re.compile(r'(?P.*mlp\.(\d+)\.w1)|(?P.*mlp\.(\d+)\.w2)', re.UNICODE), + >>> print(map_) + {'g0': 'mlp.*.w1', 'g1': 'mlp.*.w2'}) + >>> match_ = reg.match("model.layers.0.mlp.0.w1.weight") + >>> print(match_.lastgroup) + 'g0' + >>> print(map_[match_.lastgroup]) + mlp.*.w1 + ``` + """ + name_map: dict[str, str] = {} + parts: list[str] = [] + prefix_src = r".*" + + for i, g in enumerate(globs): + name = f"g{i}" + name_map[name] = g + pat_src = _glob_to_regex_src(g) + parts.append(f"(?P<{name}>{prefix_src}{pat_src})") + + alt_src = "|".join(parts) + return re.compile(alt_src), name_map + + +def match_glob(key: str, alt: re.Pattern, name_map: dict[str, str]) -> Optional[str]: + """ + Match the key against the alternation; return the original glob string that matched. + """ + m = alt.match(key) + if not m: + return None + return name_map.get(m.lastgroup) + + +class ConversionOps: + """Base class for weight conversion operations.""" + + # The inverse operation class, will be used when saving the checkpoint + reverse_op: type[ConversionOps] + + @abstractmethod + def convert( + self, value: Union[dict[str, torch.Tensor], Sequence[torch.Tensor], torch.Tensor], *args, **kwargs + ) -> torch.Tensor: + raise NotImplementedError + + +class Chunk(ConversionOps): + """Split a tensor along ``dim`` into equally sized chunks or using explicit ``sizes``.""" + + reverse_op: type[ConversionOps] + + def __init__(self, dim: int = 0, chunks: Optional[int] = None, sizes: Optional[Sequence[int]] = None): + if chunks is None and sizes is None: + raise ValueError("`chunks` or `sizes` must be provided for Chunk operations.") + if chunks is not None and chunks <= 0: + raise ValueError("`chunks` must be a strictly positive integer.") + self.dim = dim + self.chunks = chunks + self.sizes = list(sizes) if sizes is not None else None + self.reverse_op = Concatenate + + def convert(self, value: torch.Tensor, *args, **kwargs) -> list[torch.Tensor]: + if not isinstance(value, torch.Tensor): + raise TypeError("Chunk expects a torch.Tensor as input.") + if self.sizes is not None: + return list(torch.split(value, self.sizes, dim=self.dim)) + return list(torch.chunk(value, self.chunks, dim=self.dim)) + + +class Concatenate(ConversionOps): + """Concatenate tensors along `dim` using a reusable buffer.""" + + reverse_op: type[ConversionOps] + + def __init__(self, dim: int = 0): + self.dim = dim + self.reverse_op = Chunk + + @torch.no_grad + def convert(self, value: Sequence[torch.Tensor], *args, **kwargs) -> torch.Tensor: + if isinstance(value[0], list): + value = [v[0] for v in value] + tensors = value + if not tensors: + raise ValueError("Fuse requires at least one tensor to concatenate.") + + return torch.cat(tuple(tensors), dim=self.dim) + + +class MergeModulelist(Concatenate): + """ + Merge a list of tensors into a single tensor along the first dimension. + We explicitly define this because for EP or TP you want to make sure you know what you are doing! + + """ + + def __init__(self, dim: int = 0): + super().__init__(dim=dim) + self.reverse_op = SplitModulelist + + @torch.no_grad + def convert(self, value: Sequence[torch.Tensor], *args, **kwargs) -> list[torch.Tensor]: + merged = [] + for group in value: + if not isinstance(group, Sequence) or len(group) == 0: + raise ValueError("MergeModulelist requires non-empty sub-sequences.") + group = [k for k in group if k.ndim] + merged.append(torch.stack(group, dim=self.dim)) + return merged + + +class SplitModulelist(ConversionOps): + """Inverse of :class:`MergeModulelist` using explicit split sizes per group.""" + + def __init__(self, sizes: Sequence[Sequence[int]], dim: int = 0): + if not isinstance(sizes, Sequence) or not all(isinstance(sub, Sequence) and sub for sub in sizes): + raise ValueError("`sizes` must be a sequence of non-empty sequences of integers.") + self.sizes = [list(sub) for sub in sizes] + self.dim = dim + self.reverse_op = MergeModulelist + + @torch.no_grad + def convert(self, value: Sequence[torch.Tensor], *, context: dict[str, Any]) -> list[list[torch.Tensor]]: + if not isinstance(value, Sequence): + raise TypeError("SplitModulelist expects a sequence of tensors.") + if len(value) != len(self.sizes): + raise ValueError("Number of tensors does not match the provided split specifications.") + + result: list[list[torch.Tensor]] = [] + for tensor, split_sizes in zip(value, self.sizes): + if not isinstance(tensor, torch.Tensor): + raise TypeError("SplitModulelist can only split torch.Tensor instances.") + splits = torch.split(tensor, split_sizes, dim=self.dim) + result.append(list(splits)) + return result + + +class PermuteForRope(ConversionOps): + """ + Applies the permutation required to convert complex RoPE weights to the split sin/cos format. + """ + + def __init__(self): + pass + + def _apply(self, tensor: torch.Tensor) -> torch.Tensor: + dim1, dim2 = tensor.shape + n_heads = self.config.getattr("num_attention_heads", 1) + + tensor = tensor.view(n_heads, dim1 // n_heads // 2, 2, dim2) + tensor = tensor.transpose(1, 2).reshape(dim1, dim2) + return tensor + + @torch.no_grad + def convert( + self, value: Union[dict[str, torch.Tensor], Sequence[torch.Tensor], torch.Tensor], config + ) -> Union[dict[str, torch.Tensor], list[torch.Tensor], torch.Tensor]: + self.config = config + out = [[self._apply(x) for x in inner] if isinstance(inner, list) else self._apply(inner) for inner in value] + return out + + +@dataclass(slots=True) +class WeightConverter: + r""" + A weight convert that acts on a pattern of source keys. + The keys need to be collected based on the target keys. + + With wild card, glob patterns are matched, so you have to be detailed with what to match. If you match: + `model.layers.*.experts.*` -> it will act on all of them + {"model.layers.*.experts.*": []} + but + `experts.*.mlp` will be layer specific. + {"model.layers.1.experts.*": [], } + - source_keys: str | list[str] (wildcards '*' match digits) + - target_keys: str | list[str] | None + - distributed_operation / operations / quantization_operations are ALWAYS lists. + """ + + source_keys: Union[str, list[str]] + target_keys: Optional[Union[str, list[str]]] = None + operations: list[ConversionOps] = field(default_factory=list, repr=False) + + distributed_operation: Optional[TensorParallelLayer] = None + quantization_operation: Optional[ConversionOps] = None + + def __post_init__(self): + if not isinstance(self.source_keys, list): + self.source_keys = [self.source_keys] + targets_were_none = False + if not isinstance(self.target_keys, list): + if self.target_keys is None: + self.target_keys = list(self.source_keys) + targets_were_none = True + else: + self.target_keys = [self.target_keys] + + if not targets_were_none and bool(len(self.source_keys) - 1) + bool(len(self.target_keys) - 1) >= 2: + raise ValueError( + f"source keys={self.source_keys}, target_keys={self.target_keys} but you can only have one to many, one to one or many to one." + ) + + for pattern in self.source_keys: + if any(ch in pattern for ch in set("^$+?{}[]|()")): + raise AssertionError(f"'{pattern}' is not glob") + for pattern in self.target_keys: + if any(ch in pattern for ch in set("^$+?{}[]|()")): + raise AssertionError(f"'{pattern}' is not glob") + + +@dataclass(slots=True) +class ConversionEntry: + weight_converter: WeightConverter + collected_tensors: dict = field(default_factory=lambda: defaultdict(dict)) + + +GLOBAL_WORKERS = min(16, (os.cpu_count() or 8) * 2) # NVMe: 8-16; HDD/NFS: 2-4 + +# Factory function to create LoadedParameter subclasses dynamically +def get_loaded_parameter_class(base_cls): + """ + base_cls: an nn.Parameter subclass (or nn.Parameter) or a Tensor + Returns a new class that combines the base_cls with LoadedParameterMixin + + """ + class LoadedParam(base_cls): + _inplace_methods = [ + 'add_', 'mul_', 'clamp_', 'zero_', 'fill_', 'normal_', 'uniform_', + 'copy_', 'erfinv_', 'log_', "__getitem__", "neg_", "exp_", "sub_" + ] + def __new__(cls, from_existing, **kwargs): + if isinstance(from_existing, torch.nn.Parameter): + inst = super().__new__(cls, from_existing.data, from_existing.requires_grad, **from_existing.__dict__) + else: + inst = super().__new__(cls, from_existing) + inst._original_type = from_existing + # Explicitly override all in-place methods per instance + for method_name in inst._inplace_methods: + setattr(inst, method_name, MethodType(inst._skip, inst)) + + return inst + + def _skip(self, *args, **kwargs): + """Helper to skip in-place operations.""" + return self + + def __repr__(self): + return f"LoadedParameter(data={self.data})" + + @property + def data(self): + return super().data + + @data.setter + def data(self, new): + pass + + def __lt__(self, other): return torch.Tensor.__lt__(self, other) + def __le__(self, other): return torch.Tensor.__le__(self, other) + def __gt__(self, other): return torch.Tensor.__gt__(self, other) + def __ge__(self, other): return torch.Tensor.__ge__(self, other) + def __eq__(self, other): return torch.Tensor.__eq__(self, other) + def __ne__(self, other): return torch.Tensor.__ne__(self, other) + def __iadd__(self, *args, **kwargs): return self + def __isub__(self, *args, **kwargs): return self + def __imul__(self, *args, **kwargs): return self + def __imatmul__(self, *args, **kwargs): return self + def __itruediv__(self, *args, **kwargs): return self + def __ifloordiv__(self, *args, **kwargs): return self + def __imod__(self, *args, **kwargs): return self + def __ipow__(self, *args, **kwargs): return self + def __iand__(self, *args, **kwargs): return self + def __ior__(self, *args, **kwargs): return self + def __ixor__(self, *args, **kwargs): return self + def __ilshift__(self, *args, **kwargs): return self + def __irshift__(self, *args, **kwargs): return self + + return LoadedParam + +def _materialize_copy(tensor, dtype=None): + tensor = tensor[...] + if dtype is not None: + tensor = tensor.to(dtype) + return tensor + + +def spawn_materialize(thread_pool, tensor, dtype=None) -> Future: + def _job(): + return _materialize_copy(tensor, dtype) + + return thread_pool.submit(_job) + + +def spawn_tp_materialize(thread_pool, tensor, sharding_method, tensor_idx, dtype=None) -> Future: + def _job(): + return sharding_method.shard_tensor(tensor, param_casting_dtype=dtype, tensor_idx=tensor_idx)[0] + + return thread_pool.submit(_job) + + +def dot_natural_key(s: str): + parts = s.split(".") + for i, p in enumerate(parts): + # whole-segment digits -> int; otherwise leave as str + if p.isdigit(): + parts[i] = int(p) + return parts + + +@contextmanager +def log_to_misc( + layer_name: str, + misc: MutableMapping[str, str], + extras: Any = None, + op: Union[list[ConversionOps], ConversionOps, None] = None, +): + # A simple helper to handle errors with contextual messages. + try: + yield + except Exception as e: + + def _format_op_name(curr_op: Union[list[ConversionOps], ConversionOps, None]) -> Optional[str]: + if curr_op is None: + return None + if isinstance(curr_op, (list, tuple, set)): + names = [o.__class__.__name__ for o in curr_op if o is not None] + if not names: + return None + return ", ".join(names) + return curr_op.__class__.__name__ + + op_name = _format_op_name(op) + if isinstance(extras, tuple) and len(extras) == 2: + values, target_keys = extras + descriptor = f"{op_name} " if op_name else "" + misc[layer_name] = ( + f"{e}\nError: {descriptor}on tensors destined for {target_keys}. Ckpt contains: {len(values[0])}" + ) + elif isinstance(extras, str): + suffix = f" via {op_name}" if op_name else "" + misc[layer_name] = f"{e}\nError{suffix} when processing parameter {extras}" + elif extras is None and op_name: + misc[layer_name] = f"{op_name}: {e}" + else: + misc[layer_name] = f"{extras} |Error: {e}" + raise SkipLayer() + + +def set_param_for_module( + model: torch.nn.Module, + layer_name: str, + param_value: torch.Tensor, + meta_model_state_dict: MutableMapping[str, Any], + empty_param: torch.Tensor, + mismatch_keys: MutableSet[tuple[str, torch.Size, torch.Size]], + missing_keys: MutableSet[str], + misc: MutableMapping[str, Any], + distributed_operation: Optional[TensorParallelLayer], + hf_quantizer, +): + with log_to_misc(layer_name, misc, layer_name): + module_path, _, param_name = layer_name.rpartition(".") + module_obj = model.get_submodule(module_path) if module_path else model + if isinstance(param_value, list): + param_value = param_value[0] + elif not isinstance(param_value, torch.nn.Parameter): + param_value = param_value[...] + ref = meta_model_state_dict.get(layer_name, empty_param) + + + use_dtensor = hasattr(distributed_operation, "use_dtensor") and distributed_operation.use_dtensor + if not isinstance(param_value, torch.nn.Parameter): + if distributed_operation is not None: + param_value = DTensor.from_local( + param_value, + distributed_operation.device_mesh, + getattr(distributed_operation, "shard", Replicate()), + run_check=False, + shape=ref.size(), + stride=ref.stride(), + ) + if not use_dtensor: + # we convert to local + param_value = param_value.to_local() + + if param_name not in module_obj._buffers: + param_value = torch.nn.Parameter(param_value, requires_grad=param_value.is_floating_point()) + + # to skip any inplace method that modifies the param data + param_value = get_loaded_parameter_class(param_value.__class__)(from_existing=param_value) + + # skip mismatch for hf_quantizer for now + if ref is not None and ref.shape != param_value.shape and hf_quantizer is None: + mismatch_keys.add((layer_name, param_value.shape, ref.shape)) + setattr(module_obj._parameters[param_name], "_is_hf_initialized", False) # Needs to be initialized + missing_keys.discard(layer_name) + else: + missing_keys.discard(layer_name) + param_value._is_hf_initialized = True # super important otherwise _init_weight re-initi if bias is missing + setattr(module_obj, param_name, param_value) + + +class SkipLayer(Exception): + """Control-flow sentinel: abort processing of the current layer only.""" + + pass + + +def convert_and_load_state_dict_in_model( + model, + state_dict, + weight_mapping, + tp_plan, + hf_quantizer, + dtype=None, + device_map=None, + dtype_plan=None, + device_mesh=None, + loading_task_model_from_base_state_dict: bool = False, + loading_base_model_from_task_state_dict: bool = False, +): + """ + Convert a state dict according to a weight mapping (one WeightConverter per glob pattern), + collecting tensors per *layer instance* (the concrete indices captured from '*'). + """ + + prefix = model.base_model_prefix + tp_plan = tp_plan or {} # {glob_pattern: plan_obj_or_key} + device_map = device_map or {} # {exact_target_key: device} + dtype_plan = dtype_plan or {} # {glob_pattern: dtype} + weight_mapping = weight_mapping or {} # {glob_pattern: WeightConverter} + meta_model_state_dict = model.state_dict() + missing_keys = set(meta_model_state_dict.keys()) + + misc = {} + mismatch_keys = set() + unexpected_keys = set() + # Global thread_pool + thread_pool = ThreadPoolExecutor(max_workers=GLOBAL_WORKERS) + + _patterns = list(itertools.chain.from_iterable([k.source_keys for k in weight_mapping])) + source_to_target = {sk: k for k in weight_mapping for sk in k.source_keys} + weight_pattern_alt, weight_pattern_by_group_name = build_glob_alt(_patterns) + tp_plan_alt, tp_plan_by_group_name = build_glob_alt(list(tp_plan.keys())) + dtype_policy_alt, dtype_policy_by_group_name = build_glob_alt(list(dtype_plan.keys())) + + state_dict = sorted(state_dict.items(), key=lambda kv: dot_natural_key(kv[0])) + # 1. Create the conversion entries + by_conversion_pattern: dict[str, ConversionEntry] = {} + for original_key, tensor in state_dict: + matched_pattern = match_glob(original_key, weight_pattern_alt, weight_pattern_by_group_name) + if matched_pattern is not None: + converter = source_to_target[matched_pattern] # TODO make sure its the ref + sub_with_extractor = partial(re.sub, _glob_to_regex_src(matched_pattern), string=original_key) + entry_key = "|".join(converter.target_keys) + target_key = "|".join(map(sub_with_extractor, [k.replace("*", "\\1") for k in converter.target_keys])) + entry: ConversionEntry = by_conversion_pattern.setdefault(entry_key, ConversionEntry(converter)) + converter_key = sub_with_extractor(matched_pattern) + else: + converter = WeightConverter(original_key) + converter_key = entry_key = target_key = original_key + entry = by_conversion_pattern.setdefault(converter_key, ConversionEntry(converter)) + + _dtype = dtype + new_target_key = [] # test_load_with_mismatched_shapes for AutoModel.from_pretrained(AutoForCausal, vocab=10) + for t in target_key.split("|"): + if t.startswith(prefix) and meta_model_state_dict.get(t.replace(f"{prefix}.", "")) is not None: + t = t.replace(f"{prefix}.", "") + elif meta_model_state_dict.get(f"{prefix}.{t}") is not None: + t = f"{prefix}.{t}" + new_target_key.append(t) + empty_param = meta_model_state_dict.get(t) + # If it does not exist, it's unexpected + if empty_param is None: + if hf_quantizer is not None and hf_quantizer.is_valid_unexpected_keys(t): + pass + else: + unexpected_keys.add(t) + continue + + if hf_quantizer is not None and hf_quantizer.param_needs_quantization(model, t): + converter.quantization_operation = hf_quantizer.get_quantize_ops() + # TODO: to clean later. We need to use the empty_param from the checkpoint to decide if we upcast the param to a specific dtype + k_dtype = tensor.get_dtype() + dtype = str_to_torch_dtype[k_dtype] + empty_param_checkpoint = torch.empty(size=tensor.get_shape(), dtype=dtype, device="meta") + _, _dtype = _infer_parameter_dtype(model, t, empty_param_checkpoint, hf_quantizer) + else: + _dtype = dtype + matched_dtype_pattern = match_glob(t, dtype_policy_alt, dtype_policy_by_group_name) + if matched_dtype_pattern is not None: + _dtype = dtype_plan[matched_dtype_pattern] + elif empty_param.dtype != _dtype: + _dtype = empty_param.dtype + + first_target_key = new_target_key[0] + target_key = "|".join(new_target_key) + + future = None + if device_mesh: + if matched_tp_pattern := match_glob(first_target_key, tp_plan_alt, tp_plan_by_group_name): + empty_param = meta_model_state_dict.get(first_target_key) + if getattr(converter, "distributed_operation", {}) is None: + tp_layer = ALL_PARALLEL_STYLES[model.tp_plan[matched_tp_pattern]].__class__ + converter.distributed_operation = tp_layer( + device_mesh=device_mesh, rank=device_map[""].index, empty_param=empty_param.clone() + ) + # VERY IMPORTANT: this tells us wether we collected stuffs or not. + shard_index = len(entry.collected_tensors[target_key].get(converter_key, [])) + future = spawn_tp_materialize( + thread_pool, + tensor, + _dtype, + converter.distributed_operation, + shard_index, + ) + + if future is None: # If not TP, async materialize the tensors. TODO handle disk offload? + future = spawn_materialize(thread_pool, tensor, _dtype) + entry.collected_tensors[target_key].setdefault(converter_key, []).append(future) + + # 2. Actually convert the ckpt + inverse_converters = {} + keys = list(by_conversion_pattern.keys()) + + with logging.tqdm(total=len(keys), desc="Loading weights") as pbar: + for key in keys[::-1]: # revert to process simple keys first + group = by_conversion_pattern.pop(key) + converter = group.weight_converter + operations = converter.operations if isinstance(converter.operations, list) else [converter.operations] + for layer_name, tensors_for_this_layer in group.collected_tensors.items(): + concrete_target_keys = layer_name.split("|") + try: + if bool(set(concrete_target_keys) - unexpected_keys): + with log_to_misc(layer_name, misc): + values = [[k.result() for k in inner] for inner in tensors_for_this_layer.values()] + + for op in operations: + with log_to_misc(layer_name, misc, (values, concrete_target_keys), operations): + values = op.convert(values, model.config) + + values = [values] if not isinstance(values, list) else values + with log_to_misc(layer_name, misc, (values, concrete_target_keys), operations): + realized_value = { + k: t for k, t in zip(concrete_target_keys, values) if k not in unexpected_keys + } + + for k in list(realized_value.keys()).copy(): + if op := converter.quantization_operation: + with log_to_misc(layer_name, misc, op=op): + realized_value.update( + op.convert({k: realized_value.pop(k)}, model=model, missing_keys=missing_keys) + ) + + for k, output_value in realized_value.items(): + for src in converter.source_keys: # what should happen to k when we meet k at saving + inverse_converters[k] = {src: converter} + set_param_for_module( + model, + k, + output_value, + meta_model_state_dict, + empty_param, + mismatch_keys, + missing_keys, + misc, + converter.distributed_operation, + hf_quantizer + ) + except Exception as e : + raise e + del group + + # Update progress bar + pbar.update() + pbar.refresh() + + model.inverse_converters = inverse_converters + thread_pool.shutdown(wait=False) + return missing_keys, unexpected_keys, mismatch_keys, misc + + +# TODO this is not done yet! +def revert_weight_conversion(model, state_dict): + mapping = getattr(model, "", {}) # IDK why but setting this will fail all llava. + reverse_key_mapping = [(v, k) for k, v in mapping.items()] + original_state_dict = {} + for key, value in state_dict.items(): + for pattern, inverse_converter in reverse_key_mapping: + # TODO FIXME you name it + replacement = inverse_converter.lstrip("^") # strip off un-needed chars and patterns + replacement = re.sub(r"\(.*\)", "", replacement) + key, n_replace = re.subn(pattern, replacement, key) + # Early exit of the loop + if n_replace > 0: + break + original_state_dict[key] = value + state_dict = original_state_dict + return state_dict + +def _infer_parameter_dtype( + model: torch.nn.Module, + param_name: str, + empty_param: torch.Tensor, + hf_quantizer: Optional[HfQuantizer] = None, +) -> tuple[bool, Optional[torch.dtype]]: + try: + old_param = model.get_parameter_or_buffer(param_name) + except Exception as e: + if hf_quantizer is not None and hf_quantizer.quantization_config.quant_method in { + QuantizationMethod.HQQ, + QuantizationMethod.QUARK, + QuantizationMethod.MXFP4, + QuantizationMethod.BITS_AND_BYTES, + }: + return True, None + else: + raise e + is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn") + # We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params + # in int/uint/bool and not cast them. + casting_dtype = None + is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn + if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn: + # dtype that was instantiated in the meta model -- note that this respects subconfigs dtypes + if hf_quantizer is not None and hf_quantizer.param_needs_quantization(model, param_name): + casting_dtype = model.config._pre_quantization_dtype + else: + casting_dtype = old_param.dtype + return old_param is not None and old_param.is_contiguous(), casting_dtype diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index 93203ed665fa..0bf29520fe86 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -48,7 +48,7 @@ "pyyaml": "pyyaml>=5.1", "pydantic": "pydantic>=2", "pytest": "pytest>=7.2.0", - "pytest-asyncio": "pytest-asyncio", + "pytest-asyncio": "pytest-asyncio>=1.2.0", "pytest-rerunfailures": "pytest-rerunfailures<16.0", "pytest-timeout": "pytest-timeout", "pytest-xdist": "pytest-xdist", diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 2c407ecfd919..80d19b097f55 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1635,7 +1635,12 @@ def _validate_model_kwargs(self, model_kwargs: dict[str, Any]): # TransformersKwargs are model-agnostic attention and generation arguments such as 'output_attentions' for key, value in model_kwargs.items(): - if value is not None and key not in model_args and key not in TransformersKwargs.__optional_keys__: + if ( + value is not None + and key not in model_args + and key not in TransformersKwargs.__optional_keys__ + and key != "debug_io" + ): unused_model_args.append(key) if unused_model_args: diff --git a/src/transformers/generation/watermarking.py b/src/transformers/generation/watermarking.py index ed8813b4b33c..da978c3c107e 100644 --- a/src/transformers/generation/watermarking.py +++ b/src/transformers/generation/watermarking.py @@ -383,10 +383,11 @@ def __init__(self, config): ) self.prior = torch.nn.Parameter(torch.tensor([self.base_rate])) + @torch.no_grad() def _init_weights(self, module): """Initialize the weights.""" if isinstance(module, nn.Parameter): - module.weight.data.normal_(mean=0.0, std=0.02) + module.weight.normal_(mean=0.0, std=0.02) def _compute_posterior( self, diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index 237d7420997f..e6f8d1c6afcb 100755 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -36,6 +36,7 @@ "get_keys_to_not_convert", "replace_with_bnb_linear", "validate_bnb_backend_availability", + "Bnb4bitQuantize", ], "deepspeed": [ "HfDeepSpeedConfig", @@ -122,6 +123,7 @@ "quantize_to_mxfp4", "replace_with_mxfp4_linear", "swizzle_mxfp4", + "Mxfp4Quantize", ], "peft": ["PeftAdapterMixin"], "quanto": ["replace_with_quanto_layers"], @@ -177,6 +179,7 @@ unpack_weights, ) from .bitsandbytes import ( + Bnb4bitQuantize, dequantize_and_replace, get_keys_to_not_convert, replace_with_bnb_linear, @@ -256,6 +259,7 @@ ) from .mxfp4 import ( Mxfp4GptOssExperts, + Mxfp4Quantize, dequantize, load_and_swizzle_mxfp4, quantize_to_mxfp4, diff --git a/src/transformers/integrations/accelerate.py b/src/transformers/integrations/accelerate.py index 79ef98d8a4dc..070d4a072fc5 100644 --- a/src/transformers/integrations/accelerate.py +++ b/src/transformers/integrations/accelerate.py @@ -435,6 +435,7 @@ def _get_device_map( if max_memory is not None and device_name in max_memory: inferred_max_memory[device_name] = min(inferred_max_memory[device_name], max_memory[device_name]) + model.tie_weights() device_map = infer_auto_device_map( model, max_memory=inferred_max_memory, @@ -512,10 +513,8 @@ def accelerate_disk_offload( checkpoint_files, device_map, checkpoint_keys, - key_renaming_mapping, sharded_metadata, dtype, - reverse_key_renaming_mapping, ): disk_only_shard_files = [] if disk_offload_folder is not None: @@ -534,19 +533,13 @@ def accelerate_disk_offload( weight_map = dict.fromkeys(checkpoint_keys, checkpoint_files[0]) else: folder = os.path.sep.join(checkpoint_files[0].split(os.path.sep)[:-1]) - # Fix the weight map keys according to the key mapping - weight_map = { - key_renaming_mapping[k]: v - for k, v in sharded_metadata["weight_map"].items() - if k in key_renaming_mapping - } weight_map = {k: os.path.join(folder, v) for k, v in weight_map.items()} # Find potential checkpoints containing only offloaded weights disk_only_shard_files = get_disk_only_shard_files(device_map, weight_map) disk_offload_index = { name: { "safetensors_file": file, - "weight_name": reverse_key_renaming_mapping[name], + "weight_name": name, "dtype": str_dtype, } for name, file in weight_map.items() diff --git a/src/transformers/integrations/bitsandbytes.py b/src/transformers/integrations/bitsandbytes.py index be117ff3013e..f0818a1407ed 100644 --- a/src/transformers/integrations/bitsandbytes.py +++ b/src/transformers/integrations/bitsandbytes.py @@ -1,7 +1,9 @@ import inspect -from copy import deepcopy +from collections import defaultdict from inspect import signature +from typing import Optional +from ..quantizers.quantizers_utils import get_module_from_name from ..utils import ( get_available_devices, is_accelerate_available, @@ -24,10 +26,52 @@ import accelerate from accelerate import init_empty_weights from accelerate.hooks import add_hook_to_module, remove_hook_from_module - from accelerate.utils import find_tied_parameters logger = logging.get_logger(__name__) +from ..core_model_loading import ConversionOps + + +class Bnb4bitQuantize(ConversionOps): + def __init__(self, hf_quantizer): + self.hf_quantizer = hf_quantizer + + def convert(self, input_dict: torch.Tensor, model: Optional[torch.nn.Module] = None, **kwargs) -> dict[str, torch.Tensor]: + target_key, value = tuple(input_dict.items())[0] + value = value[0] if isinstance(value, list) else value + + full_name = target_key + # update param name to get the weights instead of the quantized stats + target_key = self.hf_quantizer.get_param_name(target_key) + module, _ = get_module_from_name(model, target_key) + + if not self.hf_quantizer.pre_quantized: + # Support models using `Conv1D` in place of `nn.Linear` (e.g. openai-community/gpt2) by transposing the weight matrix prior to quantization. + # Since weights are saved in the correct "orientation", we skip transposing when loading. + if issubclass(module.source_cls, Conv1D): + value = value.T + old_value = model.get_parameter_or_buffer(target_key) + new_value = bnb.nn.Params4bit(value, requires_grad=False, **old_value.__dict__).to(value.device) + return {target_key : new_value} + else: + module_name = target_key.rsplit(".", 1)[0] + # Save the states for later quantization when they are all gathered + if not hasattr(self.hf_quantizer, "param_quant_stats"): + self.hf_quantizer.param_quant_stats = defaultdict(dict) + self.hf_quantizer.param_quant_stats[module_name].update({full_name: value}) + # We are ready for quantization in this case (note, the +1 is for the weight itself) + if len(self.hf_quantizer.param_quant_stats[module_name]) == len(self.hf_quantizer.bnb_keys) + 1: + weight = self.hf_quantizer.param_quant_stats[module_name].pop(f"{module_name}.weight") + new_value = bnb.nn.Params4bit.from_prequantized( + data=weight, + quantized_stats=self.hf_quantizer.param_quant_stats[module_name], + requires_grad=False, + device=value.device, + module=module + ) + del self.hf_quantizer.param_quant_stats[module_name] + return {target_key : new_value} + return {} def _replace_with_bnb_linear( model, @@ -151,52 +195,6 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name return model -def get_keys_to_not_convert(model): - r""" - An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules - we may want to keep the lm_head in full precision for numerical stability reasons. For other architectures, we want - to keep the tied weights of the model. The function will return a list of the keys of the modules to not convert in - int8. - - Parameters: - model (`torch.nn.Module`): - Input model - """ - # Create a copy of the model and tie the weights, then - # check if it contains tied weights - tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager` - tied_model.tie_weights() - - tied_params = find_tied_parameters(tied_model) - tied_keys = sum(tied_params, []) - has_tied_params = len(tied_keys) > 0 - - # If there is not tied weights, we want to keep the lm_head(output_embedding) in full precision - if not has_tied_params: - output_emb = model.get_output_embeddings() - if output_emb is not None: - list_last_module = [name for name, module in model.named_modules() if id(module) == id(output_emb)] - return list_last_module - - # otherwise, no tied weights, no output embedding defined, simply keep the last module in full precision - list_modules = list(model.named_parameters()) - list_last_module = [list_modules[-1][0]] - # add last module together with tied weights - intersection = set(list_last_module) - set(tied_keys) - list_untouched = list(set(tied_keys)) + list(intersection) - - # remove ".weight" from the keys - names_to_remove = [".weight", ".bias"] - filtered_module_names = [] - for name in list_untouched: - for name_to_remove in names_to_remove: - if name_to_remove in name: - name = name.replace(name_to_remove, "") - filtered_module_names.append(name) - - return filtered_module_names - - # Copied from PEFT: https://github.com/huggingface/peft/blob/47b3712898539569c02ec5b3ed4a6c36811331a1/src/peft/utils/integrations.py#L41 def dequantize_bnb_weight(weight: "torch.nn.Parameter", dtype: "torch.dtype", state=None): """ diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index 8156f1045baa..0ada9460f0e8 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -13,8 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +import re +from collections.abc import Sequence +from typing import Any, Optional, Union +from ..core_model_loading import ConversionOps from ..utils import is_accelerate_available, is_torch_accelerator_available, is_torch_available, logging @@ -30,6 +33,18 @@ logger = logging.get_logger(__name__) +try: + _FP8_DTYPE = torch.float8_e4m3fn + _FP8_MIN = torch.finfo(_FP8_DTYPE).min + _FP8_MAX = torch.finfo(_FP8_DTYPE).max + _FP8_IS_INT = False +except AttributeError: + _FP8_DTYPE = torch.int8 + _FP8_MIN, _FP8_MAX = -127, 127 + _FP8_IS_INT = True + logger.warning_once( + "torch.float8_e4m3fn not available; falling back to int8 emulation for Fp8Quantize operations." + ) # Copied from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py @@ -332,6 +347,12 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: if self.weight.element_size() > 1: return F.linear(input, self.weight, self.bias) else: + if isinstance(self.weight, torch.distributed.tensor.DTensor): + weight = self.weight._local_tensor.contiguous() + scale_inv = self.weight_scale_inv._local_tensor.contiguous() + else: + weight = self.weight.contiguous() + scale_inv = self.weight_scale_inv.contiguous() # Context manager used to switch among the available accelerators device_type = torch.accelerator.current_accelerator().type if is_torch_accelerator_available() else "cuda" torch_accelerator_module = getattr(torch, device_type, torch.cuda) @@ -339,9 +360,9 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: qinput, scale = act_quant(input, self.block_size[1]) output = w8a8_block_fp8_matmul_triton( qinput, - self.weight, + weight, scale, - self.weight_scale_inv, + scale_inv, self.block_size, output_dtype=input.dtype, ) @@ -350,9 +371,124 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: torch_accelerator_module.synchronize() if self.bias is not None: output = output + self.bias + output = torch.nan_to_num(output, nan=0.0) + return output.to(dtype=input.dtype) + + +def _ceil_div(a, b): + return (a + b - 1) // b + + +class FP8Expert(nn.Module): + dtype = torch.float8_e4m3fn + + def __init__(self, config, block_size, device): + super().__init__() + + from ..activations import ACT2FN + + self.block_size = block_size + self.num_experts = config.num_local_experts + self.hidden_dim = config.hidden_size + self.intermediate_dim = config.intermediate_size + + Wg_out, Wg_in = 2 * self.intermediate_dim, self.hidden_dim + Wd_out, Wd_in = self.hidden_dim, self.intermediate_dim + + self.gate_up_proj = nn.Parameter( + torch.zeros(self.num_experts, Wg_out, Wg_in, dtype=FP8Expert.dtype, device=device) + ) + self.down_proj = nn.Parameter( + torch.zeros(self.num_experts, Wd_out, Wd_in, dtype=FP8Expert.dtype, device=device) + ) + + # Create inverse scale tiles only when using 1-byte types (fp8) + if self.gate_up_proj.element_size() == 1: + bo, bi = self.block_size + + # gate_up tiles: ceil(Wg_out/bo) x ceil(Wg_in/bi) + gu_scale_o = _ceil_div(Wg_out, bo) + gu_scale_i = _ceil_div(Wg_in, bi) + self.gate_up_proj_scales_inv = nn.Parameter( + torch.zeros(self.num_experts, gu_scale_o, gu_scale_i, dtype=torch.float32, device=device) + ) + + # down tiles: ceil(Wd_out/bo) x ceil(Wd_in/bi) + dp_scale_o = _ceil_div(Wd_out, bo) + dp_scale_i = _ceil_div(Wd_in, bi) + self.down_proj_scales_inv = nn.Parameter( + torch.zeros(self.num_experts, dp_scale_o, dp_scale_i, dtype=torch.float32, device=device) + ) + else: + # Match FP8Linear behavior when not using 1-byte weights + self.register_parameter("gate_up_proj_scale_inv", None) + self.register_parameter("down_proj_scale_inv", None) + + # (Optional) bias per projection — many MoEs omit bias; keep None to match your FP8Linear default + self.register_parameter("gate_up_bias", None) + self.register_parameter("down_bias", None) + + # Activation used in the MLP (same as your config / ACT2FN) + # Keep a handle here; actual usage happens in forward of your MoE block + self.act_fn = ACT2FN[config.hidden_act] + + def forward( + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + ) -> torch.Tensor: + final_hidden_states = torch.zeros_like(hidden_states) + num_experts = top_k_weights.shape[1] + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + + for expert_idx in expert_hit: + expert_idx = expert_idx[0] + if expert_idx == num_experts: + continue + _, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states.index_select(0, token_idx) + gate, up = self.linear( + current_state, self.gate_up_proj[expert_idx], self.gate_up_proj_scales_inv[expert_idx] + ).chunk(2, dim=-1) + current_hidden_states = self.act_fn(gate) * up + current_hidden_states = self.linear( + current_hidden_states, self.down_proj[expert_idx], self.down_proj_scales_inv[expert_idx] + ) + + routing_weights = top_k_weights[token_idx, expert_idx].unsqueeze(-1) + current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype) + final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) + + return final_hidden_states + + def linear(self, input: torch.Tensor, weight: torch.Tensor, weight_scale_inv: torch.Tensor) -> torch.Tensor: + if weight.element_size() > 1: + return F.linear(input, weight, None) + else: + # Context manager used to switch among the available accelerators + device_type = torch.accelerator.current_accelerator().type if is_torch_accelerator_available() else "cuda" + torch_accelerator_module = getattr(torch, device_type, torch.cuda) + with torch_accelerator_module.device(input.device): + qinput, scale = act_quant(input, self.block_size[1]) + output = w8a8_block_fp8_matmul_triton( + qinput, + weight, + scale, + weight_scale_inv, + self.block_size, + output_dtype=input.dtype, + ) + # Blocks the CPU until all accelerator operations on the specified device are complete. It is used to ensure that the results of the + # preceding operations are ready before proceeding + torch_accelerator_module.synchronize() return output.to(dtype=input.dtype) +# TODO: we do need this.... but not recursive... def _replace_with_fp8_linear( model, tp_plan=None, @@ -361,40 +497,48 @@ def _replace_with_fp8_linear( quantization_config=None, has_been_replaced=False, ): - """Replace Linear layers with FP8Linear.""" - if current_key_name is None: - current_key_name = [] - - for name, module in model.named_children(): - current_key_name.append(name) - - if isinstance(module, nn.Linear) and name not in (modules_to_not_convert or []): - current_key_name_str = ".".join(current_key_name) - if not any(key in current_key_name_str for key in (modules_to_not_convert or [])): - with init_empty_weights(): - model._modules[name] = FP8Linear( - in_features=module.in_features, - out_features=module.out_features, - bias=module.bias is not None, - device=module.weight.device, - dtype=module.weight.dtype, - activation_scheme=quantization_config.activation_scheme, - block_size=quantization_config.weight_block_size, + iterator = list(model.named_parameters()).copy() + for name, empty_tensor in iterator: + current_key_name = name + name = name.rsplit(".", 1)[0] if "." in name else name + module = model.get_submodule(name) + + current_key_name_str = re.sub(r"\d+", "*", current_key_name) + if not any(key in current_key_name_str for key in (modules_to_not_convert or [])): + with init_empty_weights(): + if ( + "gate_up_proj" in current_key_name + or "down_proj" in current_key_name + and "experts" in current_key_name + ): # Experts! + in_features = empty_tensor.size(-2) + out_features = empty_tensor.size(-1) + model.set_submodule( + name, + FP8Expert( + config=model.config, + block_size=quantization_config.weight_block_size, + device=empty_tensor.device, + ), ) - has_been_replaced = True - # when changing a layer the TP PLAN for that layer should be updated. TODO - - if len(list(module.children())) > 0: - _, has_been_replaced = _replace_with_fp8_linear( - module, - tp_plan, - modules_to_not_convert, - current_key_name, - quantization_config, - has_been_replaced=has_been_replaced, - ) - current_key_name.pop(-1) + elif isinstance(module, nn.Linear): + in_features = module.in_features + out_features = module.out_features + model.set_submodule( + name, + FP8Linear( + in_features=in_features, + out_features=out_features, + bias=module.bias is not None, + device=module.weight.device, + dtype=module.weight.dtype, + activation_scheme=quantization_config.activation_scheme, + block_size=quantization_config.weight_block_size, + ), + ) + has_been_replaced = True + # when changing a layer the TP PLAN for that layer should be updated. TODO return model, has_been_replaced @@ -405,7 +549,7 @@ def replace_with_fp8_linear( quantization_config=None, ): """Helper function to replace model layers with FP8 versions.""" - modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert + modules_to_not_convert += ["lm_head"] if quantization_config.modules_to_not_convert is not None: modules_to_not_convert.extend(quantization_config.modules_to_not_convert) @@ -424,3 +568,124 @@ def replace_with_fp8_linear( ) return model +class Fp8Quantize(ConversionOps): + """ + A quantization operation that creates two tensors, weight and scale out of a weight. + """ + + reverse_op: type[ConversionOps] + + def __init__(self, hf_quantizer): + self.hf_quantizer = hf_quantizer + self.reverse_op = Fp8Dequantize + + def convert(self, input_dict: torch.Tensor, **kwargs) -> dict[str, torch.Tensor]: + # Unpack single key/value (value may be wrapped in a list) + target_keys, value = tuple(input_dict.items())[0] + value = value[0] if isinstance(value, list) else value + + # Resolve block size (support dict-like or attr-like quant_config) + block_size = None + if self.hf_quantizer.quantization_config is not None: + if isinstance(self.hf_quantizer.quantization_config, dict): + block_size = self.hf_quantizer.quantization_config.get("weight_block_size") + else: + block_size = getattr(self.hf_quantizer.quantization_config, "weight_block_size", None) + if block_size is None: + block_size = (value.shape[-2], value.shape[-1]) + + block_m, block_n = block_size + rows, cols = value.shape[-2], value.shape[-1] + + # Enforce exact tiling like your original + if rows % block_m != 0 or cols % block_n != 0: + raise ValueError( + f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_m}, {block_n}). for {target_keys}" + ) + + # Leading dims can be empty (2D) or include num_experts/... (3D+) + leading_shape = value.shape[:-2] + rows_tiles = rows // block_m + cols_tiles = cols // block_n + + original_shape = value.shape + value_fp32 = value.to(torch.float32) + + # Reshape to (..., rows_tiles, block_m, cols_tiles, block_n) + reshaped = value_fp32.reshape(*leading_shape, rows_tiles, block_m, cols_tiles, block_n) + + # Per-tile max-abs over the block dims + # dims: block_m is at -3, block_n is at -1 after the reshape + max_abs = reshaped.abs().amax(dim=(-3, -1)) + safe_max_abs = torch.where(max_abs > 0, max_abs, torch.ones_like(max_abs)) + + # Tile scale (we store inverse scale like your Linear: weight_scale_inv) + scales = _FP8_MAX / safe_max_abs + scales = torch.where(max_abs > 0, scales, torch.ones_like(scales)) # keep zeros stable + + # Broadcast scales back over the block dims and quantize + # max_abs/scales shape: (..., rows_tiles, cols_tiles) + scales_broadcast = scales.unsqueeze(-1).unsqueeze(-3) # -> (..., rows_tiles, 1, cols_tiles, 1) + scaled = reshaped * scales_broadcast + + if _FP8_IS_INT: + quantized = torch.clamp(scaled.round(), min=_FP8_MIN, max=_FP8_MAX).to(_FP8_DTYPE) + else: + quantized = torch.clamp(scaled, min=_FP8_MIN, max=_FP8_MAX).to(_FP8_DTYPE) + + quantized = quantized.reshape(original_shape) + + inv_scales = (1.0 / scales).to(torch.float32) # shape: (*leading, rows_tiles, cols_tiles) + if target_keys.endswith("weight"): + scale_key = target_keys.rsplit(".", 1)[0] + ".weight_scale_inv" + else: + scale_key = target_keys + "_scales_inv" + + # Return both quantized weights and per-tile inverse scales (keeps leading dims, e.g., num_experts) + return { + target_keys: quantized, + scale_key: inv_scales, + } + +class Fp8Dequantize(ConversionOps): + """Inverse operation of :class:`Fp8Quantize`. Takes a pair (weight, scale) and reconstructs the fp32 tensor.""" + + def __init__(self, block_size: Optional[tuple[int, int]] = None): + self.block_size = block_size + self.reverse_op = Fp8Quantize + + def convert( + self, + value: Union[Sequence[torch.Tensor], dict[str, torch.Tensor]], + *, + context: dict[str, Any], + ) -> torch.Tensor: + if isinstance(value, dict): + tensors = list(value.values()) + else: + tensors = list(value) if isinstance(value, Sequence) else [value] + if len(tensors) != 2: + raise ValueError("Fp8Dequantize expects exactly two tensors: quantized weights and scales.") + quantized, scales = tensors + if not isinstance(quantized, torch.Tensor) or not isinstance(scales, torch.Tensor): + raise TypeError("Fp8Dequantize expects tensors as inputs.") + + quantized_fp32 = quantized.to(torch.float32) + rows, cols = quantized_fp32.shape[-2:] + block_size = self.block_size + if block_size is None: + quant_config = context.get("quantization_config") + block_size = getattr(quant_config, "weight_block_size", None) + if block_size is None: + block_size = (rows, cols) + block_m, block_n = block_size + if rows % block_m != 0 or cols % block_n != 0: + raise ValueError( + f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_m}, {block_n})." + ) + + reshaped = quantized_fp32.reshape(-1, rows // block_m, block_m, cols // block_n, block_n) + expanded_scales = scales.to(torch.float32).reshape(-1, rows // block_m, cols // block_n) + expanded_scales = expanded_scales.unsqueeze(-1).unsqueeze(2) + dequantized = reshaped * expanded_scales + return dequantized.reshape(quantized_fp32.shape) diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index 6a6ce1db17e7..cc8840fd496d 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + +from ..core_model_loading import ConversionOps, get_loaded_parameter_class from ..utils import is_accelerate_available, is_torch_available, logging @@ -25,6 +28,8 @@ import re from contextlib import contextmanager +from ..quantizers.quantizers_utils import get_module_from_name + logger = logging.get_logger(__name__) @@ -48,6 +53,61 @@ ] +class Mxfp4Quantize(ConversionOps): + def __init__(self, hf_quantizer): + self.hf_quantizer = hf_quantizer + def convert( + self, input_dict: dict[str, torch.Tensor], model: Optional[torch.nn.Module] = None, missing_keys: Optional[list[str]] = None, **kwargs + ) -> dict[str, torch.Tensor]: + target_key, value = tuple(input_dict.items())[0] + value = value[0] if isinstance(value, list) else value + if not self.hf_quantizer.pre_quantized: + module, _ = get_module_from_name(model, target_key) + with torch.device(value.device): + if isinstance(module, Mxfp4GptOssExperts): + triton_weight_tensor, weight_scale = quantize_to_mxfp4(value, triton_kernels_hub) + PrecisionConfig, FlexCtx, InFlexData = ( + triton_kernels_hub.matmul_ogs.PrecisionConfig, + triton_kernels_hub.matmul_ogs.FlexCtx, + triton_kernels_hub.matmul_ogs.InFlexData, + ) + triton_weight_tensor, weight_scale = swizzle_mxfp4( + triton_weight_tensor, weight_scale, triton_kernels_hub + ) + + proj = "gate_up_proj" if "gate_up_proj" in target_key else "down_proj" + setattr(module, proj, triton_weight_tensor) + setattr( + module, + f"{proj}_precision_config", + PrecisionConfig(weight_scale=weight_scale, flex_ctx=FlexCtx(rhs_data=InFlexData())), + ) + missing_keys.discard(f"{target_key}_blocks") + missing_keys.discard(f"{target_key}_scales") + delattr(module, f"{proj}_blocks") + delattr(module, f"{proj}_scales") + + else: + + if ("blocks" in target_key or "scales" in target_key) and self.hf_quantizer.quantization_config.dequantize: + # blocks and scales have the same length that's why this works for both + module, _ = get_module_from_name(model, target_key[: -len("_blocks")]) + else: + module, _ = get_module_from_name(model, target_key) + + if self.hf_quantizer.quantization_config.dequantize: + dequantize_convertops(module, target_key, value, value.device, missing_keys) + else: + # Eagerly set tensors on the module and perform swizzle; this function will + # set the appropriate attributes and remove *_blocks/_scales when both are loaded. + load_and_swizzle_mxfp4_convertops(module, target_key, value, value.device, missing_keys, triton_kernels_hub) + + # We return an empty mapping since the module was updated in-place. This prevents + # the loader from trying to materialize the original meta-parameter names again. + # We don't use set_param_for_module since it expects mainly a torch.nn.Parameter or a safetensors pointer + return {} + + @contextmanager def on_device(dev): if is_torch_available(): @@ -88,13 +148,11 @@ def swizzle_mxfp4(w, w_scale, triton_kernels_hub): ) layout = triton_kernels_hub.tensor_details.layout StridedLayout = triton_kernels_hub.tensor_details.layout.StridedLayout - value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1) w = convert_layout(wrap_torch_tensor(w, dtype=FP4), value_layout, **value_layout_opts) w_scale = convert_layout(wrap_torch_tensor(w_scale), StridedLayout) return w, w_scale - # Copied from GPT_OSS repo # TODO: Add absolute link when the repo is public def convert_moe_packed_tensors( @@ -355,6 +413,22 @@ def dequantize(module, param_name, param_value, target_device, dq_param_name, ** delattr(module, blocks_attr) delattr(module, scales_attr) +def dequantize_convertops(module, param_name, param_value, target_device, missing_keys): + for proj in ["gate_up_proj", "down_proj"]: + if proj in param_name: + blocks_attr = f"{proj}_blocks" + scales_attr = f"{proj}_scales" + setattr(module, param_name.rsplit(".", 1)[1], param_value) + if hasattr(module, blocks_attr) and hasattr(module, scales_attr): + dequantized = convert_moe_packed_tensors(getattr(module, blocks_attr), getattr(module, scales_attr)) + if target_device == "cpu" and torch.cuda.is_available(): + torch.cuda.empty_cache() + dequantized = torch.nn.Parameter(dequantized.to(target_device)) + dequantized = get_loaded_parameter_class(dequantized.__class__)(from_existing=dequantized) + setattr(module, proj, dequantized) + missing_keys.discard(param_name.rsplit("_", 1)[0]) + delattr(module, blocks_attr) + delattr(module, scales_attr) def load_and_swizzle_mxfp4(module, param_name, param_value, target_device, triton_kernels_hub, **kwargs): """ @@ -423,6 +497,68 @@ def load_and_swizzle_mxfp4(module, param_name, param_value, target_device, trito del blocks +def load_and_swizzle_mxfp4_convertops(module, param_name, param_value, target_device, missing_keys, triton_kernels_hub): + """ + This transforms the weights obtained using `convert_gpt_oss.py` to load them into `Mxfp4GptOssExperts`. + """ + PrecisionConfig, FlexCtx, InFlexData = ( + triton_kernels_hub.matmul_ogs.PrecisionConfig, + triton_kernels_hub.matmul_ogs.FlexCtx, + triton_kernels_hub.matmul_ogs.InFlexData, + ) + + if "blocks" in param_name: + proj = param_name.split(".")[-1].split("_blocks")[0] + if "scales" in param_name: + proj = param_name.split(".")[-1].split("_scales")[0] + + setattr(module, param_name.rsplit(".", 1)[1], torch.nn.Parameter(param_value, requires_grad=False)) + missing_keys.discard(param_name) + + blocks_attr = f"{proj}_blocks" + scales_attr = f"{proj}_scales" + blocks = getattr(module, blocks_attr) # at this point values were loaded from ckpt + scales = getattr(module, scales_attr) + + # check if blocks or scales are not on meta device + # if so, it means param_value is either a blocks or scales tensor + # and the other blocks or scales tensor is on the correct device + + if blocks.device.type != "meta" and scales.device.type != "meta": + local_experts = blocks.size(0) + if blocks.device.type == "meta": + blocks = param_value + elif scales.device.type == "meta": + scales = param_value + + if proj == "gate_up_proj": + blocks = blocks.reshape(local_experts, module.intermediate_size * 2, -1) + else: + blocks = blocks.reshape(local_experts, -1, module.intermediate_size // 2) + if getattr(target_device, "type", target_device) == "cpu": + target_device = "cuda" + + blocks = blocks.to(target_device).contiguous() + scales = scales.to(target_device).contiguous() + with on_device(target_device): + triton_weight_tensor, weight_scale = swizzle_mxfp4( + blocks.transpose(-2, -1), scales.transpose(-2, -1), triton_kernels_hub + ) + # need to overwrite the shapes for the kernels + if proj == "gate_up_proj": + triton_weight_tensor.shape = torch.Size([local_experts, module.hidden_size, module.intermediate_size * 2]) + else: + triton_weight_tensor.shape = torch.Size([local_experts, module.intermediate_size, module.hidden_size]) + + # triton_weight_tensor is what needs to be passed in oai kernels. It stores the data, the shapes and any more objects. It's like a subtensor + setattr(module, proj, triton_weight_tensor) + setattr(module, f"{proj}_precision_config", PrecisionConfig(weight_scale=weight_scale, flex_ctx=FlexCtx(rhs_data=InFlexData()))) + delattr(module, scales_attr) + delattr(module, blocks_attr) + del blocks + del scales + + def _replace_with_mxfp4_linear( model, modules_to_not_convert=None, diff --git a/src/transformers/integrations/peft.py b/src/transformers/integrations/peft.py index 2aa515199d72..db3c5df70d91 100644 --- a/src/transformers/integrations/peft.py +++ b/src/transformers/integrations/peft.py @@ -236,7 +236,7 @@ def load_adapter( **adapter_kwargs, ) peft_config.inference_mode = not is_trainable - + # TODO: WE NEED TOO APPLY OUR DYNAMIC WEIGHT CONVERSION AT SOME POINT HERE! # Create and add fresh new adapters into the model. inject_adapter_in_model(peft_config, self, adapter_name, **peft_load_kwargs) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index f8a96d7a476e..6fa40eef0890 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -18,6 +18,7 @@ import os import re from functools import partial, reduce +from typing import Optional import torch import torch.distributed as dist @@ -306,7 +307,7 @@ def repack_weights( return final_ordered_tensor -def get_tensor_shard(param, empty_param, device_mesh, rank, dim): +def get_tensor_shard(param, empty_param, device_mesh, rank, dim, tensor_idx: Optional[int] = None): """ Generalized tensor sharding across a multi-dimensional device mesh. Extract only the fraction of the parameter owned by the given `rank` when the parameter would have gone sharding at provided `dim`. @@ -358,32 +359,57 @@ def get_tensor_shard(param, empty_param, device_mesh, rank, dim): rank (int): Global rank of the current process/device. dim (int): Dimension along which to shard the tensor. """ - param_dim = empty_param.dim() - + param_dim = empty_param.ndim + # Flatten the mesh to get the total number of devices + mesh_shape = device_mesh.shape + world_size = reduce(operator.mul, mesh_shape) if dim < 0: dim = param_dim + dim + if empty_param.dim() == 3 and dim == 1 and len(param.get_shape()) == 2: + dim = 0 + elif empty_param.dim() == 3 and dim == 2 and len(param.get_shape()) == 2: + dim = 0 + + shard_size = math.ceil(empty_param.size(dim) / world_size) + start = rank * shard_size + end = min(start + shard_size, empty_param.size(dim)) + if dim >= param_dim: raise ValueError(f"dim {dim} is out of bounds for tensor of dimension {param_dim}") - # Flatten the mesh to get the total number of devices - mesh_shape = device_mesh.shape - world_size = reduce(operator.mul, mesh_shape) - if rank >= world_size: raise ValueError(f"Rank {rank} is out of bounds for mesh size {world_size}") - shard_size = math.ceil(empty_param.shape[dim] / world_size) - start = rank * shard_size + # we have the full tensor not 1 part of it. + # in that case, we just assume that the weight was properly saved + # and thus because we TP if the layer is colwise it should not use this. Layer should be packed_colwise + # to inform that it needs to read form a packed tensor. It will also take care of the module list thingy. + # here we take care of potential chunking / layer split / layer chunking. + # The only "hard" case is? if we collect q,k,v -> merge it into qkv. In that case + # actually we still shard dim=0 does not change + # so only case is if the dim of the empty param is 3 and the shard dim is 0 -> we put the + # tensor on a certain device (with the input tensor_index) + dimensions = param.get_shape() + + if empty_param.dim() == 3 and dim == 0 and len(param.get_shape()) == 2: + # special case we don't "shard" just send this entire tensor to the correct rank. + if start <= tensor_idx < end: + # this tensor does need to be materialized on this device: + return param[:] + else: + return torch.empty([], dtype=torch.int64, device=rank) - # Construct slicing index dynamically - end = min(start + shard_size, empty_param.shape[dim]) - slice_indices = [slice(None)] * param_dim - if start < empty_param.shape[dim]: + slice_indices = [slice(None)] * len(param.get_shape()) + + if start < param.get_shape()[dim]: slice_indices[dim] = slice(start, end) - return param[tuple(slice_indices)] - dimensions = list(param.shape) + param = param[tuple(slice_indices)] + if isinstance(param, list): # TODO handle the modulelist case! + param = [p[:] for p in param] + return param + dimensions[dim] = 0 - return torch.empty(tuple(dimensions), dtype=torch.int64) + return torch.empty(tuple(dimensions), dtype=torch.int64) # empty allocates memory.... def distribute_module( @@ -410,6 +436,19 @@ class TensorParallelLayer: """ use_dtensor = True + device_mesh = None + rank = None + + # Used to compare the shape of the original tensor + empty_param = None + + # Used to init the corresponding DTensor + shard = None + + def __init__(self, device_mesh=None, rank=None, empty_param=None): + self.rank = rank + self.device_mesh = device_mesh + self.empty_param = empty_param @staticmethod def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): ... @@ -439,12 +478,12 @@ class GatherParallel(TensorParallelLayer): def __init__( self, - *, input_layouts: Placement | None = None, output_layouts: Placement | None = None, use_local_output: bool = True, + **kwargs, ): - super().__init__() + super().__init__(**kwargs) self.input_layouts = (input_layouts or Replicate(),) self.output_layouts = output_layouts self.desired_input_layouts = (Replicate(),) @@ -465,6 +504,21 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me dist.all_reduce(outputs[0], op=dist.ReduceOp.SUM, async_op=False) return outputs + def shard_tensor( + self, + param, + param_type=None, + param_casting_dtype=None, + to_contiguous=None, + rank=None, + device_mesh=None, + tensor_idx=None, + ): + shard = [Replicate()] + parameter = param[...].to(param_casting_dtype) + self.shard = shard + return parameter, shard + def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module: distribute_module( module, @@ -493,6 +547,23 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me # TODO: figure out dynamo support for instance method and switch this to instance method return outputs + def shard_tensor( + self, + param, + param_type=None, + param_casting_dtype=None, + to_contiguous=None, + rank=None, + device_mesh=None, + tensor_idx=None, + ): + mesh = device_mesh or self.device_mesh + parameter = param[...].to(param_casting_dtype) + if mesh is not None: + parameter = parameter / mesh.size() + self.shard = None + return parameter, None + def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): param = param[...].to(param_casting_dtype) if to_contiguous: @@ -515,8 +586,8 @@ class ReplicateParallel(TensorParallelLayer): This class is used to replicate computation in a TP layer (used in SP regions when we don't use sequence parallelism for example) """ - def __init__(self, *, use_dtensor=True, use_local_output=True): - super().__init__() + def __init__(self, use_dtensor=True, use_local_output=True, **kwargs): + super().__init__(**kwargs) self.input_layouts = (Replicate(),) self.output_layouts = (Replicate(),) self.desired_input_layouts = (Replicate(),) @@ -537,12 +608,33 @@ def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): return outputs.to_local() if use_local_output and isinstance(outputs, DTensor) else outputs + def shard_tensor( + self, + param, + param_type=None, + param_casting_dtype=None, + to_contiguous=None, + rank=None, + device_mesh=None, + tensor_idx=None, + ): + parameter = param[...].to(param_casting_dtype) + shard = [Replicate()] + self.shard = shard + return parameter, shard + def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): - param = param[...].to(param_casting_dtype) - if to_contiguous: - param = param.contiguous() - param = DTensor.from_local(param, device_mesh, [Replicate()], run_check=False) - return param + parameter, shard = self.shard_tensor( + param, + param_type=param_type, + param_casting_dtype=param_casting_dtype, + to_contiguous=to_contiguous, + rank=rank, + device_mesh=device_mesh, + ) + if self.use_dtensor: + parameter = DTensor.from_local(parameter, device_mesh, shard, run_check=False) + return parameter class ColwiseParallel(TensorParallelLayer): @@ -552,13 +644,13 @@ class ColwiseParallel(TensorParallelLayer): def __init__( self, - *, input_layouts: Placement | None = None, output_layouts: Placement | None = None, use_local_output: bool = True, use_dtensor=True, + **kwargs, ): - super().__init__() + super().__init__(**kwargs) self.input_layouts = (input_layouts or Replicate(),) self.output_layouts = (output_layouts or Shard(-1),) self.desired_input_layouts = (Replicate(),) @@ -578,18 +670,34 @@ def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_ input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=False) return input_tensor - def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): - # colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only) - # means Colwise as Linear is input * weight^T + bias, where - # weight would become Shard(1) + def shard_tensor( + self, + param, + param_type=None, + param_casting_dtype=None, + to_contiguous=None, + rank=None, + device_mesh=None, + tensor_idx=None, + ): + device_mesh = self.device_mesh + empty_param = self.empty_param + rank = self.rank if param_type == "bias": - parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1) + parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1, tensor_idx) shard = [Shard(-1)] else: shard = [Shard(-2)] - parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -2) - + parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -2, tensor_idx) parameter = parameter.to(param_casting_dtype) + self.shard = shard + return parameter, shard + + def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): + # colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only) + # means Colwise as Linear is input * weight^T + bias, where + # weight would become Shard(1) + parameter, shard = self.shard_tensor(param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh) if to_contiguous: parameter = parameter.contiguous() if self.use_dtensor: @@ -608,6 +716,26 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me class PackedColwiseParallel(ColwiseParallel): + def shard_tensor( + self, + param, + param_type=None, + param_casting_dtype=None, + to_contiguous=None, + rank=None, + device_mesh=None, + tensor_idx=None, + ): + device_mesh = device_mesh or self.device_mesh + empty_param = self.empty_param + rank = rank if rank is not None else self.rank + return get_packed_weights(param, empty_param, device_mesh, rank, -2).to(param_casting_dtype), [Shard(-2)] + + def create_nn_parameter( + self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh + ): + return nn.Parameter(param, requires_grad=param.is_floating_point()) + def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): # colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only) # means Colwise as Linear is input * weight^T + bias, where @@ -642,18 +770,41 @@ class RowwiseParallel(TensorParallelLayer): def __init__( self, - *, input_layouts: Placement | None = None, output_layouts: Placement | None = None, use_local_output: bool = True, use_dtensor=True, + **kwargs, ): - super().__init__() + super().__init__(**kwargs) self.input_layouts = (input_layouts or Shard(-1),) self.output_layouts = (output_layouts or Replicate(),) self.use_local_output = use_local_output self.use_dtensor = use_dtensor + def shard_tensor( + self, + param, + param_type=None, + param_casting_dtype=None, + to_contiguous=None, + rank=None, + device_mesh=None, + tensor_idx=None, + ): + device_mesh = device_mesh or self.device_mesh + empty_param = self.empty_param + rank = rank if rank is not None else self.rank + if param_type == "bias": + shard = [Replicate()] + parameter = param[...] + else: + parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1, tensor_idx=tensor_idx) + shard = [Shard(-1)] + parameter = parameter.to(param_casting_dtype) + self.shard = shard + return parameter, shard + def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): # Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1) # means Rowwise as nn.Linear is input * weight^T + bias, where @@ -725,6 +876,21 @@ def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module: class PackedRowwiseParallel(RowwiseParallel): + def shard_tensor( + self, + param, + param_type=None, + param_casting_dtype=None, + to_contiguous=None, + rank=None, + device_mesh=None, + tensor_idx=None, + ): + device_mesh = device_mesh or self.device_mesh + empty_param = self.empty_param + rank = rank if rank is not None else self.rank + return get_packed_weights(param, empty_param, device_mesh, rank, -1), [Shard(-1)] + def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): # colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only) # means Colwise as Linear is input * weight^T + bias, where @@ -783,8 +949,8 @@ class SequenceParallel(TensorParallelLayer): to ensure that they are replicated. """ - def __init__(self, *, sequence_dim: int = 1, use_local_output: bool = False, use_dtensor=False): - super().__init__() + def __init__(self, sequence_dim: int = 1, use_local_output: bool = False, use_dtensor=False, **kwargs): + super().__init__(**kwargs) self.input_layouts = (Replicate(),) self.desired_input_layouts = (Shard(1),) self.output_layouts = (Replicate(),) @@ -793,6 +959,21 @@ def __init__(self, *, sequence_dim: int = 1, use_local_output: bool = False, use self.sequence_sharding = (Shard(sequence_dim),) self.use_local_output = use_local_output + def shard_tensor( + self, + param, + param_type=None, + param_casting_dtype=None, + to_contiguous=None, + rank=None, + device_mesh=None, + tensor_idx=None, + ): + parameter = param[...].to(param_casting_dtype) + shard = [Replicate()] + self.shard = shard + return parameter, shard + @staticmethod def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): input_tensor = inputs[0] @@ -827,10 +1008,34 @@ class GroupedGemmParallel(TensorParallelLayer): Applies Expert Parallelism to MoE experts by loading the correct experts on each device. """ - def __init__(self): - super().__init__() + def __init__(self, **kwargs): + super().__init__(**kwargs) self.use_dtensor = False + def shard_tensor( + self, + param, + param_type=None, + param_casting_dtype=None, + to_contiguous=None, + rank=None, + device_mesh=None, + tensor_idx=None, + ): + empty_param = self.empty_param + ep_rank = self.rank + device_mesh = self.device_mesh + + global_num_experts = empty_param.shape[0] + if global_num_experts % device_mesh.size() != 0: + raise ValueError( + f"Global number of experts must be divisible by number of devices: {global_num_experts} % {device_mesh.size()} != 0" + ) + local_num_experts = global_num_experts // device_mesh.size() + parameter = param[ep_rank * local_num_experts : (ep_rank + 1) * local_num_experts].to(param_casting_dtype) + self.shard = None + return parameter, None + def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): ep_rank = rank global_num_experts = empty_param.shape[0] @@ -851,8 +1056,8 @@ class RouterParallel(TensorParallelLayer): """ def __init__(self, *args, **kwargs): + super().__init__(**kwargs) self.args = args - self.kwargs = kwargs self.use_dtensor = False @staticmethod @@ -917,6 +1122,20 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me ) # masking class for one hot return router_scores, router_indices + def shard_tensor( + self, + param, + param_type=None, + param_casting_dtype=None, + to_contiguous=None, + rank=None, + device_mesh=None, + tensor_idx=None, + ): + parameter = param[...].to(param_casting_dtype) + self.shard = None + return parameter, None + def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh): # TODO: i'd like for this to be the default param = param[...].to(param_casting_dtype) @@ -1059,6 +1278,9 @@ def shard_and_distribute_module( if current_shard_plan is not None: try: tp_layer = ALL_PARALLEL_STYLES[current_shard_plan] + tp_layer.empty_param = empty_param + tp_layer.device_mesh = device_mesh + tp_layer.rank = rank param = tp_layer.partition_tensor( param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh ) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 960373ba102a..b212344b4e12 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -23,11 +23,11 @@ import os import re import sys +import time import warnings from abc import abstractmethod from collections import defaultdict -from collections.abc import Callable -from concurrent.futures import ThreadPoolExecutor, as_completed +from collections.abc import Callable, Sequence from contextlib import contextmanager from enum import Enum from functools import partial, wraps @@ -45,17 +45,22 @@ from torch.utils.checkpoint import checkpoint from .configuration_utils import PreTrainedConfig +from .conversion_mapping import get_checkpoint_conversion_mapping +from .core_model_loading import ( + WeightConverter, + _infer_parameter_dtype, + convert_and_load_state_dict_in_model, + revert_weight_conversion, +) from .distributed import DistributedConfig from .dynamic_module_utils import custom_object_save from .generation import CompileConfig, GenerationConfig from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled, is_fsdp_enabled from .integrations.accelerate import ( _get_device_map, - accelerate_disk_offload, accelerate_dispatch, check_and_set_device_map, expand_device_map, - find_tied_parameters, init_empty_weights, ) from .integrations.deepspeed import _load_state_dict_into_zero3_model @@ -122,6 +127,7 @@ is_sagemaker_mp_enabled, is_tracing, ) +from .utils.loading_report import log_state_dict_report from .utils.quantization_config import QuantizationMethod @@ -130,7 +136,6 @@ from accelerate.utils import ( extract_model_from_parallel, offload_weight, - save_offload_index, ) from accelerate.utils.modeling import get_state_dict_from_offload @@ -182,6 +187,7 @@ def is_local_dist_rank_0(): "xavier_normal": nn.init.xavier_normal, "kaiming_uniform": nn.init.kaiming_uniform, "kaiming_normal": nn.init.kaiming_normal, + "orthogonal_": nn.init.orthogonal_, } # DO NOT MODIFY, KEPT FOR BC ONLY @@ -470,7 +476,8 @@ def _end_ptr(tensor: torch.Tensor) -> int: def _get_tied_weight_keys(module: nn.Module, prefix=""): tied_weight_keys = [] if getattr(module, "_tied_weights_keys", None) is not None: - names = [f"{prefix}.{k}" if prefix else k for k in module._tied_weights_keys] + value_names = list(module._tied_weights_keys.keys()) + names = [f"{prefix}.{k}" if prefix else k for k in value_names] tied_weight_keys.extend(names) if getattr(module, "_dynamic_tied_weights_keys", None) is not None: names = [f"{prefix}.{k}" if prefix else k for k in module._dynamic_tied_weights_keys] @@ -530,39 +537,6 @@ def _find_identical(tensors: list[set[str]], state_dict: dict[str, torch.Tensor] shared_tensors.append(shared) return shared_tensors, identical - -def _infer_parameter_dtype( - model: "PreTrainedModel", - param_name: str, - empty_param: torch.Tensor, - hf_quantizer: Optional[HfQuantizer] = None, -) -> Union[bool, Optional[torch.dtype]]: - try: - old_param = model.get_parameter_or_buffer(param_name) - except Exception as e: - if hf_quantizer is not None and hf_quantizer.quantization_config.quant_method in { - QuantizationMethod.HQQ, - QuantizationMethod.QUARK, - QuantizationMethod.MXFP4, - QuantizationMethod.BITS_AND_BYTES, - }: - return True, None - else: - raise e - is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn") - # We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params - # in int/uint/bool and not cast them. - casting_dtype = None - is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn - if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn: - # dtype that was instantiated in the meta model -- note that this respects subconfigs dtypes - if hf_quantizer is not None and hf_quantizer.param_needs_quantization(model, param_name): - casting_dtype = model.config._pre_quantization_dtype - else: - casting_dtype = old_param.dtype - return old_param is not None and old_param.is_contiguous(), casting_dtype - - def _load_parameter_into_model(model: "PreTrainedModel", param_name: str, tensor: torch.Tensor): """Cast a single parameter `param_name` into the `model`, with value `tensor`.""" module, param_type = get_module_from_name(model, param_name) @@ -696,83 +670,6 @@ def _load_state_dict_into_meta_model( return disk_offload_index -def load_shard_file(args): - ( - shard_file, - state_dict, - disk_only_shard_files, - is_quantized, - device_map, - hf_quantizer, - key_renaming_mapping, - weights_only, - model, - reverse_key_renaming_mapping, - disk_offload_folder, - disk_offload_index, - device_mesh, - ) = args - - # Skip the load for shards that only contain disk-offloaded weights - if shard_file in disk_only_shard_files: - return [], disk_offload_index - - map_location = "cpu" - if shard_file.endswith(".safetensors") and not (is_deepspeed_zero3_enabled() and not is_quantized): - map_location = "meta" - - # If shard_file is "", we use the existing state_dict instead of loading it - if shard_file != "": - state_dict = load_state_dict( - shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only - ) - - # Fix the key names - state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping} - - error_msgs = [] - if is_deepspeed_zero3_enabled() and not is_quantized: - error_msgs += _load_state_dict_into_zero3_model(model, state_dict) - # Skip it with fsdp on ranks other than 0 - elif not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized): - disk_offload_index = _load_state_dict_into_meta_model( - model, - state_dict, - shard_file, - reverse_key_renaming_mapping, - device_map=device_map, - disk_offload_folder=disk_offload_folder, - disk_offload_index=disk_offload_index, - hf_quantizer=hf_quantizer, - device_mesh=device_mesh, - ) - - return error_msgs, disk_offload_index - - -def load_shard_files_with_threadpool(args_list): - num_workers = int(os.environ.get("HF_PARALLEL_LOADING_WORKERS", "8")) - - # Do not spawn anymore workers than you need - num_workers = min(len(args_list), num_workers) - - logger.info(f"Loading model weights in parallel with {num_workers} workers...") - - error_msgs = [] - - with ThreadPoolExecutor(max_workers=num_workers) as executor: - with logging.tqdm(total=len(args_list), desc="Loading checkpoint shards") as pbar: - futures = [executor.submit(load_shard_file, arg) for arg in args_list] - for future in as_completed(futures): - _error_msgs, disk_offload_index = future.result() - - error_msgs += _error_msgs - - pbar.update(1) - - return error_msgs, disk_offload_index - - def _add_variant(weights_name: str, variant: Optional[str] = None) -> str: if variant is not None: path, name = weights_name.rsplit(".", 1) @@ -780,40 +677,6 @@ def _add_variant(weights_name: str, variant: Optional[str] = None) -> str: return weights_name -def update_key_name(keys): - """ - Updates a dictionary of keys to pack layers together as layer.{0, 1, 4} instead of layers.0, layers.1, layers.4. - """ - key_dict = defaultdict(list) - for key in keys: - all_digits = re.findall(r".(\d+).", key) - for i, k in enumerate(all_digits): - if len(key_dict[re.sub(r".(\d+).", ".*.", key)]) <= i: - key_dict[re.sub(r".(\d+).", ".*.", key)].append(set()) - key_dict[re.sub(r".(\d+).", ".*.", key)][i].add(int(k)) - - final_keys = set() - for key in keys: - text = re.sub(r".(\d+).", ".*.", key) - pattern = key_dict[text] - final_text = "" - for i, part in enumerate(text.split("*")): - if len(pattern) <= i: - final_text += part - else: - data = [str(i) for i in sorted(pattern[i])] - if len(data) > 10: - result = f"{data[0]}...{data[-1]}" - else: - result = ", ".join(data) # If there are only 1 or 2 elements, show them all - if len(data) > 1: - final_text += part + "{" + result + "}" - else: - final_text += part + data[0] - final_keys.add(final_text) - return sorted(final_keys) - - def _get_resolved_checkpoint_files( pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], variant: Optional[str], @@ -1174,102 +1037,33 @@ def _get_dtype( return config, dtype, dtype_orig -def _find_missing_and_unexpected_keys( - model: "PreTrainedModel", - original_checkpoint_keys: list[str], - checkpoint_keys: list[str], - loading_base_model_from_task_state_dict: bool, - hf_quantizer: Optional[HfQuantizer], -) -> tuple[list[str], list[str]]: - """Find missing keys (keys that are part of the model parameters but were NOT found in the loaded state dict keys) and unexpected keys - (keys found in the loaded state dict keys, but that are NOT part of the model parameters) - """ - prefix = model.base_model_prefix - - # Compute expected keys, i.e. keys that the full model expects - expected_keys = list(model.state_dict().keys()) - if hf_quantizer is not None: - expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, checkpoint_keys) - - # Adjust prefix of the keys to make them match loaded keys before removing them - missing_keys = sorted(set(expected_keys) - set(checkpoint_keys)) - unexpected_keys = set(checkpoint_keys) - set(expected_keys) - # If a module has the same name under the base and task specific model, we have to re-add it to unexpected keys - if loading_base_model_from_task_state_dict: - task_specific_keys = [k for k in original_checkpoint_keys if not k.startswith(f"{prefix}.")] - unexpected_keys.update(task_specific_keys) - - # Remove nonpersistent buffers from unexpected keys: they are not in the expected keys (model state dict), but - # may be in the loaded keys. Note that removing all buffers does the job, as they were part of the expected keys anyway - model_buffers = {n for n, _ in model.named_buffers()} - unexpected_keys = sorted(unexpected_keys - model_buffers) - - tied_params = find_tied_parameters(model) - for group in tied_params: - missing_in_group = [k for k in missing_keys if k in group] - if len(missing_in_group) > 0 and len(missing_in_group) < len(group): - missing_keys = [k for k in missing_keys if k not in missing_in_group] - - if hf_quantizer is not None: - missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix) - unexpected_keys = hf_quantizer.update_unexpected_keys(model, unexpected_keys) - - return missing_keys, unexpected_keys - +@contextmanager +def guard_nn_init_functions(flag_name: str = "_is_hf_initialized"): + import torch.nn.init as I -def _find_mismatched_keys( - model: "PreTrainedModel", - state_dict: Optional[dict], - checkpoint_files: Optional[list[str]], - ignore_mismatched_sizes: bool, - keys_to_rename_mapping: dict[str, str], - is_quantized: bool, - weights_only: bool, -) -> tuple[list[str], list[tuple[int, int]]]: - """ - Find potential shape mismatch between the different state dicts and the model parameters, but only if `ignore_mismatched_sizes` - is True. Otherwise, return immediately and any shape mismatch that may exist will be raised later on. This avoids checking - every parameter in advance, as shape mismatch are extremely rare in practice. If we want to ignore them however, we do - need to check in advance as we need to know which parameters we need to move back from meta to cpu, and initialize - correctly. Indeed, as our model initialization takes place at the module level, and not the weight level, in the - case of a sharded checkpoint we cannot correctly initialize the weights according to `model._init_weights()` if we perform - this check on each state dict at loading time (after the first loaded checkpoint, there are no way to initialize only the - mismatched weights if any, without overwriting the previously loaded weights as well because all the module will be - initialized, not only the weights that are mismatched). - """ + originals = {} - # An error will be raised later on anyway if there is a mismatch - this avoids running the rest of this function - # if there are no mismatch (which is almost always the case) - if not ignore_mismatched_sizes: - return [], [] - - if state_dict is not None: - checkpoint_files = [""] - - model_state_dict = model.state_dict() - mismatched_keys = [] - mismatched_shapes = [] - for shard_file in checkpoint_files: - # If shard_file is "", we use the existing state_dict instead of loading it - if shard_file != "": - state_dict = load_state_dict( - shard_file, is_quantized=is_quantized, map_location="meta", weights_only=weights_only - ) + def make_wrapper(fn): + @wraps(fn) + def wrapped(*args, **kwargs): + # Tensor can come positionally or as a kwarg (e.g. via DeviceContext) + t = args[0] if args else kwargs.get("tensor", kwargs.get("input")) + if t is not None and getattr(t, flag_name, False): + # mimic init.* return convention (returns the tensor) + return t + return fn(*args, **kwargs) # TODO we could set is init here. - # Fix the key names - new_state_dict = {keys_to_rename_mapping[k]: v for k, v in state_dict.items() if k in keys_to_rename_mapping} + return wrapped - for key, tensor in new_state_dict.items(): - if key in model_state_dict and tensor.shape != model_state_dict[key].shape: - # This skips size mismatches for 4-bit weights. Two 4-bit values share an 8-bit container, causing size differences. - # Without matching with module type or parameter type it seems like a practical way to detect valid 4bit weights. - if not ( - is_quantized and tensor.shape[-1] == 1 and tensor.numel() * 2 == model_state_dict[key].numel() - ): - mismatched_keys.append(key) - mismatched_shapes.append((tensor.shape, model_state_dict[key].shape)) - - return mismatched_keys, mismatched_shapes + try: + for name in TORCH_INIT_FUNCTIONS: + if hasattr(I, name): + originals[name] = getattr(I, name) + setattr(I, name, make_wrapper(originals[name])) + yield + finally: + for name, fn in originals.items(): + setattr(I, name, fn) class PipelineParallel(Enum): @@ -1677,6 +1471,8 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH # to also prevent bfloat16 casting, use the _keep_in_fp32_modules_strict flag _keep_in_fp32_modules_strict = None + dtype_plan: Optional[dict[str, torch.dtype]] = None + # a list of `re` patterns of `state_dict` keys that should be removed from the list of missing # keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings. _keys_to_ignore_on_load_missing = None @@ -1841,11 +1637,18 @@ def __init__(self, config: PreTrainedConfig, *inputs, **kwargs): self.name_or_path = config.name_or_path self.warnings_issued = {} self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None + # Overwrite the class attribute to make it an instance attribute, so models like # `InstructBlipForConditionalGeneration` can dynamically update it without modifying the class attribute # when a different component (e.g. language_model) is used. self._keep_in_fp32_modules = copy.copy(self.__class__._keep_in_fp32_modules) self._keep_in_fp32_modules_strict = copy.copy(self.__class__._keep_in_fp32_modules_strict) + self.dtype_plan = {} + + if isinstance(self._keep_in_fp32_modules, list): + self.dtype_plan.update(dict.fromkeys(self._keep_in_fp32_modules, torch.float32)) + if isinstance(self._keep_in_fp32_modules_strict, list): + self.dtype_plan.update(dict.fromkeys(self._keep_in_fp32_modules_strict, torch.float32)) self._no_split_modules = self._no_split_modules or [] _CAN_RECORD_REGISTRY[str(self.__class__)] = self._can_record_outputs # added for executorch support only @@ -1861,32 +1664,6 @@ def post_init(self): self.init_weights() self._backward_compatibility_gradient_checkpointing() - # Make sure the modules correctly exist if the flag is active - if self._keep_in_fp32_modules is not None or self._keep_in_fp32_modules_strict is not None: - all_parameters = {name for name, _ in self.named_parameters() if len(name) > 0} - unique_module_names = set() - # Get all unique module names in the module graph, without the prefixes - for param in all_parameters: - unique_module_names.update( - [name for name in param.split(".") if not name.isnumeric() and name not in ["weight", "bias"]] - ) - # Check that every module in the keep_in_fp32 list is part of the module graph - if self._keep_in_fp32_modules is not None: - for module in self._keep_in_fp32_modules: - if module not in unique_module_names: - raise ValueError( - f"{module} was specified in the `_keep_in_fp32_modules` list, but is not part of the modules in" - f" {self.__class__.__name__}" - ) - - if self._keep_in_fp32_modules_strict is not None: - for module in self._keep_in_fp32_modules_strict: - if module not in unique_module_names: - raise ValueError( - f"{module} was specified in the `_keep_in_fp32_modules_strict` list, but is not part of the modules in" - f" {self.__class__.__name__}" - ) - self._tp_plan, self._ep_plan, self._pp_plan = {}, {}, {} # If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config if self.base_model is self: @@ -2625,6 +2402,7 @@ def set_decoder(self, decoder): return + @torch.no_grad() def _init_weights(self, module): """ Initialize the weights. This is quite general on purpose, in the spirit of what we usually do. For more complex @@ -2632,34 +2410,50 @@ def _init_weights(self, module): `nn.Parameter`, this method should also be overridden in order to initialize it correctly. """ if hasattr(self.config, "initializer_range"): - std = self.config.initializer_range + std = self.config.initializer_range or 0.02 else: # 0.02 is the standard default value across the library std = getattr(self.config.get_text_config(), "initializer_range", 0.02) - - if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, nn.MultiheadAttention): - # This uses torch's original init - module._reset_parameters() - # We cannot use `isinstance` on the RMSNorms or LayerNorms, as they usually are custom modules which change names - # between modelings (because they are prefixed with the model name) - elif ( - isinstance(module, (nn.GroupNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)) - or "LayerNorm" in module.__class__.__name__ - or "RMSNorm" in module.__class__.__name__ - ): - # Norms can exist without weights (in which case they are None from torch primitives) - if hasattr(module, "weight") and module.weight is not None: - module.weight.data.fill_(1.0) - if hasattr(module, "bias") and module.bias is not None: - module.bias.data.zero_() + try: + if isinstance(module, PreTrainedModel): + return + elif isinstance( + module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d) + ): + if getattr(module, "weight", None) is not None: + module.weight.normal_(mean=0.0, std=std) + if getattr(module, "bias", None) is not None: + module.bias.zero_() + elif isinstance(module, nn.Embedding): + if getattr(module, "weight", None) is not None: + module.weight.normal_(mean=0.0, std=std) + if getattr(module, "padding_idx", None) is not None: + module.weight[module.padding_idx].zero_() + elif isinstance(module, nn.Parameter): + module.normal_(mean=0.0, std=std) + elif isinstance(module, nn.MultiheadAttention): + # This uses torch's original init + module._reset_parameters() + # We cannot use `isinstance` on the RMSNorms or LayerNorms, as they usually are custom modules which change names + # between modelings (because they are prefixed with the model name) + elif ( + isinstance(module, (nn.GroupNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)) + or "LayerNorm" in module.__class__.__name__ + or "RMSNorm" in module.__class__.__name__ + ): + # Norms can exist without weights (in which case they are None from torch primitives) + if hasattr(module, "weight") and module.weight is not None: + module.weight.fill_(1.0) + if hasattr(module, "bias") and module.bias is not None: + module.bias.zero_() + if isinstance(getattr(module, "gate_up_proj", None), nn.Parameter): + module.gate_up_proj.normal_(mean=0.0, std=std) + if isinstance(getattr(module, "down_proj", None), nn.Parameter): + module.down_proj.normal_(mean=0.0, std=std) + if isinstance(getattr(module, "gate", None), nn.Parameter): + module.gate.normal_(mean=0.0, std=std) + except Exception as e: + logger.warning(f"Failed to init: {str(e)}") def _initialize_weights(self, module): """ @@ -2667,10 +2461,12 @@ def _initialize_weights(self, module): """ if getattr(module, "_is_hf_initialized", False): return + self._init_weights(module) module._is_hf_initialized = True @torch.no_grad() + @guard_nn_init_functions() def initialize_weights(self): """ This is equivalent to calling `self.apply(self._initialize_weights)`, but correctly handles composite models. @@ -2681,8 +2477,9 @@ def initialize_weights(self): Note that the `torch.no_grad()` decorator is very important as well, as most of our `_init_weights` do not use `torch.nn.init` functions (which are all no_grad by default), but simply do in-place ops such as - `module.weight.data.zero_()`. + `module.weight.zero_()`. """ + # Sort by depth (stable) then name for deterministic order. if not hasattr(torch.nn.Module, "smart_apply"): # This function is equivalent to `torch.nn.Module.apply`, except that it dynamically adjust the function # to apply as we go down the graph @@ -2701,155 +2498,127 @@ def smart_apply(self, fn): # Let the magic happen with this simple call self.smart_apply(self._initialize_weights) - def tie_embeddings_and_encoder_decoder(self): + def tie_weight_source_and_target( + self, + top_level: "PreTrainedModel", + missing_keys: Optional[set[str]] = None, + module_prefix: str = "", + ): """ If set in the config, tie the weights between the input embeddings and the output embeddings, - and the encoder and decoder. + and the encoder and decoder. This relies on the `_tied_weights_keys` dict. """ - if getattr(self.config.get_text_config(decoder=True), "tie_word_embeddings", True): - output_embeddings = self.get_output_embeddings() - if output_embeddings is not None: - self._tie_embedding_weights(output_embeddings, self.get_input_embeddings()) - - if getattr(self.config, "is_encoder_decoder", False) and getattr(self.config, "tie_encoder_decoder", False): - if hasattr(self, self.base_model_prefix): - self = getattr(self, self.base_model_prefix) - tied_weights = self._tie_encoder_decoder_weights( - self.encoder, self.decoder, self.base_model_prefix, "encoder" - ) - # Setting a dynamic variable instead of `_tied_weights_keys` because it's a class - # attributed not an instance member, therefore modifying it will modify the entire class - # Leading to issues on subsequent calls by different tests or subsequent calls. - self._dynamic_tied_weights_keys = tied_weights + mapping = getattr(self, "_tied_weights_keys", None) + if not isinstance(mapping, dict): + return + if ( # we only tie for ourselves, so we look at our config + not self.config.tie_word_embeddings + and not self.config.tie_encoder_decoder # if missing keys is None we init? + ): + return - def tie_weights(self): - """ - Recursively (for all submodels) tie all the weights of the model. - """ - # Note that `self` is included in `self.modules` so we also apply to current PreTrainedModel with this call - for module in self.modules(): - # If it's a PreTrainedModel, may need to tie the embeddings and/or encoder/decoder weights - if isinstance(module, PreTrainedModel): - module.tie_embeddings_and_encoder_decoder() - # Additionally, if it has a custom `_tie_weights`, honor it - if hasattr(module, "_tie_weights"): - module._tie_weights() + for target_name, source_name in mapping.items(): + source_name = f"{module_prefix}.{source_name}" if module_prefix else source_name - @staticmethod - def _tie_encoder_decoder_weights( - encoder: nn.Module, decoder: nn.Module, base_model_prefix: str, base_encoder_name: str - ): - uninitialized_encoder_weights: list[str] = [] - tied_weights: list[str] = [] - if decoder.__class__ != encoder.__class__: - logger.info( - f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder" - " weights are correctly initialized." + # if there are missing keys but the source is also missing, we are out, _init_weights will init later and tie later. + # maybe we still need ot remove tied from missing just because you tie + source_is_there = missing_keys and not re.search( + rf"^{re.escape(source_name)}", "\n".join(missing_keys), flags=re.MULTILINE ) - def tie_encoder_to_decoder_recursively( - decoder_pointer: nn.Module, - encoder_pointer: nn.Module, - module_name: str, - base_encoder_name: str, - uninitialized_encoder_weights: list[str], - depth=0, - total_decoder_name="", - total_encoder_name="", - ): - assert isinstance(decoder_pointer, nn.Module) and isinstance(encoder_pointer, nn.Module), ( - f"{decoder_pointer} and {encoder_pointer} have to be of type nn.Module" + # if neither are here, we still want to the training to have same grads + target_is_not_there = ( + missing_keys + and re.search(target_name, "\n".join(missing_keys), flags=re.MULTILINE) + and not source_is_there ) - if hasattr(decoder_pointer, "weight"): - assert hasattr(encoder_pointer, "weight") - encoder_pointer.weight = decoder_pointer.weight - tied_weights.append(f"{base_encoder_name}{total_encoder_name}.weight") - if hasattr(decoder_pointer, "bias"): - assert hasattr(encoder_pointer, "bias") - tied_weights.append(f"{base_encoder_name}{total_encoder_name}.bias") - encoder_pointer.bias = decoder_pointer.bias - return + if source_is_there or missing_keys is None or target_is_not_there: + try: + if source_name.endswith(".bias") or source_name.endswith(".weight"): + source_param_or_module = top_level.get_parameter_or_buffer(source_name) + else: + source_param_or_module = top_level.get_submodule(source_name) + except AttributeError: + continue - encoder_modules = encoder_pointer._modules - decoder_modules = decoder_pointer._modules - if len(decoder_modules) > 0: - assert len(encoder_modules) > 0, ( - f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}" - ) + target_name = f"{module_prefix}.{target_name}" if module_prefix else target_name - all_encoder_weights = {module_name + "/" + sub_name for sub_name in encoder_modules} - encoder_layer_pos = 0 - for name in decoder_modules: - if name.isdigit(): - encoder_name = str(int(name) + encoder_layer_pos) - decoder_name = name - if not isinstance(decoder_modules[decoder_name], type(encoder_modules[encoder_name])) and len( - encoder_modules - ) != len(decoder_modules): - # this can happen if the name corresponds to the position in a list module list of layers - # in this case the decoder has added a cross-attention that the encoder does not have - # thus skip this step and subtract one layer pos from encoder - encoder_layer_pos -= 1 + if "d+" in target_name: + reg = re.compile(target_name) + modules = dict(self.named_modules()) + params = dict(self.named_parameters()) + for target_n in modules.keys() | params.keys(): + if not reg.fullmatch(target_n): continue - elif name not in encoder_modules: - continue - elif depth > 500: - raise ValueError( - "Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is" - " a circular dependency between two or more `nn.Modules` of your model." - ) + if "." in target_n: + parent_path, last = target_n.rsplit(".", 1) + parent = self.get_submodule(parent_path) + else: + parent_path, last = "", target_n + parent = self # top-level + if last in parent._modules: + parent._modules[last] = source_param_or_module + if missing_keys: + for k, _ in source_param_or_module.named_parameters(): + missing_keys.discard(f"{parent_path}.{last}.{k}") + else: + setattr(parent, last, source_param_or_module) + self._adjust_bias(parent, source_param_or_module) + if missing_keys: + missing_keys.discard(target_n) + else: + if "." in target_name: + submodule, weight = target_name.rsplit(".", 1) + submodule = top_level.get_submodule(submodule) + setattr(submodule, weight, source_param_or_module) + self._adjust_bias(submodule, source_param_or_module) else: - decoder_name = encoder_name = name - tie_encoder_to_decoder_recursively( - decoder_modules[decoder_name], - encoder_modules[encoder_name], - module_name + "/" + name, - base_encoder_name, - uninitialized_encoder_weights, - depth=depth + 1, - total_encoder_name=f"{total_encoder_name}.{encoder_name}", - total_decoder_name=f"{total_decoder_name}.{decoder_name}", - ) - all_encoder_weights.remove(module_name + "/" + encoder_name) + setattr(top_level, target_name, source_param_or_module) - uninitialized_encoder_weights += list(all_encoder_weights) + if missing_keys: + missing_keys.discard(target_name) - # tie weights recursively - tie_encoder_to_decoder_recursively( - decoder, encoder, base_model_prefix, base_encoder_name, uninitialized_encoder_weights - ) - - if len(uninitialized_encoder_weights) > 0: - logger.warning( - f"The following encoder weights were not tied to the decoder {uninitialized_encoder_weights}" - ) - return tied_weights - - def _tie_embedding_weights(self, output_embeddings, input_embeddings): - """Tie weights, and add hooks and flags if using TP.""" - output_embeddings.weight = input_embeddings.weight - - # Passing hooks over to the embeddings if needed - # (currently limited to tensor parallel hooks and flags only) - if hasattr(input_embeddings, "_is_hooked") and getattr(input_embeddings, "_hf_tp_plan", None): - output_embeddings._is_hooked = input_embeddings._is_hooked - output_embeddings._hf_tp_plan = input_embeddings._hf_tp_plan - output_embeddings._forward_hooks = input_embeddings._forward_hooks - output_embeddings._forward_pre_hooks = input_embeddings._forward_pre_hooks - output_embeddings.__repr__ = ( - lambda: f"{output_embeddings.__repr__()}\nTP Plan: {output_embeddings._hf_tp_plan}" - ) + # source and target are missing, but we don't need to warn about target missing as we are prob gonna tie + elif ( + source_is_there + and missing_keys + and (self.config.tie_word_embeddings or self.config.tie_encoder_decoder) + ): + if "d+" in target_name: + for target_n, _ in self.named_parameters(): + missing_keys.discard(target_n) + else: + missing_keys.discard(target_name) - if getattr(output_embeddings, "bias", None) is not None: + def _adjust_bias(self, output_embeddings, input_embeddings): + if getattr(output_embeddings, "bias", None) is not None and hasattr(output_embeddings, "weight"): + weight_shape = output_embeddings.weight.shape output_embeddings.bias.data = nn.functional.pad( output_embeddings.bias.data, - (0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0]), + (0, weight_shape[0] - output_embeddings.bias.shape[0]), "constant", 0, ) if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"): output_embeddings.out_features = input_embeddings.num_embeddings + def tie_weights(self, missing_keys: Optional[set[str]] = None): + """ + Recursively (for all submodels) tie all the weights of the model. + """ + # Note that `self` is included in `self.modules` so we also apply to current PreTrainedModel with this call + if missing_keys is None: + # called from `post_init` + self.tie_weight_source_and_target(self, missing_keys, "") + else: # this is from_pretrained, so its not called on every sub module + for module_prefix, module in self.named_modules(): + # If it's a PreTrainedModel, may need to tie the embeddings and/or encoder/decoder weights + if isinstance(module, PreTrainedModel): + module.tie_weight_source_and_target(self, missing_keys, module_prefix) + # Additionally, if it has a custom `_tie_weights`, honor it + if hasattr(module, "_tie_weights"): + module._tie_weights() + def _get_no_split_modules(self, device_map: str): """ Get the modules of the model that should not be spit when using device_map. We iterate through the modules to @@ -3352,9 +3121,8 @@ def init_weights(self): if _init_weights: # Initialize weights self.initialize_weights() - - # Tie weights should be skipped when not initializing all weights - # since from_pretrained(...) calls tie weights anyways + # Tie weights needs to be called as it figures out recursively if sub modules + # need to tie self.tie_weights() def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): @@ -3457,6 +3225,7 @@ def save_pretrained( variant: Optional[str] = None, token: Optional[Union[str, bool]] = None, save_peft_format: bool = True, + save_original_format: bool = False, # TODO next PR will make it go to True **kwargs, ): """ @@ -3505,6 +3274,10 @@ def save_pretrained( For backward compatibility with PEFT library, in case adapter weights are attached to the model, all keys of the state dict of adapters needs to be prepended with `base_model.model`. Advanced users can disable this behaviours by setting `save_peft_format` to `False`. + save_original_format (`bool`, *optional*, defaults to `True`): + For backward compatibility with the previous versions of `transfomers` you can save the checkpoint with + its reverse mapping. The reverse mapping needs to exists even if the model was loaded from a None legacy + checkpoint. kwargs (`dict[str, Any]`, *optional*): Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. """ @@ -3644,24 +3417,18 @@ def save_pretrained( module_map[name + f".{key}"] = module state_dict = model_to_save.state_dict() - if any( - allowed_name in class_name.__name__.lower() - for class_name in self.__class__.__mro__[:-1] - for allowed_name in VLMS + if ( + any( + allowed_name in class_name.__name__.lower() + for class_name in self.__class__.__mro__[:-1] + for allowed_name in VLMS + ) + or save_original_format ): - reverse_key_mapping = {v: k for k, v in self._checkpoint_conversion_mapping.items()} - - original_state_dict = {} - for key, value in state_dict.items(): - for pattern, replacement in reverse_key_mapping.items(): - replacement = replacement.lstrip("^") # strip off un-needed chars and patterns - replacement = re.sub(r"\(.*\)", "", replacement) - key, n_replace = re.subn(pattern, replacement, key) - # Early exit of the loop - if n_replace > 0: - break - original_state_dict[key] = value - state_dict = original_state_dict + # MEGA BIG TODO HERE: self._conversion_ops needs to be used to save the final ckpt + # using what was loaded. Actually self._conversion_ops wont work because we need it + # even if the files are not legacy -> thus no conversion happened + state_dict = revert_weight_conversion(self, state_dict) # Translate state_dict from smp to hf if saving with smp >= 1.10 if IS_SAGEMAKER_MP_POST_1_10: @@ -3707,7 +3474,7 @@ def save_pretrained( shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1} # Recursively descend to find tied weight keys - _tied_weights_keys = _get_tied_weight_keys(self) + _tied_weights_keys = set(_get_tied_weight_keys(self)) error_names = [] to_delete_names = set() for names in shared_ptrs.values(): @@ -3747,10 +3514,10 @@ def save_pretrained( error_names.extend(shared_names) if len(error_names) > 0: + suggested_fix = {v: k for k, v in list(shared_ptrs.values())} raise RuntimeError( - f"The weights trying to be saved contained shared tensors {error_names} that are mismatching " - "the transformers base configuration. Try saving using `safe_serialization=False`, setting the " - "`_dynamic_tied_weights_keys` attribute for affected modules, or remove this tensor sharing.", + f"The weights trying to be saved contained shared tensors {error_names} which are not properly defined" + f"as being shared in `_tied_weight_keys`. You should probably add: `_tied_weight_keys = {suggested_fix}. If a whole module is shared you can use it directly", ) # Shard the model if it is too big. @@ -3829,7 +3596,8 @@ def save_pretrained( if safe_serialization: # At some point we will need to deal better with save_function (used for TPU and other distributed - # joyfulness), but for now this enough. + # joyfulness), but for now this enough. # TODO: we should def parallelize this we are otherwise just waiting + # too much before scheduling the next write when its in a different file safe_save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata) else: save_function(shard, os.path.join(save_directory, shard_file)) @@ -4399,6 +4167,13 @@ def from_pretrained( config, quantization_config, dtype, device_map, weights_only, user_agent ) + weight_conversions: Optional[list[WeightConverter]] = None + model_type = getattr(config, "model_type", None) + if model_type is not None: + weight_conversions = get_checkpoint_conversion_mapping().get(model_type) + if weight_conversions is None: + weight_conversions = get_checkpoint_conversion_mapping()["legacy"] + if gguf_file: if hf_quantizer is not None: raise ValueError( @@ -4453,31 +4228,30 @@ def from_pretrained( with ContextManagers(model_init_context): # Let's make sure we don't run the init function of buffer modules model = cls(config, *model_args, **model_kwargs) - - # Potentially upcast some modules to avoid loosing precision - model.upcast_modules_in_fp32(hf_quantizer, dtype) - # Make sure to tie the weights correctly - model.tie_weights() + copy_model = cls(config, *model_args, **model_kwargs) # make sure we use the model's config since the __init__ call might have copied it config = model.config if hf_quantizer is not None: # replace module with quantized modules (does not touch weights) - hf_quantizer.preprocess_model( - model=model, - device_map=device_map, - keep_in_fp32_modules=model._keep_in_fp32_modules, - config=config, - checkpoint_files=checkpoint_files, - use_kernels=use_kernels, - ) + for m in [model, copy_model]: + hf_quantizer.preprocess_model( + model=m, + device_map=device_map, + keep_in_fp32_modules=model._keep_in_fp32_modules, # TODO prob no longer needed? + config=config, + checkpoint_files=checkpoint_files, + use_kernels=use_kernels, + ) if _torch_distributed_available and device_mesh is not None: # add hooks to nn.Modules: no weights model = distribute_model(model, tp_plan, distributed_config, device_mesh, tp_size) # Prepare the full device map if device_map is not None: - device_map = _get_device_map(model, device_map, max_memory, hf_quantizer, dtype) + # simple solution as deepcopy don't work. We want to tie the weights afterwards. + copy_model.tie_weights() + device_map = _get_device_map(copy_model, device_map, max_memory, hf_quantizer, dtype) # restore default dtype if dtype_orig is not None: @@ -4498,15 +4272,15 @@ def from_pretrained( device_mesh=device_mesh, key_mapping=key_mapping, weights_only=weights_only, + weight_mapping=weight_conversions, ) - model.tie_weights() # make sure token embedding weights are still tied if needed model.eval() # Set model in evaluation mode to deactivate DropOut modules by default model.set_use_kernels(use_kernels, kernel_config) # If it is a model with generation capabilities, attempt to load generation files (generation config, # custom generate function) - if model.can_generate() and hasattr(model, "adjust_generation_fn"): + if model.can_generate() and hasattr(model, "adjust_generation_fn") and trust_remote_code: model.adjust_generation_fn( generation_config, from_auto_class, @@ -4517,17 +4291,16 @@ def from_pretrained( **kwargs, ) - # for device_map="auto" : dispatch model with hooks on all devices if necessary (not needed with a tp_plan, so we skip it as it slightly - # harm performances). + # for device_map="auto" : dispatch model with hooks on all devices if necessary if device_map is not None and device_mesh is None: accelerate_dispatch(model, hf_quantizer, device_map, offload_folder, offload_index, offload_buffers) if hf_quantizer is not None: model.hf_quantizer = hf_quantizer - hf_quantizer.postprocess_model(model, config=config) # usually a no-op + hf_quantizer.postprocess_model(model, config=config) # usually a no-op but sometimes needed if _adapter_model_path is not None: - adapter_kwargs["key_mapping"] = key_mapping # TODO: Dynamic weight loader for adapters + adapter_kwargs["key_mapping"] = weight_conversions # TODO: Dynamic weight loader for adapters model.load_adapter( _adapter_model_path, adapter_name=adapter_name, @@ -4545,107 +4318,6 @@ def from_pretrained( return model, loading_info return model - @staticmethod - def _fix_state_dict_key_on_load(key: str) -> tuple[str, bool]: - """Replace legacy parameter names with their modern equivalents. E.g. beta -> bias, gamma -> weight.""" - # Rename LayerNorm beta & gamma params for some early models ported from Tensorflow (e.g. Bert) - # This rename is logged. - if key.endswith("LayerNorm.beta"): - return key.replace("LayerNorm.beta", "LayerNorm.bias"), True - if key.endswith("LayerNorm.gamma"): - return key.replace("LayerNorm.gamma", "LayerNorm.weight"), True - - # Rename weight norm parametrizations to match changes across torch versions. - # Impacts a number of speech/wav2vec models. e.g. Hubert, Wav2Vec2, and others. - # This rename is not logged. - if hasattr(nn.utils.parametrizations, "weight_norm"): - if key.endswith("weight_g"): - return key.replace("weight_g", "parametrizations.weight.original0"), True - if key.endswith("weight_v"): - return key.replace("weight_v", "parametrizations.weight.original1"), True - else: - if key.endswith("parametrizations.weight.original0"): - return key.replace("parametrizations.weight.original0", "weight_g"), True - if key.endswith("parametrizations.weight.original1"): - return key.replace("parametrizations.weight.original1", "weight_v"), True - - return key, False - - def _get_key_renaming_mapping( - self, - checkpoint_keys: list[str], - key_mapping: Optional[dict[str, str]] = None, - loading_base_model_from_task_state_dict: bool = False, - loading_task_model_from_base_state_dict: bool = False, - ): - """ - Compute a mapping between the serialized keys on disk `checkpoint_keys`, and the keys that the model - that we are loading expects. This is the single entry point for key renaming that will be used during - loading. - Log if any parameters have been renamed. - """ - prefix = self.base_model_prefix - _prefix = f"{prefix}." - - if loading_task_model_from_base_state_dict: - task_specific_expected_keys, base_model_keys = [], [] - for key in self.state_dict(): - if key.startswith(_prefix): - base_model_keys.append(key[len(_prefix) :]) - else: - task_specific_expected_keys.append(key) - - renamed_keys = {} - key_renaming_mapping = {} - for key in checkpoint_keys: - # Class specific rename - new_key, has_changed = self._fix_state_dict_key_on_load(key) - - # Optionally map the key according to `key_mapping` - if key_mapping is not None: - for pattern, replacement in key_mapping.items(): - new_key, n_replace = re.subn(pattern, replacement, new_key) - # Early exit of the loop - if n_replace > 0: - has_changed = True - break - - # In this case, we need to add the prefix to the keys, to match them to the expected keys - if loading_task_model_from_base_state_dict: - # small sanity check: if we find a key that is only part of the task-specific keys, we raise - # (if it's also part of the base model, we do not raise and assume it comes from there) - if new_key in task_specific_expected_keys and new_key not in base_model_keys: - raise ValueError( - "The state dictionary of the model you are trying to load is corrupted. Are you sure it was " - "properly saved?" - ) - new_key = ".".join([prefix, new_key]) - # In this case we need to remove the prefix from the key to match them to the expected keys, and use - # only the keys starting with the prefix - elif loading_base_model_from_task_state_dict: - if not new_key.startswith(_prefix): - continue - new_key = new_key[len(_prefix) :] - - key_renaming_mapping[key] = new_key - - # track gamma/beta rename for logging - if has_changed: - if key.endswith("LayerNorm.gamma"): - renamed_keys["LayerNorm.gamma"] = (key, new_key) - elif key.endswith("LayerNorm.beta"): - renamed_keys["LayerNorm.beta"] = (key, new_key) - - if renamed_keys: - warning_msg = f"A pretrained model of type `{self.__class__.__name__}` " - warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n" - for old_key, new_key in renamed_keys.values(): - warning_msg += f"* `{old_key}` -> `{new_key}`\n" - warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users." - logger.info_once(warning_msg) - - return key_renaming_mapping - @staticmethod def _fix_state_dict_key_on_save(key) -> tuple[str, bool]: """ @@ -4677,97 +4349,16 @@ def _load_pretrained_model( device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None, key_mapping: Optional[dict[str, str]] = None, weights_only: bool = True, + weight_mapping: Optional[Sequence[WeightConverter]] = None, ): - # TODO: we should only be calling hf_quantizer.skip_placement or something like that is_quantized = hf_quantizer is not None is_hqq_or_quark = is_quantized and hf_quantizer.quantization_config.quant_method in { QuantizationMethod.HQQ, QuantizationMethod.QUARK, } - # Get all the keys of the state dicts that we have to initialize the model with - if sharded_metadata is not None: - original_checkpoint_keys = sharded_metadata["all_checkpoint_keys"] - elif state_dict is not None: - original_checkpoint_keys = list(state_dict.keys()) - else: - original_checkpoint_keys = list( - load_state_dict(checkpoint_files[0], map_location="meta", weights_only=weights_only).keys() - ) - - # Check if we are in a special state, i.e. loading from a state dict coming from a different architecture - prefix = model.base_model_prefix - has_prefix_module = any(s.startswith(prefix) for s in original_checkpoint_keys) if len(prefix) > 0 else False - expects_prefix_module = hasattr(model, prefix) if len(prefix) > 0 else False - loading_task_model_from_base_state_dict = not has_prefix_module and expects_prefix_module - loading_base_model_from_task_state_dict = has_prefix_module and not expects_prefix_module - - # Find the key names that the model expects from the serialized keys - key_renaming_mapping = model._get_key_renaming_mapping( - original_checkpoint_keys, - key_mapping, - loading_base_model_from_task_state_dict, - loading_task_model_from_base_state_dict, - ) - checkpoint_keys = list(key_renaming_mapping.values()) - - # Find missing and unexpected keys from the state dict - missing_keys, unexpected_keys = _find_missing_and_unexpected_keys( - model, original_checkpoint_keys, checkpoint_keys, loading_base_model_from_task_state_dict, hf_quantizer - ) - # Find all the keys with shape mismatch (if we ignore the mismatch, the weights need to be newly initialized the - # same way as missing keys) - mismatched_keys, mismatched_shapes = _find_mismatched_keys( - model, - state_dict, - checkpoint_files, - ignore_mismatched_sizes, - key_renaming_mapping, - is_quantized, - weights_only, - ) - - # We need to update both the mapping and the list of checkpoint keys to remove the mismatched and unexpected ones - key_renaming_mapping = { - k: v for k, v in key_renaming_mapping.items() if v not in mismatched_keys and v not in unexpected_keys - } - checkpoint_keys = list(key_renaming_mapping.values()) - - # Move missing (and potentially mismatched) keys back to cpu from meta device (because they won't be moved when - # loading the weights as they are not in the loaded state dict) - model._move_missing_keys_from_meta_to_cpu(missing_keys + mismatched_keys, dtype, hf_quantizer) - - # correctly initialize the missing (and potentially mismatched) keys - model._initialize_missing_keys(missing_keys + mismatched_keys, is_quantized) - - # Get reverse key mapping - reverse_key_renaming_mapping = {v: k for k, v in key_renaming_mapping.items()} - - is_offloaded_safetensors = False - # This offload index if for params explicitly on the "disk" in the device_map - disk_offload_index = None - disk_only_shard_files = [] - # Prepare parameters offloading if needed - if device_map is not None and "disk" in device_map.values(): - disk_offload_index, disk_only_shard_files, is_offloaded_safetensors = accelerate_disk_offload( - disk_offload_folder, - checkpoint_files, - device_map, - checkpoint_keys, - key_renaming_mapping, - sharded_metadata, - dtype, - reverse_key_renaming_mapping, - ) - # To be able to iterate, even if we don't use it if the state_dict is already provided - elif state_dict is not None: - checkpoint_files = [""] - - # Compute expected model keys + # Model's definition arriving here is final (TP hooks added, quantized layers replaces) expected_keys = list(model.state_dict().keys()) - if hf_quantizer is not None: - expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, checkpoint_keys) - if logger.level >= logging.WARNING: verify_tp_plan(expected_keys, getattr(model, "_tp_plan", None)) @@ -4776,46 +4367,79 @@ def _load_pretrained_model( expanded_device_map = expand_device_map(device_map, expected_keys) caching_allocator_warmup(model, expanded_device_map, hf_quantizer) - # Prepare and compatabilize arguments for serial and parallel shard loading - args_list = [ - ( - shard_file, - state_dict, - disk_only_shard_files, - is_quantized, - device_map, - hf_quantizer, - key_renaming_mapping, - weights_only, + # Now we read all the files to get a pointer on each physical weights + merged_state_dict = {} + all_pointer = set() + + if device_map is None: + device_map = {"": "cpu"} + keys = sorted(device_map.keys(), key=len, reverse=True) + tp_plan = getattr(model, "_tp_plan", None) + error_msgs = [] + misc = {} + + if is_deepspeed_zero3_enabled() and not is_quantized: + error_msgs += _load_state_dict_into_zero3_model(model, state_dict) + else: + if checkpoint_files is not None: + pattern = re.compile(r"(" + "|".join(map(re.escape, keys)) + r")") + if sharded_metadata is None: + k_v_iterator = dict.fromkeys( + safe_open(checkpoint_files[0], framework="pt").keys(), "model.safetensors" + ).items() + else: + k_v_iterator = sharded_metadata["weight_map"].items() + + for k, v in k_v_iterator: + match = pattern.match(k) + if match and match.group(1) != "": + device = device_map[match.group(1)] + else: + device = device_map.get("", "cpu") + if isinstance(device, torch.device): + device = device.index # safetensors only + if device == "disk": + device = "cpu" # we read to cpu to then write to disk + file_pointer = safe_open( + os.path.join(checkpoint_files[0].rsplit("/", 1)[0], v), framework="pt", device=device + ) + all_pointer.add(file_pointer) + merged_state_dict[k] = file_pointer.get_slice(k) # don't materialize yet + elif state_dict is not None: + merged_state_dict = state_dict + else: + raise ValueError("Neither a state dict nor checkpoint files were found.") + start = time.perf_counter() + missing_keys, unexpected_keys, mismatched_keys, misc = convert_and_load_state_dict_in_model( model, - reverse_key_renaming_mapping, - disk_offload_folder, - disk_offload_index, + merged_state_dict, + weight_mapping, + tp_plan, + hf_quantizer, + dtype, + device_map, + model.dtype_plan, device_mesh, ) - for shard_file in checkpoint_files - ] + end = time.perf_counter() - error_msgs = [] + for k in all_pointer: # finally close all opened file pointers TODO async + k.__exit__(None, None, None) - if ( - os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES - and not is_deepspeed_zero3_enabled() - ): - _error_msgs, disk_offload_index = load_shard_files_with_threadpool(args_list) - error_msgs += _error_msgs - else: - if len(args_list) > 1: - args_list = logging.tqdm(args_list, desc="Loading checkpoint shards") + # Move missing (and potentially mismatched) keys back to cpu from meta device (because they won't be moved when + # loading the weights as they are not in the loaded state dict) + # Remove tied weights keys and etc + miss_and_mismatched = missing_keys | {k[0] for k in mismatched_keys} + model._move_missing_keys_from_meta_to_cpu(miss_and_mismatched, dtype, hf_quantizer) - for args in args_list: - _error_msgs, disk_offload_index = load_shard_file(args) - error_msgs += _error_msgs + # correctly initialize the missing (and potentially mismatched) keys + model._initialize_missing_keys(miss_and_mismatched, is_quantized) + missing_keys, unexpected_keys = model._adjust_missing_and_unexpected_keys( + missing_keys, unexpected_keys, False, model + ) - # Save offloaded index if needed - if disk_offload_index is not None and len(disk_offload_index) > 0 and not is_offloaded_safetensors: - save_offload_index(disk_offload_index, disk_offload_folder) - disk_offload_index = None + # We make sure we TIE after _init_ + model.tie_weights(missing_keys) # Post-processing for tensor parallelism if device_mesh is not None: @@ -4823,7 +4447,7 @@ def _load_pretrained_model( tp_device = list(device_map.values())[0] # This is needed for the RotaryEmbedding, which was not initialized on the correct device as it is # not part of the state_dict (persistent=False) - for buffer in model.buffers(): + for buffer in model.buffers(): # TODO to avaoid this buffer could be added to the ckpt if buffer.device != tp_device: buffer.data = buffer.to(tp_device) @@ -4831,7 +4455,9 @@ def _load_pretrained_model( # were not part of the loaded weights: do it now if loading_task_model_from_base_state_dict: parameters_to_initialize = { - name: param for name, param in model.named_parameters() if not name.startswith(prefix) + name: param + for name, param in model.named_parameters() + if not name.startswith(model.base_model_prefix) } for name, param in parameters_to_initialize.items(): # If it is still on meta here, it means that it's a tied weight that will be tied later anyway -> skip it @@ -4850,52 +4476,20 @@ def _load_pretrained_model( device_mesh, ) - # Remove potential model-specific exceptions from the warnings - missing_keys, unexpected_keys = model._adjust_missing_and_unexpected_keys( - missing_keys, unexpected_keys, loading_task_model_from_base_state_dict + logger.warning(f"Loading the checkpoint files into the model took {end - start}") + log_state_dict_report( + model=model, + pretrained_model_name_or_path=pretrained_model_name_or_path, + logger=logger, + error_msgs=error_msgs, + unexpected_keys=unexpected_keys, + missing_keys=missing_keys, + mismatched_keys=mismatched_keys, + mismatched_shapes=mismatched_keys, + misc=misc, + ignore_mismatched_sizes=ignore_mismatched_sizes, ) - - # TODO: separate this in another function: it's not core.... - # All potential warnings/infos - if len(error_msgs) > 0: - error_msg = "\n\t".join(error_msgs) - if "size mismatch" in error_msg: - error_msg += ( - "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." - ) - raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") - if len(unexpected_keys) > 0: - archs = [] if model.config.architectures is None else model.config.architectures - warner = logger.warning if model.__class__.__name__ in archs else logger.info - warner( - f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" - f" initializing {model.__class__.__name__}: {update_key_name(unexpected_keys)}\n- This IS expected if you are" - f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or" - " with another architecture (e.g. initializing a BertForSequenceClassification model from a" - " BertForPreTraining model).\n- This IS NOT expected if you are initializing" - f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical" - " (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." - ) - if len(missing_keys) > 0: - logger.warning( - f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" - f" {pretrained_model_name_or_path} and are newly initialized: {update_key_name(missing_keys)}\nYou should probably" - " TRAIN this model on a down-stream task to be able to use it for predictions and inference." - ) - if len(mismatched_keys) > 0: - mismatched_warning = "\n".join( - [ - f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" - for key, (shape1, shape2) in zip(mismatched_keys, mismatched_shapes) - ] - ) - logger.warning( - f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" - f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" - f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able" - " to use it for predictions and inference." - ) - + disk_offload_index = None return model, missing_keys, unexpected_keys, mismatched_keys, disk_offload_index, error_msgs def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=False): @@ -5104,43 +4698,15 @@ def _move_missing_keys_from_meta_to_cpu( value = torch.empty_like(param, dtype=dtype, device="cpu") if not is_quantized or not hf_quantizer.param_needs_quantization(self, key): _load_parameter_into_model(self, key, value) - else: - hf_quantizer.create_quantized_param(self, value, key, "cpu") def _initialize_missing_keys(self, missing_keys: list[str], is_quantized: bool) -> None: - """Initialize the missing keys (keys that are part of the model parameters, but were NOT found in the loaded state dicts), according to + """ + Initialize the missing keys (keys that are part of the model parameters, but were NOT found in the loaded state dicts), according to `_initialize_weights`. Indeed, since the corresponding weights are missing from the state dict, they will not be replaced and need to be initialized correctly (i.e. weight initialization distribution). - Also take care of setting the `_is_hf_initialized` flag for keys that are not missing. - """ - for key in self.state_dict(): - # If it's part of the keys that will be loaded, mark it as already initialized - if key not in missing_keys: - param_or_buffer = self.get_parameter_or_buffer(key) - param_or_buffer._is_hf_initialized = True - - def set_is_initialized_for_modules(module): - # A module is already initialized if and only if all its children are also already initialized, and all - # its immediate `nn.Parameter` and persistent buffers are also already initialized - if ( - # All immediate children are initialized - all(getattr(child, "_is_hf_initialized", False) for child in module.children()) - # All immediate parameters are initialized - and all(getattr(param, "_is_hf_initialized", False) for param in module.parameters(recurse=False)) - # All immediate persistent buffers are initialized - and all( - getattr(buffer, "_is_hf_initialized", False) - for name, buffer in module.named_buffers(recurse=False) - if name not in module._non_persistent_buffers_set - ) - ): - module._is_hf_initialized = True - - # Set the flag on the modules as well. We do it recursively (depth-first), as it's more efficient (we do not - # need to check the entire state dict of each module, only the immediate children, so we only iterate once over - # each param) - self.apply(set_is_initialized_for_modules) + Params that are not missing have the `is_hf_initialized` flag. + """ # This will only initialize submodules that are not marked as initialized by the line above. if is_deepspeed_zero3_enabled() and not is_quantized: import deepspeed @@ -5154,13 +4720,17 @@ def set_is_initialized_for_modules(module): else: self.initialize_weights() + for p in self.parameters(): # TODO @Cyrilvallez if we are able to do this while we smart apply my be better + setattr(p, "__class__", nn.Parameter) + setattr(p, "_is_hf_initialized", True) + def _adjust_missing_and_unexpected_keys( - self, missing_keys: list[str], unexpected_keys: list[str], loading_task_model_from_base_state_dict: bool - ) -> tuple[list[str], list[str]]: + self, missing_keys: set[str], unexpected_keys: set[str], loading_task_model_from_base_state_dict: bool, model + ) -> tuple[set[str], set[str]]: """Adjust the `missing_keys` and `unexpected_keys` based on current model's exception rules, to avoid raising unneeded warnings/errors. """ - # Old checkpoints may have keys for rotary_emb.inv_freq for each layer, however we moved this buffer to the main model + # Old checkpoints may have keys for rotary_emb.inv_freq forach layer, however we moved this buffer to the main model # (so the buffer name has changed). Remove them in such a case. This is another exception that was not added to # `_keys_to_ignore_on_load_unexpected` as it touches many models -> we add it manually to the existing patterns has_inv_freq_buffers = any(buffer.endswith("rotary_emb.inv_freq") for buffer, _ in self.named_buffers()) @@ -5176,17 +4746,17 @@ def _adjust_missing_and_unexpected_keys( # Clean-up missing keys if ignore_missing_regex is not None: - missing_keys = [key for key in missing_keys if ignore_missing_regex.search(key) is None] + missing_keys = {key for key in missing_keys if ignore_missing_regex.search(key) is None} # Clean-up unexpected keys if ignore_unexpected_regex is not None: - unexpected_keys = [key for key in unexpected_keys if ignore_unexpected_regex.search(key) is None] + unexpected_keys = {key for key in unexpected_keys if ignore_unexpected_regex.search(key) is None} # Note: only the unexpected keys should remove the added prefix here, to correctly display the original name # in the warnings. For missing keys, we should show the prefix in the warning as it's part of the final model if loading_task_model_from_base_state_dict: _prefix = f"{self.base_model_prefix}." - unexpected_keys = [k.removeprefix(_prefix) for k in unexpected_keys] + unexpected_keys = {k.removeprefix(_prefix) for k in unexpected_keys} return missing_keys, unexpected_keys @@ -5223,35 +4793,6 @@ def train(self, mode: bool = True): def eval(self): return self.train(False) - def upcast_modules_in_fp32(self, hf_quantizer: HfQuantizer | None, dtype: torch.dtype) -> None: - """ - Upcast modules defined in `_keep_in_fp32_modules` and `_keep_in_fp32_modules_strict` in fp32, if - `dtype` is different than fp32. - """ - # If the dtype is already fp32, we can skip - if dtype == torch.float32: - return - - keep_in_fp32_modules = [] - # The _keep_in_fp32_modules flag is only used to avoid bf16 -> fp16 casting precision issues. It was introduced - # in case of force loading a model that should stay bf16 in fp16 (which includes a few quantizers as this is a pre-processing - # step for e.g. bitsandbytes). See https://github.com/huggingface/transformers/issues/20287 for details. - if self._keep_in_fp32_modules is not None and ( - dtype == torch.float16 or getattr(hf_quantizer, "use_keep_in_fp32_modules", False) - ): - keep_in_fp32_modules.extend(self._keep_in_fp32_modules) - - if self._keep_in_fp32_modules_strict is not None and (dtype == torch.float16 or dtype == torch.bfloat16): - keep_in_fp32_modules.extend(self._keep_in_fp32_modules_strict) - - if len(keep_in_fp32_modules) > 0: - # We need to match exact layers, so we add either `.` on each side, or start/end of string - keep_in_fp32_regex = re.compile("|".join([rf"((^|\.){module}($|\.))" for module in keep_in_fp32_modules])) - for name, param in self.named_parameters(): - if keep_in_fp32_regex.search(name): - # param = param.to(torch.float32) does not work here as only in the local scope. - param.data = param.data.to(torch.float32) - PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub) if PreTrainedModel.push_to_hub.__doc__ is not None: diff --git a/src/transformers/models/aimv2/modeling_aimv2.py b/src/transformers/models/aimv2/modeling_aimv2.py index f44879e37b02..bead6a11dd7b 100644 --- a/src/transformers/models/aimv2/modeling_aimv2.py +++ b/src/transformers/models/aimv2/modeling_aimv2.py @@ -406,13 +406,14 @@ class Aimv2PreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_flex_attn = True + @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) if hasattr(module, "logit_scale"): if isinstance(module.logit_scale, nn.Parameter): - module.logit_scale.data.fill_(math.log(1 / 0.07)) + module.logit_scale.fill_(math.log(1 / 0.07)) elif isinstance(module, Aimv2AttentionPoolingHead): - module.cls_token.data.normal_(mean=0.0, std=self.config.initializer_range) + module.cls_token.normal_(mean=0.0, std=self.config.initializer_range) @auto_docstring( diff --git a/src/transformers/models/aimv2/modular_aimv2.py b/src/transformers/models/aimv2/modular_aimv2.py index a7ea96f8f2c2..55ff92212b39 100644 --- a/src/transformers/models/aimv2/modular_aimv2.py +++ b/src/transformers/models/aimv2/modular_aimv2.py @@ -449,13 +449,14 @@ class Aimv2PreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_flex_attn = True + @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) if hasattr(module, "logit_scale"): if isinstance(module.logit_scale, nn.Parameter): - module.logit_scale.data.fill_(math.log(1 / 0.07)) + module.logit_scale.fill_(math.log(1 / 0.07)) elif isinstance(module, Aimv2AttentionPoolingHead): - module.cls_token.data.normal_(mean=0.0, std=self.config.initializer_range) + module.cls_token.normal_(mean=0.0, std=self.config.initializer_range) @auto_docstring( diff --git a/src/transformers/models/albert/modeling_albert.py b/src/transformers/models/albert/modeling_albert.py index e8d650043169..ac4337e4f269 100755 --- a/src/transformers/models/albert/modeling_albert.py +++ b/src/transformers/models/albert/modeling_albert.py @@ -302,21 +302,22 @@ class AlbertPreTrainedModel(PreTrainedModel): "attentions": AlbertAttention, } + @torch.no_grad() def _init_weights(self, module): """Initialize the weights.""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, AlbertMLMHead): - module.bias.data.zero_() + module.bias.zero_() @dataclass @@ -425,7 +426,10 @@ def forward( """ ) class AlbertForPreTraining(AlbertPreTrainedModel): - _tied_weights_keys = ["predictions.decoder.bias", "predictions.decoder.weight"] + _tied_weights_keys = { + "predictions.decoder.weight": "albert.embeddings.word_embeddings.weight", + "predictions.decoder.bias": "predictions.bias", + } def __init__(self, config: AlbertConfig): super().__init__(config) @@ -525,7 +529,6 @@ def __init__(self, config: AlbertConfig): self.dense = nn.Linear(config.hidden_size, config.embedding_size) self.decoder = nn.Linear(config.embedding_size, config.vocab_size) self.activation = ACT2FN[config.hidden_act] - self.decoder.bias = self.bias def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) @@ -537,14 +540,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return prediction_scores - def _tie_weights(self) -> None: - # For accelerate compatibility and to not break backward compatibility - if self.decoder.bias.device.type == "meta": - self.decoder.bias = self.bias - else: - # To tie those two weights if they get disconnected (on TPU or when the bias is resized) - self.bias = self.decoder.bias - class AlbertSOPHead(nn.Module): def __init__(self, config: AlbertConfig): @@ -561,7 +556,10 @@ def forward(self, pooled_output: torch.Tensor) -> torch.Tensor: @auto_docstring class AlbertForMaskedLM(AlbertPreTrainedModel): - _tied_weights_keys = ["predictions.decoder.bias", "predictions.decoder.weight"] + _tied_weights_keys = { + "predictions.decoder.weight": "albert.embeddings.word_embeddings.weight", + "predictions.decoder.bias": "predictions.bias", + } def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/align/modeling_align.py b/src/transformers/models/align/modeling_align.py index 57b73d38ab48..6ec6d72a4771 100644 --- a/src/transformers/models/align/modeling_align.py +++ b/src/transformers/models/align/modeling_align.py @@ -823,24 +823,25 @@ class AlignPreTrainedModel(PreTrainedModel): input_modalities = ["image", "text"] supports_gradient_checkpointing = True + @torch.no_grad() def _init_weights(self, module: nn.Module): """Initialize the weights""" std = self.config.initializer_range if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, AlignModel): nn.init.xavier_uniform_(module.text_projection.weight) - module.text_projection.bias.data.zero_() - module.temperature.data.fill_(self.config.temperature_init_value) + module.text_projection.bias.zero_() + module.temperature.fill_(self.config.temperature_init_value) elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() if isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) @auto_docstring( diff --git a/src/transformers/models/altclip/modeling_altclip.py b/src/transformers/models/altclip/modeling_altclip.py index be84fb62b66d..1c45432d5f20 100755 --- a/src/transformers/models/altclip/modeling_altclip.py +++ b/src/transformers/models/altclip/modeling_altclip.py @@ -770,6 +770,7 @@ class AltCLIPPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_module = [] + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" factor = self.config.initializer_factor @@ -797,23 +798,21 @@ def _init_weights(self, module): module.text_projection.weight, std=module.text_embed_dim**-0.5 * self.config.initializer_factor, ) - module.text_projection._is_hf_initialized = True nn.init.normal_( module.visual_projection.weight, std=module.vision_embed_dim**-0.5 * self.config.initializer_factor, ) - module.visual_projection._is_hf_initialized = True elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_factor) + module.weight.normal_(mean=0.0, std=self.config.initializer_factor) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_factor) + module.weight.normal_(mean=0.0, std=self.config.initializer_factor) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() class AltCLIPVisionTransformer(nn.Module): diff --git a/src/transformers/models/apertus/modeling_apertus.py b/src/transformers/models/apertus/modeling_apertus.py index e92e87a3c280..7cdde33e8ff2 100644 --- a/src/transformers/models/apertus/modeling_apertus.py +++ b/src/transformers/models/apertus/modeling_apertus.py @@ -429,7 +429,7 @@ def forward( @auto_docstring class ApertusForCausalLM(ApertusPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/arcee/modeling_arcee.py b/src/transformers/models/arcee/modeling_arcee.py index 619e72b7a11b..513162398dd7 100644 --- a/src/transformers/models/arcee/modeling_arcee.py +++ b/src/transformers/models/arcee/modeling_arcee.py @@ -434,7 +434,7 @@ def forward( @auto_docstring(checkpoint="arcee-ai/AFM-4.5B") class ArceeForCausalLM(ArceePreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index e702077bf930..61b78357df62 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -585,10 +585,11 @@ class AriaTextPreTrainedModel(PreTrainedModel): "attentions": AriaTextAttention, } + @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) if isinstance(module, AriaGroupedExpertsGemm): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) @auto_docstring @@ -608,6 +609,7 @@ class AriaPreTrainedModel(PreTrainedModel): "attentions": AriaTextAttention, } + @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) if isinstance(module, AriaProjector): @@ -760,7 +762,7 @@ def forward( @auto_docstring class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} @@ -1053,7 +1055,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): "^multi_modal_projector": "model.multi_modal_projector", "^language_model.lm_head": "lm_head", } - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} def __init__(self, config: AriaConfig): super().__init__(config) diff --git a/src/transformers/models/aria/modular_aria.py b/src/transformers/models/aria/modular_aria.py index 66483c248a2a..c3fddb8e1f3d 100644 --- a/src/transformers/models/aria/modular_aria.py +++ b/src/transformers/models/aria/modular_aria.py @@ -1191,10 +1191,11 @@ class AriaTextPreTrainedModel(PreTrainedModel): "attentions": AriaTextAttention, } + @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) if isinstance(module, AriaGroupedExpertsGemm): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) class AriaPreTrainedModel(LlamaPreTrainedModel): @@ -1203,6 +1204,7 @@ class AriaPreTrainedModel(LlamaPreTrainedModel): _can_compile_fullgraph = False # MoE models don't work with torch.compile (dynamic slicing) _supports_attention_backend = True + @torch.no_grad() def _init_weights(self, module): PreTrainedModel._init_weights(self, module) if isinstance(module, AriaProjector): @@ -1220,7 +1222,7 @@ def __init__(self, config: AriaTextConfig): class AriaTextForCausalLM(AriaTextPreTrainedModel, LlamaForCausalLM): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} def __init__(self, config: AriaTextConfig): super().__init__(config) @@ -1359,6 +1361,8 @@ def forward( """ ) class AriaForConditionalGeneration(LlavaForConditionalGeneration): + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + def get_image_features( self, pixel_values: torch.FloatTensor, diff --git a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py index 0a918edd1886..1f270b96aa95 100644 --- a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +++ b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py @@ -300,23 +300,26 @@ class ASTPreTrainedModel(PreTrainedModel): "attentions": ASTSelfAttention, } + @torch.no_grad() def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid # `trunc_normal_cpu` not implemented in `half` issues - module.weight.data = nn.init.trunc_normal_( - module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range - ).to(module.weight.dtype) + module.weight.copy_( + nn.init.trunc_normal_(module.weight.to(torch.float32), mean=0.0, std=self.config.initializer_range).to( + module.weight.dtype + ) + ) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, ASTEmbeddings): - module.cls_token.data.zero_() - module.position_embeddings.data.zero_() - module.distillation_token.data.zero_() + module.cls_token.zero_() + module.position_embeddings.zero_() + module.distillation_token.zero_() @auto_docstring diff --git a/src/transformers/models/autoformer/modeling_autoformer.py b/src/transformers/models/autoformer/modeling_autoformer.py index 14b93fb1b66e..782ef440d0a7 100644 --- a/src/transformers/models/autoformer/modeling_autoformer.py +++ b/src/transformers/models/autoformer/modeling_autoformer.py @@ -826,21 +826,22 @@ class AutoformerPreTrainedModel(PreTrainedModel): main_input_name = "past_values" supports_gradient_checkpointing = True + @torch.no_grad() def _init_weights(self, module: nn.Module): std = self.config.init_std if isinstance(module, (nn.Linear, nn.Conv1d)): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, AutoformerSinusoidalPositionalEmbedding): module._init_weight() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() # copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask def _update_full_mask( diff --git a/src/transformers/models/aya_vision/modeling_aya_vision.py b/src/transformers/models/aya_vision/modeling_aya_vision.py index 39f9d70fcc7b..271845446db7 100644 --- a/src/transformers/models/aya_vision/modeling_aya_vision.py +++ b/src/transformers/models/aya_vision/modeling_aya_vision.py @@ -338,7 +338,7 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi "^multi_modal_projector": "model.multi_modal_projector", "^language_model.lm_head": "lm_head", } - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} def __init__(self, config: AyaVisionConfig): super().__init__(config) diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index 9285068292ad..ed07f9345e2b 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -1126,12 +1126,13 @@ class BambaPreTrainedModel(PreTrainedModel): # Note: only supports HybridMambaAttentionDynamicCache _is_stateful = True + @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) if isinstance(module, BambaMixer): - module.dt_bias.data.fill_(1.0) - module.A_log.data = torch.log(torch.arange(1, module.num_heads + 1)) - module.D.data.fill_(1.0) + module.dt_bias.fill_(1.0) + module.A_log.copy_(torch.log(torch.arange(1, module.num_heads + 1))) + module.D.fill_(1.0) @auto_docstring @@ -1383,7 +1384,7 @@ def _update_mamba_mask(self, attention_mask, cache_position): @auto_docstring class BambaForCausalLM(BambaPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py index 79a1b0e5ea15..024e8415fffe 100644 --- a/src/transformers/models/bamba/modular_bamba.py +++ b/src/transformers/models/bamba/modular_bamba.py @@ -800,12 +800,13 @@ class BambaPreTrainedModel(PreTrainedModel): # Note: only supports HybridMambaAttentionDynamicCache _is_stateful = True + @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) if isinstance(module, BambaMixer): - module.dt_bias.data.fill_(1.0) - module.A_log.data = torch.log(torch.arange(1, module.num_heads + 1)) - module.D.data.fill_(1.0) + module.dt_bias.fill_(1.0) + module.A_log.copy_(torch.log(torch.arange(1, module.num_heads + 1))) + module.D.fill_(1.0) @auto_docstring diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index 0aa063cebcd3..07d937dd7cdc 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -329,19 +329,20 @@ class BarkPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = False _supports_flash_attn = True + @torch.no_grad() def _init_weights(self, module): """Initialize the weights.""" if isinstance(module, (nn.Linear,)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) @@ -910,6 +911,9 @@ def __init__(self, config): # non-causal gpt-like model with one embedding layer and one lm_head for each codebook of Encodec super().__init__(config) self.config = config + self._tied_weights_keys = {} + for i in range(self.config.n_codes_total - self.config.n_codes_given): + self._tied_weights_keys[f"lm_heads.{i}.weight"] = f"input_embeds_layers.{i + 1}.weight" # initialize a modified non causal GPT-like model # note that for there is one embedding layer and one lm_head for each codebook of Encodec @@ -1025,25 +1029,6 @@ def resize_token_embeddings( return model_embeds - def _tie_weights(self): - if getattr(self.config, "tie_word_embeddings", True): - self._tied_weights_keys = [] - output_embeddings = self.get_output_embeddings() - input_embeddings = self.get_input_embeddings() - - for i in range(self.config.n_codes_total - self.config.n_codes_given): - # self.input_embeds_layers[i + 1].weight = self.lm_heads[i].weight - self._tie_embedding_weights(output_embeddings[i], input_embeddings[i + 1]) - self._tied_weights_keys.append(f"lm_heads.{i}.weight") - - def tie_weights(self): - """ - Tie the weights between the input embeddings list and the output embeddings list. - """ - for module in self.modules(): - if hasattr(module, "_tie_weights"): - module._tie_weights() - @auto_docstring def forward( self, @@ -1580,14 +1565,6 @@ def generate( return audio - def tie_weights(self): - """ - Tie the weights between the input embeddings list and the output embeddings list. - """ - for module in self.modules(): - if hasattr(module, "_tie_weights"): - module._tie_weights() - __all__ = [ "BarkFineModel", diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index b903becf5e9c..d08608268a15 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -476,19 +476,20 @@ class BartPreTrainedModel(PreTrainedModel): _can_compile_fullgraph = True + @torch.no_grad() def _init_weights(self, module): std = self.config.init_std if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() @property def dummy_inputs(self): @@ -527,7 +528,7 @@ class BartEncoder(BartPreTrainedModel): embed_tokens (nn.Embedding): output embedding """ - def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None): + def __init__(self, config: BartConfig): super().__init__(config) self.dropout = config.dropout @@ -538,12 +539,9 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No self.max_source_positions = config.max_position_embeddings embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 - if embed_tokens is not None: - self.embed_tokens = embed_tokens - else: - self.embed_tokens = BartScaledWordEmbedding( - config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale - ) + self.embed_tokens = BartScaledWordEmbedding( + config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale + ) self.embed_positions = BartLearnedPositionalEmbedding( config.max_position_embeddings, @@ -674,7 +672,7 @@ class BartDecoder(BartPreTrainedModel): embed_tokens (nn.Embedding): output embedding """ - def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None): + def __init__(self, config: BartConfig): super().__init__(config) self.dropout = config.dropout self.layerdrop = config.decoder_layerdrop @@ -682,12 +680,9 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No self.max_target_positions = config.max_position_embeddings embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 - if embed_tokens is not None: - self.embed_tokens = embed_tokens - else: - self.embed_tokens = BartScaledWordEmbedding( - config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale - ) + self.embed_tokens = BartScaledWordEmbedding( + config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale + ) self.embed_positions = BartLearnedPositionalEmbedding( config.max_position_embeddings, @@ -899,7 +894,10 @@ def forward( @auto_docstring class BartModel(BartPreTrainedModel): - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + _tied_weights_keys = { + "decoder.embed_tokens.weight": "shared.weight", + "encoder.embed_tokens.weight": "shared.weight", + } def __init__(self, config: BartConfig): super().__init__(config) @@ -908,24 +906,12 @@ def __init__(self, config: BartConfig): embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 self.shared = BartScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale) - self.encoder = BartEncoder(config, self.shared) - self.decoder = BartDecoder(config, self.shared) + self.encoder = BartEncoder(config) + self.decoder = BartDecoder(config) # Initialize weights and apply final processing self.post_init() - def _tie_weights(self): - if self.config.tie_word_embeddings: - # Some model checkpoints like "facebook/bart-large-cnn"'s embedding weight is in decoder.embed_tokens, need check here, see issue #36247 - if self.shared.weight.device == torch.device( - "meta" - ) and self.decoder.embed_tokens.weight.device != torch.device("meta"): - self._tie_embedding_weights(self.encoder.embed_tokens, self.decoder.embed_tokens) - self._tie_embedding_weights(self.shared, self.decoder.embed_tokens) - else: - self._tie_embedding_weights(self.encoder.embed_tokens, self.shared) - self._tie_embedding_weights(self.decoder.embed_tokens, self.shared) - def get_input_embeddings(self): return self.shared @@ -1052,7 +1038,9 @@ def forward( ) class BartForConditionalGeneration(BartPreTrainedModel, GenerationMixin): base_model_prefix = "model" - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + _tied_weights_keys = { + "lm_head.weight": "model.shared.weight", + } _keys_to_ignore_on_load_missing = ["final_logits_bias"] def __init__(self, config: BartConfig): @@ -1086,11 +1074,6 @@ def _resize_final_logits_bias(self, new_num_tokens: int) -> None: new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) self.register_buffer("final_logits_bias", new_bias) - def _tie_weights(self): - if self.config.tie_word_embeddings: - self.model._tie_weights() - self._tie_embedding_weights(self.lm_head, self.model.shared) - @auto_docstring def forward( self, @@ -1240,8 +1223,6 @@ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): """ ) class BartForSequenceClassification(BartPreTrainedModel): - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] - def __init__(self, config: BartConfig, **kwargs): super().__init__(config, **kwargs) self.model = BartModel(config) @@ -1374,8 +1355,6 @@ def forward( @auto_docstring class BartForQuestionAnswering(BartPreTrainedModel): - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] - def __init__(self, config): super().__init__(config) @@ -1513,7 +1492,9 @@ def forward(self, *args, **kwargs): """ ) class BartForCausalLM(BartPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = { + "lm_head.weight": "model.decoder.embed_tokens.weight", + } def __init__(self, config): config.is_decoder = True diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index fff3158ab387..afa955985696 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -692,31 +692,32 @@ class BeitPreTrainedModel(PreTrainedModel): _keys_to_ignore_on_load_unexpected = [r".*relative_position_index.*"] _supports_sdpa = True + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, BeitEmbeddings): - module.cls_token.data.zero_() + module.cls_token.zero_() if module.mask_token is not None: - module.mask_token.data.zero_() + module.mask_token.zero_() if module.position_embeddings is not None: - module.position_embeddings.data.zero_() + module.position_embeddings.zero_() elif isinstance(module, BeitRelativePositionBias): - module.relative_position_bias_table.data.zero_() + module.relative_position_bias_table.zero_() elif isinstance(module, BeitLayer): if module.lambda_1 is not None: - module.lambda_1.data.fill_(self.config.layer_scale_init_value) - module.lambda_2.data.fill_(self.config.layer_scale_init_value) + module.lambda_1.fill_(self.config.layer_scale_init_value) + module.lambda_2.fill_(self.config.layer_scale_init_value) @auto_docstring diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 444753bef63e..bf7d54108b32 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -506,16 +506,9 @@ def __init__(self, config): # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. - self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` - self.decoder.bias = self.bias - - def _tie_weights(self): - self.decoder.bias = self.bias - def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -569,21 +562,22 @@ class BertPreTrainedModel(PreTrainedModel): "cross_attentions": BertCrossAttention, } + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, BertLMPredictionHead): - module.bias.data.zero_() + module.bias.zero_() @dataclass @@ -770,7 +764,10 @@ def _create_attention_masks( """ ) class BertForPreTraining(BertPreTrainedModel): - _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"] + _tied_weights_keys = { + "cls.predictions.decoder.weight": "bert.embeddings.word_embeddings.weight", + "cls.predictions.decoder.bias": "cls.predictions.bias", + } def __init__(self, config): super().__init__(config) @@ -864,7 +861,10 @@ def forward( """ ) class BertLMHeadModel(BertPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"] + _tied_weights_keys = { + "cls.predictions.decoder.weight": "bert.embeddings.word_embeddings.weight", + "cls.predictions.decoder.bias": "cls.predictions.bias", + } def __init__(self, config): super().__init__(config) @@ -948,7 +948,10 @@ def forward( @auto_docstring class BertForMaskedLM(BertPreTrainedModel): - _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"] + _tied_weights_keys = { + "cls.predictions.decoder.weight": "bert.embeddings.word_embeddings.weight", + "cls.predictions.decoder.bias": "cls.predictions.bias", + } def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index 5967774905a1..359ef6889a45 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -456,21 +456,22 @@ class BertGenerationPreTrainedModel(PreTrainedModel): "cross_attentions": BertGenerationCrossAttention, } + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, BertGenerationOnlyLMHead): - module.bias.data.zero_() + module.bias.zero_() @auto_docstring( @@ -629,20 +630,11 @@ def __init__(self, config): super().__init__() self.decoder = nn.Linear(config.hidden_size, config.vocab_size) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - self.decoder.bias = self.bias def forward(self, hidden_states): logits = self.decoder(hidden_states) return logits - def _tie_weights(self): - # For accelerate compatibility and to not break backward compatibility - if self.decoder.bias.device.type == "meta": - self.decoder.bias = self.bias - else: - # To tie those two weights if they get disconnected (on TPU or when the bias is resized) - self.bias = self.decoder.bias - @auto_docstring( custom_intro=""" @@ -650,7 +642,10 @@ def _tie_weights(self): """ ) class BertGenerationDecoder(BertGenerationPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] + _tied_weights_keys = { + "lm_head.decoder.weight": "bert.embeddings.word_embeddings.weight", + "lm_head.decoder.bias": "lm_head.bias", + } def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index 3b2d5fcf797a..ccdc0dd8b842 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -1464,16 +1464,9 @@ def __init__(self, config): # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. - self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` - self.decoder.bias = self.bias - - def _tie_weights(self): - self.decoder.bias = self.bias - def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -1521,21 +1514,22 @@ class BigBirdPreTrainedModel(PreTrainedModel): base_model_prefix = "bert" supports_gradient_checkpointing = True + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, BigBirdLMPredictionHead): - module.bias.data.zero_() + module.bias.zero_() @dataclass @@ -1899,7 +1893,10 @@ def _pad_to_block_size( class BigBirdForPreTraining(BigBirdPreTrainedModel): - _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"] + _tied_weights_keys = { + "cls.predictions.decoder.bias": "cls.predictions.bias", + "cls.predictions.decoder.weight": "bert.embeddings.word_embeddings.weight", + } def __init__(self, config): super().__init__(config) @@ -1999,7 +1996,10 @@ def forward( @auto_docstring class BigBirdForMaskedLM(BigBirdPreTrainedModel): - _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"] + _tied_weights_keys = { + "cls.predictions.decoder.bias": "cls.predictions.bias", + "cls.predictions.decoder.weight": "bert.embeddings.word_embeddings.weight", + } def __init__(self, config): super().__init__(config) @@ -2141,7 +2141,10 @@ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_ """ ) class BigBirdForCausalLM(BigBirdPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"] + _tied_weights_keys = { + "cls.predictions.decoder.bias": "cls.predictions.bias", + "cls.predictions.decoder.weight": "bert.embeddings.word_embeddings.weight", + } def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index ada977bfe7fa..220b050496a1 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -1539,19 +1539,20 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel): _can_compile_fullgraph = True + @torch.no_grad() def _init_weights(self, module): std = self.config.init_std if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() @property def dummy_inputs(self): @@ -1574,7 +1575,7 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel): embed_tokens (nn.Embedding): output embedding """ - def __init__(self, config: BigBirdPegasusConfig, embed_tokens: Optional[nn.Embedding] = None): + def __init__(self, config: BigBirdPegasusConfig): super().__init__(config) self.attention_type = config.attention_type @@ -1592,9 +1593,6 @@ def __init__(self, config: BigBirdPegasusConfig, embed_tokens: Optional[nn.Embed config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale ) - if embed_tokens is not None: - self.embed_tokens.weight = embed_tokens.weight - self.embed_positions = BigBirdPegasusLearnedPositionalEmbedding( config.max_position_embeddings, embed_dim, @@ -1849,7 +1847,7 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel): embed_tokens (nn.Embedding): output embedding """ - def __init__(self, config: BigBirdPegasusConfig, embed_tokens: Optional[nn.Embedding] = None): + def __init__(self, config: BigBirdPegasusConfig): super().__init__(config) self.dropout = config.dropout self.layerdrop = config.decoder_layerdrop @@ -1861,9 +1859,6 @@ def __init__(self, config: BigBirdPegasusConfig, embed_tokens: Optional[nn.Embed config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale ) - if embed_tokens is not None: - self.embed_tokens.weight = embed_tokens.weight - self.embed_positions = BigBirdPegasusLearnedPositionalEmbedding( config.max_position_embeddings, config.d_model, @@ -2075,7 +2070,10 @@ def forward( @auto_docstring class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel): - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + _tied_weights_keys = { + "encoder.embed_tokens.weight": "shared.weight", + "decoder.embed_tokens.weight": "shared.weight", + } def __init__(self, config: BigBirdPegasusConfig): super().__init__(config) @@ -2086,8 +2084,8 @@ def __init__(self, config: BigBirdPegasusConfig): vocab_size, config.d_model, padding_idx, embed_scale=embed_scale ) - self.encoder = BigBirdPegasusEncoder(config, self.shared) - self.decoder = BigBirdPegasusDecoder(config, self.shared) + self.encoder = BigBirdPegasusEncoder(config) + self.decoder = BigBirdPegasusDecoder(config) # Initialize weights and apply final processing self.post_init() @@ -2100,11 +2098,6 @@ def set_input_embeddings(self, value): self.encoder.embed_tokens = self.shared self.decoder.embed_tokens = self.shared - def _tie_weights(self): - if self.config.tie_word_embeddings: - self._tie_embedding_weights(self.encoder.embed_tokens, self.shared) - self._tie_embedding_weights(self.decoder.embed_tokens, self.shared) - def get_encoder(self): return self.encoder @@ -2213,7 +2206,9 @@ def forward( # Copied from transformers.models.bart.modeling_bart.BartForConditionalGeneration with Bart->BigBirdPegasus, BART->BIGBIRD_PEGASUS class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel, GenerationMixin): base_model_prefix = "model" - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + _tied_weights_keys = { + "lm_head.weight": "model.shared.weight", + } _keys_to_ignore_on_load_missing = ["final_logits_bias"] def __init__(self, config: BigBirdPegasusConfig): @@ -2247,11 +2242,6 @@ def _resize_final_logits_bias(self, new_num_tokens: int) -> None: new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) self.register_buffer("final_logits_bias", new_bias) - def _tie_weights(self): - if self.config.tie_word_embeddings: - self.model._tie_weights() - self._tie_embedding_weights(self.lm_head, self.model.shared) - @auto_docstring # Ignore copy def forward( @@ -2374,8 +2364,6 @@ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): """ ) class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel): - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] - def __init__(self, config: BigBirdPegasusConfig, **kwargs): super().__init__(config, **kwargs) self.model = BigBirdPegasusModel(config) @@ -2497,8 +2485,6 @@ def forward( @auto_docstring class BigBirdPegasusForQuestionAnswering(BigBirdPegasusPreTrainedModel): - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] - def __init__(self, config): super().__init__(config) @@ -2621,8 +2607,6 @@ def forward(self, *args, **kwargs): class BigBirdPegasusForCausalLM(BigBirdPegasusPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] - def __init__(self, config): config.is_decoder = True config.is_encoder_decoder = False diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index 67bca4bae7ed..886d80f9936a 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -510,7 +510,7 @@ def forward( """ ) class BioGptForCausalLM(BioGptPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["output_projection.weight"] + _tied_weights_keys = {"output_projection.weight": "biogpt.embed_tokens.weight"} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/biogpt/modular_biogpt.py b/src/transformers/models/biogpt/modular_biogpt.py index f267d9fc10ca..0a0e9958c109 100644 --- a/src/transformers/models/biogpt/modular_biogpt.py +++ b/src/transformers/models/biogpt/modular_biogpt.py @@ -332,7 +332,7 @@ def forward( """ ) class BioGptForCausalLM(BioGptPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["output_projection.weight"] + _tied_weights_keys = {"output_projection.weight": "biogpt.embed_tokens.weight"} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/bit/modeling_bit.py b/src/transformers/models/bit/modeling_bit.py index 916f99a1556e..fe80fcda4dc8 100644 --- a/src/transformers/models/bit/modeling_bit.py +++ b/src/transformers/models/bit/modeling_bit.py @@ -628,6 +628,7 @@ class BitPreTrainedModel(PreTrainedModel): main_input_name = "pixel_values" _no_split_modules = ["BitEmbeddings"] + @torch.no_grad() def _init_weights(self, module): if isinstance(module, nn.Conv2d): nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") diff --git a/src/transformers/models/bitnet/modeling_bitnet.py b/src/transformers/models/bitnet/modeling_bitnet.py index d3972946a203..3b4f3fd69ed0 100644 --- a/src/transformers/models/bitnet/modeling_bitnet.py +++ b/src/transformers/models/bitnet/modeling_bitnet.py @@ -433,7 +433,7 @@ def forward( @auto_docstring class BitNetForCausalLM(BitNetPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = None _pp_plan = None diff --git a/src/transformers/models/bitnet/modular_bitnet.py b/src/transformers/models/bitnet/modular_bitnet.py index bc3e7c1cf2b9..093eb2428395 100644 --- a/src/transformers/models/bitnet/modular_bitnet.py +++ b/src/transformers/models/bitnet/modular_bitnet.py @@ -114,7 +114,7 @@ class BitNetModel(LlamaModel): class BitNetForCausalLM(LlamaForCausalLM): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = None _pp_plan = None diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 8faa86b1fd2b..bd7790f5a7a4 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -438,19 +438,20 @@ class BlenderbotPreTrainedModel(PreTrainedModel): _can_compile_fullgraph = True + @torch.no_grad() def _init_weights(self, module): std = self.config.init_std if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() @property def dummy_inputs(self): @@ -474,7 +475,7 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel): embed_tokens (nn.Embedding): output embedding """ - def __init__(self, config: BlenderbotConfig, embed_tokens: Optional[nn.Embedding] = None): + def __init__(self, config: BlenderbotConfig): super().__init__(config) self.dropout = config.dropout @@ -485,12 +486,9 @@ def __init__(self, config: BlenderbotConfig, embed_tokens: Optional[nn.Embedding self.max_source_positions = config.max_position_embeddings embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 - if embed_tokens is not None: - self.embed_tokens = embed_tokens - else: - self.embed_tokens = BlenderbotScaledWordEmbedding( - config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale - ) + self.embed_tokens = BlenderbotScaledWordEmbedding( + config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale + ) self.embed_positions = BlenderbotLearnedPositionalEmbedding( config.max_position_embeddings, @@ -623,7 +621,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): embed_tokens (nn.Embedding): output embedding """ - def __init__(self, config: BlenderbotConfig, embed_tokens: Optional[nn.Embedding] = None): + def __init__(self, config: BlenderbotConfig): super().__init__(config) self.dropout = config.dropout self.layerdrop = config.decoder_layerdrop @@ -631,12 +629,9 @@ def __init__(self, config: BlenderbotConfig, embed_tokens: Optional[nn.Embedding self.max_target_positions = config.max_position_embeddings embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 - if embed_tokens is not None: - self.embed_tokens = embed_tokens - else: - self.embed_tokens = BlenderbotScaledWordEmbedding( - config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale - ) + self.embed_tokens = BlenderbotScaledWordEmbedding( + config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale + ) self.embed_positions = BlenderbotLearnedPositionalEmbedding( config.max_position_embeddings, @@ -852,7 +847,10 @@ def forward( @auto_docstring class BlenderbotModel(BlenderbotPreTrainedModel): - _tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"] + _tied_weights_keys = { + "encoder.embed_tokens.weight": "shared.weight", + "decoder.embed_tokens.weight": "shared.weight", + } def __init__(self, config: BlenderbotConfig): super().__init__(config) @@ -860,8 +858,8 @@ def __init__(self, config: BlenderbotConfig): padding_idx, vocab_size = config.pad_token_id, config.vocab_size embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 self.shared = BlenderbotScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale) - self.encoder = BlenderbotEncoder(config, self.shared) - self.decoder = BlenderbotDecoder(config, self.shared) + self.encoder = BlenderbotEncoder(config) + self.decoder = BlenderbotDecoder(config) # Initialize weights and apply final processing self.post_init() @@ -1001,7 +999,9 @@ def forward( class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel, GenerationMixin): base_model_prefix = "model" _keys_to_ignore_on_load_missing = ["final_logits_bias"] - _tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "lm_head.weight"] + _tied_weights_keys = { + "lm_head.weight": "model.shared.weight", + } def __init__(self, config: BlenderbotConfig): super().__init__(config) @@ -1184,7 +1184,9 @@ def forward(self, *args, **kwargs): # Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->Blenderbot, facebook/bart-base->facebook/blenderbot-400M-distill class BlenderbotForCausalLM(BlenderbotPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = { + "lm_head.weight": "model.decoder.embed_tokens.weight", + } def __init__(self, config): config.is_decoder = True diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 675df2cd49eb..bd1a36cb4d22 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -431,19 +431,20 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel): _can_compile_fullgraph = True + @torch.no_grad() def _init_weights(self, module): std = self.config.init_std if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() @property def dummy_inputs(self): @@ -467,7 +468,7 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel): embed_tokens (nn.Embedding): output embedding """ - def __init__(self, config: BlenderbotSmallConfig, embed_tokens: Optional[nn.Embedding] = None): + def __init__(self, config: BlenderbotSmallConfig): super().__init__(config) self.dropout = config.dropout @@ -478,10 +479,7 @@ def __init__(self, config: BlenderbotSmallConfig, embed_tokens: Optional[nn.Embe self.max_source_positions = config.max_position_embeddings self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 - if embed_tokens is not None: - self.embed_tokens = embed_tokens - else: - self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) + self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) self.embed_positions = BlenderbotSmallLearnedPositionalEmbedding( config.max_position_embeddings, @@ -612,7 +610,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): embed_tokens (nn.Embedding): output embedding """ - def __init__(self, config: BlenderbotSmallConfig, embed_tokens: Optional[nn.Embedding] = None): + def __init__(self, config: BlenderbotSmallConfig): super().__init__(config) self.dropout = config.dropout self.layerdrop = config.decoder_layerdrop @@ -620,10 +618,7 @@ def __init__(self, config: BlenderbotSmallConfig, embed_tokens: Optional[nn.Embe self.max_target_positions = config.max_position_embeddings self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 - if embed_tokens is not None: - self.embed_tokens = embed_tokens - else: - self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) + self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) self.embed_positions = BlenderbotSmallLearnedPositionalEmbedding( config.max_position_embeddings, @@ -838,7 +833,10 @@ def forward( @auto_docstring class BlenderbotSmallModel(BlenderbotSmallPreTrainedModel): - _tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"] + _tied_weights_keys = { + "encoder.embed_tokens.weight": "shared.weight", + "decoder.embed_tokens.weight": "shared.weight", + } def __init__(self, config: BlenderbotSmallConfig): super().__init__(config) @@ -846,8 +844,8 @@ def __init__(self, config: BlenderbotSmallConfig): padding_idx, vocab_size = config.pad_token_id, config.vocab_size self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) - self.encoder = BlenderbotSmallEncoder(config, self.shared) - self.decoder = BlenderbotSmallDecoder(config, self.shared) + self.encoder = BlenderbotSmallEncoder(config) + self.decoder = BlenderbotSmallDecoder(config) # Initialize weights and apply final processing self.post_init() @@ -974,7 +972,9 @@ def forward( class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel, GenerationMixin): base_model_prefix = "model" _keys_to_ignore_on_load_missing = ["final_logits_bias"] - _tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "lm_head.weight"] + _tied_weights_keys = { + "lm_head.weight": "model.shared.weight", + } def __init__(self, config: BlenderbotSmallConfig): super().__init__(config) @@ -1144,7 +1144,9 @@ def forward(self, *args, **kwargs): # Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->BlenderbotSmall, facebook/bart-base->facebook/blenderbot_small-90M class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = { + "lm_head.weight": "model.decoder.embed_tokens.weight", + } def __init__(self, config): config.is_decoder = True diff --git a/src/transformers/models/blip/modeling_blip.py b/src/transformers/models/blip/modeling_blip.py index abde4b5dba0a..4920678d5d87 100644 --- a/src/transformers/models/blip/modeling_blip.py +++ b/src/transformers/models/blip/modeling_blip.py @@ -419,13 +419,14 @@ class BlipPreTrainedModel(PreTrainedModel): _no_split_modules = ["BlipEncoderLayer", "BlipTextEmbeddings"] _skip_keys_device_placement = ["past_key_values"] + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" factor = self.config.initializer_range if isinstance(module, (nn.Conv2d, nn.Embedding, nn.Linear)): - module.weight.data.normal_(mean=0.0, std=factor) + module.weight.normal_(mean=0.0, std=factor) if hasattr(module, "bias") and module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() if isinstance(module, BlipVisionEmbeddings): if hasattr(self.config, "vision_config"): @@ -443,10 +444,10 @@ def _init_weights(self, module): ) elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, nn.Linear) and module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() class BlipEncoder(nn.Module): @@ -797,8 +798,11 @@ def forward( ) class BlipForConditionalGeneration(BlipPreTrainedModel, GenerationMixin): config: BlipConfig - _tied_weights_keys = ["text_decoder.cls.predictions.decoder.bias"] main_input_name = "pixel_values" + _tied_weights_keys = { + "text_decoder.cls.predictions.decoder.bias": "text_decoder.cls.predictions.bias", + "text_decoder.cls.predictions.decoder.weight": "text_decoder.bert.embeddings.word_embeddings.weight", + } # TODO @arthurzucker check why we need this when for other models, their subPreTrainedModel handle it themselves. def __init__(self, config: BlipConfig): super().__init__(config) @@ -963,7 +967,10 @@ def generate( ) class BlipForQuestionAnswering(BlipPreTrainedModel, GenerationMixin): config: BlipConfig - _tied_weights_keys = ["text_decoder.cls.predictions.decoder.bias"] + _tied_weights_keys = { + "text_decoder.cls.predictions.decoder.bias": "text_decoder.cls.predictions.bias", + "text_decoder.cls.predictions.decoder.weight": "text_decoder.bert.embeddings.word_embeddings.weight", + } def __init__(self, config: BlipConfig): super().__init__(config) @@ -971,7 +978,6 @@ def __init__(self, config: BlipConfig): self.vision_model = BlipVisionModel(config.vision_config) self.text_encoder = BlipTextModel(config.text_config, add_pooling_layer=False) - self.text_decoder = BlipTextLMHeadModel(config.text_config) self.decoder_pad_token_id = config.text_config.pad_token_id diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index ee67f77d5241..6e9e3bb7c2c3 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -473,16 +473,9 @@ def __init__(self, config): # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. - self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` - self.decoder.bias = self.bias - - def _tie_weights(self): - self.decoder.bias = self.bias - def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -511,15 +504,16 @@ class BlipTextPreTrainedModel(PreTrainedModel): base_model_prefix = "bert" _no_split_modules = [] + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Embedding)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() # Adapted from https://github.com/salesforce/BLIP/blob/3a29b7410476bf5f2ba0955827390eb6ea1f4f9d/models/med.py#L571 @@ -744,7 +738,10 @@ def forward( # Adapted from https://github.com/salesforce/BLIP/blob/main/models/med.py#L811 class BlipTextLMHeadModel(BlipTextPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"] + _tied_weights_keys = { + "cls.predictions.decoder.bias": "cls.predictions.bias", + "cls.predictions.decoder.weight": "bert.embeddings.word_embeddings.weight", + } def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/blip_2/configuration_blip_2.py b/src/transformers/models/blip_2/configuration_blip_2.py index 2694fdeb1085..e281d92cd9ea 100644 --- a/src/transformers/models/blip_2/configuration_blip_2.py +++ b/src/transformers/models/blip_2/configuration_blip_2.py @@ -263,9 +263,7 @@ class Blip2Config(PreTrainedConfig): ```""" model_type = "blip-2" - attribute_map = { - "image_token_id": "image_token_index", - } + attribute_map = {"image_token_id": "image_token_index", "tie_words_embeddings": "use_decoder_only_language_model"} sub_configs = {"text_config": AutoConfig, "qformer_config": Blip2QFormerConfig, "vision_config": Blip2VisionConfig} def __init__( diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 806b08469f6f..175e69180935 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -409,19 +409,20 @@ class Blip2PreTrainedModel(PreTrainedModel): ] _skip_keys_device_placement = "past_key_values" + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" factor = self.config.initializer_range if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=factor) + module.weight.normal_(mean=0.0, std=factor) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=factor) + module.weight.normal_(mean=0.0, std=factor) elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, Blip2VisionEmbeddings): nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor) nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor) @@ -435,7 +436,7 @@ def _init_weights(self, module): Blip2ForImageTextRetrieval, ), ): - module.query_tokens.data.zero_() + module.query_tokens.zero_() # Copied from transformers.models.blip.modeling_blip.BlipEncoder with Blip->Blip2 @@ -1049,10 +1050,6 @@ def __init__(self, config: Blip2Config): else: language_model = AutoModelForSeq2SeqLM.from_config(config.text_config) - # Update _tied_weights_keys using the base model used. - if language_model._tied_weights_keys is not None: - self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys] - self.language_model = language_model # Initialize weights and apply final processing @@ -1076,11 +1073,6 @@ def get_encoder(self): def get_decoder(self): return self.language_model.get_decoder() - def _tie_weights(self): - if not self.config.use_decoder_only_language_model: - self.language_model.encoder.embed_tokens = self.language_model.shared - self.language_model.decoder.embed_tokens = self.language_model.shared - @filter_out_non_signature_kwargs() @auto_docstring def get_text_features( @@ -1612,10 +1604,6 @@ def __init__(self, config: Blip2Config): else: language_model = AutoModelForSeq2SeqLM.from_config(config.text_config) - # Update _tied_weights_keys using the base model used. - if language_model._tied_weights_keys is not None: - self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys] - self.language_model = language_model # Initialize weights and apply final processing @@ -1639,11 +1627,6 @@ def get_encoder(self): def get_decoder(self): return self.language_model.get_decoder() - def _tie_weights(self): - if not self.config.use_decoder_only_language_model: - self.language_model.encoder.embed_tokens = self.language_model.shared - self.language_model.decoder.embed_tokens = self.language_model.shared - def _preprocess_accelerate(self): r""" Some pre-processing hacks to make the model `accelerate` compatible. Check diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index af63b5ef66f2..82a5444b2057 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -425,19 +425,20 @@ class BloomPreTrainedModel(PreTrainedModel): def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) + @torch.no_grad() def _init_weights(self, module: nn.Module): """Initialize the weights.""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) @auto_docstring @@ -722,7 +723,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( """ ) class BloomForCausalLM(BloomPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "transformer.word_embeddings.weight"} def __init__(self, config: BloomConfig): super().__init__(config) diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index 678b67b377b3..5e63d9b203d4 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -1231,7 +1231,7 @@ class BltForCausalLM(BltPreTrainedModel, GenerationMixin): config: BltConfig _can_compile_fullgraph = False base_model_prefix = "model" - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"model.local_encoder.embed_tokens.weight": "lm_head.weight"} def __init__(self, config: BltConfig): super().__init__(config.get_text_config()) diff --git a/src/transformers/models/blt/modular_blt.py b/src/transformers/models/blt/modular_blt.py index f25380d7417c..78d5aa5a15ef 100644 --- a/src/transformers/models/blt/modular_blt.py +++ b/src/transformers/models/blt/modular_blt.py @@ -964,7 +964,7 @@ class BltForCausalLM(MllamaForCausalLM): config: BltConfig _can_compile_fullgraph = False base_model_prefix = "model" - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"model.local_encoder.embed_tokens.weight": "lm_head.weight"} def __init__(self, config: BltConfig): super().__init__(config) diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index 9647f8bb38f8..a44eb7bfabb1 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -919,6 +919,7 @@ class BridgeTowerPreTrainedModel(PreTrainedModel): _no_split_modules = ["BridgeTowerSelfAttention", "BridgeTowerResidualAttention"] _skip_keys_device_placement = "past_key_values" + @torch.no_grad() def _init_weights(self, module: nn.Module): std = self.config.initializer_factor if isinstance(module, BridgeTowerVisionTransformer): @@ -927,7 +928,7 @@ def _init_weights(self, module: nn.Module): fc_std = (2 * self.config.hidden_size) ** -0.5 for block in module.transformer.resblocks: nn.init.normal_(block.attn.in_proj_weight, std=attn_std * std) - block.attn.in_proj_bias.data.zero_() + block.attn.in_proj_bias.zero_() nn.init.normal_(block.attn.out_proj.weight, std=proj_std * std) nn.init.normal_(block.mlp.c_fc.weight, std=fc_std * std) nn.init.normal_(block.mlp.c_proj.weight, std=proj_std * std) @@ -935,15 +936,15 @@ def _init_weights(self, module: nn.Module): nn.init.normal_(module.embeddings.class_embedding, std=attn_std * std) nn.init.normal_(module.embeddings.position_embedding.weight, std=attn_std * std) elif isinstance(module, (nn.Linear, nn.Conv2d, nn.Embedding)): - module.weight.data.normal_(mean=0.0, std=0.05 * std) + module.weight.normal_(mean=0.0, std=0.05 * std) elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, BridgeTowerForContrastiveLearning): - module.logit_scale.data.fill_(self.config.logit_scale_init_value) + module.logit_scale.fill_(self.config.logit_scale_init_value) if isinstance(module, (nn.Linear, BridgeTowerMLMHead)) and module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() class BridgeTowerVisionModel(BridgeTowerPreTrainedModel): @@ -1497,7 +1498,7 @@ def forward(self, x): """ ) class BridgeTowerForMaskedLM(BridgeTowerPreTrainedModel): - _tied_weights_keys = ["mlm_score.decoder.weight"] + _tied_weights_keys = {"mlm_score.decoder.weight": "bridgetower.text_model.embeddings.word_embeddings.weight"} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/bros/modeling_bros.py b/src/transformers/models/bros/modeling_bros.py index 3e7b4b40cb84..74da9e9c8ae8 100755 --- a/src/transformers/models/bros/modeling_bros.py +++ b/src/transformers/models/bros/modeling_bros.py @@ -514,20 +514,21 @@ class BrosPreTrainedModel(PreTrainedModel): config: BrosConfig base_model_prefix = "bros" + @torch.no_grad() def _init_weights(self, module: nn.Module): """Initialize the weights""" std = self.config.initializer_range if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, BrosRelationExtractor): nn.init.normal_(module.dummy_node, std=std) diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index 26897520a2c7..267aafe5959e 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -383,7 +383,6 @@ def __init__(self, config): self.decoder = nn.Linear(config.hidden_size, config.vocab_size) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - self.decoder.bias = self.bias def forward(self, features, **kwargs): x = self.dense(features) @@ -395,14 +394,6 @@ def forward(self, features, **kwargs): return x - def _tie_weights(self): - # To tie those two weights if they get disconnected (on TPU or when the bias is resized) - # For accelerate compatibility and to not break backward compatibility - if self.decoder.bias.device.type == "meta": - self.decoder.bias = self.bias - else: - self.bias = self.decoder.bias - @auto_docstring class CamembertPreTrainedModel(PreTrainedModel): @@ -419,21 +410,22 @@ class CamembertPreTrainedModel(PreTrainedModel): "cross_attentions": CamembertCrossAttention, } + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, CamembertLMHead): - module.bias.data.zero_() + module.bias.zero_() class CamembertEmbeddings(nn.Module): @@ -745,7 +737,10 @@ def _create_attention_masks( @auto_docstring class CamembertForMaskedLM(CamembertPreTrainedModel): - _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] + _tied_weights_keys = { + "lm_head.decoder.weight": "roberta.embeddings.word_embeddings.weight", + "lm_head.decoder.bias": "lm_head.bias", + } def __init__(self, config): super().__init__(config) @@ -1191,7 +1186,10 @@ def forward( """ ) class CamembertForCausalLM(CamembertPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] + _tied_weights_keys = { + "lm_head.decoder.weight": "camembert.embeddings.word_embeddings.weight", + "lm_head.decoder.bias": "lm_head.bias", + } def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/camembert/modular_camembert.py b/src/transformers/models/camembert/modular_camembert.py index eb83629ccc4e..6a72534c9132 100644 --- a/src/transformers/models/camembert/modular_camembert.py +++ b/src/transformers/models/camembert/modular_camembert.py @@ -53,6 +53,11 @@ class CamembertModel(RobertaModel): class CamembertForMaskedLM(RobertaForMaskedLM): + _tied_weights_keys = { + "lm_head.decoder.weight": "roberta.embeddings.word_embeddings.weight", + "lm_head.decoder.bias": "lm_head.bias", + } + def __init__(self, config): super().__init__(config) del self.camembert diff --git a/src/transformers/models/canine/modeling_canine.py b/src/transformers/models/canine/modeling_canine.py index 8965ae9a3f7c..2b0a1e897266 100644 --- a/src/transformers/models/canine/modeling_canine.py +++ b/src/transformers/models/canine/modeling_canine.py @@ -688,12 +688,11 @@ def __init__(self, config): # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. - self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` - self.decoder.bias = self.bias def forward(self, hidden_states: tuple[torch.FloatTensor]) -> torch.FloatTensor: hidden_states = self.transform(hidden_states) @@ -720,19 +719,20 @@ class CaninePreTrainedModel(PreTrainedModel): base_model_prefix = "canine" supports_gradient_checkpointing = True + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv1d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) @auto_docstring diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index 136b47b016c2..0930e44cb718 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -1009,7 +1009,7 @@ def forward( """ ) class ChameleonForConditionalGeneration(ChameleonPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/chinese_clip/modeling_chinese_clip.py b/src/transformers/models/chinese_clip/modeling_chinese_clip.py index 6e254f9bb3a7..815124aa45f8 100644 --- a/src/transformers/models/chinese_clip/modeling_chinese_clip.py +++ b/src/transformers/models/chinese_clip/modeling_chinese_clip.py @@ -562,6 +562,7 @@ class ChineseCLIPPreTrainedModel(PreTrainedModel): input_modalities = ["image", "text"] supports_gradient_checkpointing = True + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" factor = self.config.initializer_factor @@ -576,7 +577,7 @@ def _init_weights(self, module): nn.init.normal_(module.token_type_embeddings.weight, mean=0.0, std=self.config.initializer_range) for embedding in [module.word_embeddings, module.position_embeddings, module.token_type_embeddings]: if embedding.padding_idx is not None: - embedding.weight.data[embedding.padding_idx].zero_() + embedding.weight[embedding.padding_idx].zero_() elif isinstance(module, ChineseCLIPVisionAttention): factor = self.config.initializer_factor in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor @@ -602,12 +603,12 @@ def _init_weights(self, module): ) if isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() # Copied from transformers.models.align.modeling_align.AlignTextEncoder with Align->ChineseCLIP diff --git a/src/transformers/models/clap/modeling_clap.py b/src/transformers/models/clap/modeling_clap.py index 89ad2ec26a61..0a44ecb7ffe7 100644 --- a/src/transformers/models/clap/modeling_clap.py +++ b/src/transformers/models/clap/modeling_clap.py @@ -1308,28 +1308,29 @@ class ClapPreTrainedModel(PreTrainedModel): input_modalities = ["audio", "text"] supports_gradient_checkpointing = False + @torch.no_grad() def _init_weights(self, module: nn.Module): """Initialize the weights""" factor = self.config.initializer_factor if isinstance(module, ClapTextEmbeddings): - module.position_embeddings.weight.data.normal_(mean=0.0, std=factor * 0.02) - module.token_type_embeddings.weight.data.normal_(mean=0.0, std=factor * 0.02) + module.position_embeddings.weight.normal_(mean=0.0, std=factor * 0.02) + module.token_type_embeddings.weight.normal_(mean=0.0, std=factor * 0.02) elif isinstance(module, ClapModel): - module.logit_scale_a.data.fill_(math.log(self.config.logit_scale_init_value)) - module.logit_scale_t.data.fill_(math.log(self.config.logit_scale_init_value)) + module.logit_scale_a.fill_(math.log(self.config.logit_scale_init_value)) + module.logit_scale_t.fill_(math.log(self.config.logit_scale_init_value)) elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=factor * 0.02) + module.weight.normal_(mean=0.0, std=factor * 0.02) elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, (nn.Conv2d, nn.Linear)): in_proj_std = (self.config.hidden_size**-0.5) * ((2 * self.config.num_hidden_layers) ** -0.5) * factor nn.init.normal_(module.weight, std=in_proj_std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, ClapAudioSelfAttention): - module.relative_position_bias_table.data.zero_() + module.relative_position_bias_table.zero_() class ClapAudioModel(ClapPreTrainedModel): diff --git a/src/transformers/models/clip/modeling_clip.py b/src/transformers/models/clip/modeling_clip.py index 33a85df063c7..8ce33c4a0dcf 100644 --- a/src/transformers/models/clip/modeling_clip.py +++ b/src/transformers/models/clip/modeling_clip.py @@ -408,12 +408,13 @@ class CLIPPreTrainedModel(PreTrainedModel): "attentions": CLIPAttention, } + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" factor = self.config.initializer_factor if isinstance(module, CLIPTextEmbeddings): - module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) - module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) + module.token_embedding.weight.normal_(mean=0.0, std=factor * 0.02) + module.position_embedding.weight.normal_(mean=0.0, std=factor * 0.02) elif isinstance(module, CLIPVisionEmbeddings): factor = self.config.initializer_factor nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) @@ -459,10 +460,10 @@ def _init_weights(self, module): ) if isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() class CLIPEncoder(nn.Module): diff --git a/src/transformers/models/clipseg/modeling_clipseg.py b/src/transformers/models/clipseg/modeling_clipseg.py index be00e0e70381..9f14686630ba 100644 --- a/src/transformers/models/clipseg/modeling_clipseg.py +++ b/src/transformers/models/clipseg/modeling_clipseg.py @@ -427,12 +427,13 @@ class CLIPSegPreTrainedModel(PreTrainedModel): input_modalities = ["image", "text"] supports_gradient_checkpointing = True + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" factor = self.config.initializer_factor if isinstance(module, CLIPSegTextEmbeddings): - module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) - module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) + module.token_embedding.weight.normal_(mean=0.0, std=factor * 0.02) + module.position_embedding.weight.normal_(mean=0.0, std=factor * 0.02) elif isinstance(module, CLIPSegVisionEmbeddings): factor = self.config.initializer_factor nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) @@ -463,10 +464,10 @@ def _init_weights(self, module): ) if isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->CLIPSeg diff --git a/src/transformers/models/clvp/modeling_clvp.py b/src/transformers/models/clvp/modeling_clvp.py index fe6c9790b9ae..9893b6bd1442 100644 --- a/src/transformers/models/clvp/modeling_clvp.py +++ b/src/transformers/models/clvp/modeling_clvp.py @@ -781,17 +781,18 @@ class ClvpPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _skip_keys_device_placement = "past_key_values" + @torch.no_grad() def _init_weights(self, module: nn.Module): """Initialize the weights""" factor = self.config.initializer_factor if isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=factor * 0.02) + module.weight.normal_(mean=0.0, std=factor * 0.02) elif isinstance(module, (nn.Linear, Conv1D, nn.Conv1d)): - module.weight.data.normal_(mean=0.0, std=factor * 0.02) + module.weight.normal_(mean=0.0, std=factor * 0.02) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, ClvpRMSNorm): - module.weight.data.fill_(1.0) + module.weight.fill_(1.0) elif isinstance(module, ClvpEncoderMLP): in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor fc_std = (2 * module.config.hidden_size) ** -0.5 * factor @@ -800,22 +801,22 @@ def _init_weights(self, module: nn.Module): elif isinstance(module, ClvpEncoder): config = self.config.get_text_config() factor = config.initializer_factor - module.projection.weight.data.normal_(mean=0.0, std=factor * (config.hidden_size**-0.5)) + module.projection.weight.normal_(mean=0.0, std=factor * (config.hidden_size**-0.5)) elif isinstance(module, ClvpConditioningEncoder): - module.mel_conv.weight.data.normal_(mean=0.0, std=factor) - module.mel_conv.bias.data.zero_() + module.mel_conv.weight.normal_(mean=0.0, std=factor) + module.mel_conv.bias.zero_() elif isinstance(module, ClvpForCausalLM): for name, p in module.named_parameters(): if name == "c_proj.weight": - p.data.normal_( + p.normal_( mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.num_hidden_layers)) ) elif isinstance(module, ClvpModelForConditionalGeneration): - module.logit_scale.data.fill_(self.config.logit_scale_init_value) + module.logit_scale.fill_(self.config.logit_scale_init_value) if isinstance(module, (nn.LayerNorm, nn.GroupNorm)): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) class ClvpEncoder(ClvpPreTrainedModel): diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index 8bb5bc9bda95..b5e350d79d1a 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -283,19 +283,20 @@ class CodeGenPreTrainedModel(PreTrainedModel): def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) + @torch.no_grad() def _init_weights(self, module): """Initialize the weights.""" if isinstance(module, (nn.Linear,)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) @auto_docstring @@ -560,7 +561,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( """ ) class CodeGenForCausalLM(CodeGenPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "transformer.wte.weight"} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 71eb4870fbf2..cf73b48989cd 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -466,7 +466,7 @@ def forward( @auto_docstring class CohereForCausalLM(CoherePreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index 8a9929dc3ff2..a9c56cd2491c 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -447,7 +447,7 @@ def forward( @auto_docstring class Cohere2ForCausalLM(Cohere2PreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/cohere2_vision/modeling_cohere2_vision.py b/src/transformers/models/cohere2_vision/modeling_cohere2_vision.py index c041ce831fe5..af46c765557a 100644 --- a/src/transformers/models/cohere2_vision/modeling_cohere2_vision.py +++ b/src/transformers/models/cohere2_vision/modeling_cohere2_vision.py @@ -268,7 +268,7 @@ def forward( ) class Cohere2VisionForConditionalGeneration(Cohere2VisionPreTrainedModel, GenerationMixin): _checkpoint_conversion_mapping = {} - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} def __init__(self, config: Cohere2VisionConfig): super().__init__(config) diff --git a/src/transformers/models/colpali/convert_colpali_weights_to_hf.py b/src/transformers/models/colpali/convert_colpali_weights_to_hf.py index 55de46730074..dab4d8651145 100644 --- a/src/transformers/models/colpali/convert_colpali_weights_to_hf.py +++ b/src/transformers/models/colpali/convert_colpali_weights_to_hf.py @@ -144,7 +144,15 @@ def convert_colpali_weights_to_hf( # Tie the weights (following ColPali's `__init__`` step) if model.vlm.language_model._tied_weights_keys is not None: - model._tied_weights_keys = [f"vlm.language_model.{k}" for k in model.vlm.language_model._tied_weights_keys] + prefix = "vlm.language_model." + prefixed_mapping = { + f"{prefix}{target}": f"{prefix}{source}" + for target, source in model.vlm.language_model._tied_weights_keys.items() + } + if isinstance(model._tied_weights_keys, dict): + model._tied_weights_keys.update(prefixed_mapping) + else: + model._tied_weights_keys = prefixed_mapping # Sanity check: ensure all keys are the same state_dict_keys_old = set(original_state_dict.keys()) diff --git a/src/transformers/models/colpali/modeling_colpali.py b/src/transformers/models/colpali/modeling_colpali.py index 16ced722841c..954722e2b144 100644 --- a/src/transformers/models/colpali/modeling_colpali.py +++ b/src/transformers/models/colpali/modeling_colpali.py @@ -38,6 +38,7 @@ class ColPaliPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_flex_attn = True + @torch.no_grad() def _init_weights(self, module): std = ( self.config.initializer_range @@ -46,13 +47,13 @@ def _init_weights(self, module): ) if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() @dataclass @@ -113,7 +114,6 @@ def __init__(self, config: ColPaliConfig): self.vocab_size = config.vlm_config.text_config.vocab_size self.vlm = AutoModelForImageTextToText.from_config(config.vlm_config) - self._tied_weights_keys = [f"vlm.language_model.{k}" for k in (self.vlm._tied_weights_keys or [])] self.embedding_dim = self.config.embedding_dim self.embedding_proj_layer = nn.Linear( @@ -186,9 +186,6 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.vlm.set_output_embeddings(new_embeddings) - def tie_weights(self): - return self.vlm.tie_weights() - def resize_token_embeddings( self, new_num_tokens: Optional[int] = None, diff --git a/src/transformers/models/colqwen2/modeling_colqwen2.py b/src/transformers/models/colqwen2/modeling_colqwen2.py index c3a6c04ee4db..27b897f70490 100644 --- a/src/transformers/models/colqwen2/modeling_colqwen2.py +++ b/src/transformers/models/colqwen2/modeling_colqwen2.py @@ -46,6 +46,7 @@ class ColQwen2PreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_flex_attn = True + @torch.no_grad() def _init_weights(self, module): std = ( self.config.initializer_range @@ -54,13 +55,13 @@ def _init_weights(self, module): ) if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() @dataclass @@ -118,7 +119,6 @@ def __init__(self, config: ColQwen2Config): self.config.vlm_config.text_config.hidden_size, self.embedding_dim, ) - self._tied_weights_keys = [f"vlm.{k}" for k in (self.vlm._tied_weights_keys or [])] self.post_init() @@ -222,9 +222,6 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.vlm.set_output_embeddings(new_embeddings) - def tie_weights(self): - return self.vlm.tie_weights() - def resize_token_embeddings( self, new_num_tokens: Optional[int] = None, diff --git a/src/transformers/models/colqwen2/modular_colqwen2.py b/src/transformers/models/colqwen2/modular_colqwen2.py index 072591abbab8..366256a66d35 100644 --- a/src/transformers/models/colqwen2/modular_colqwen2.py +++ b/src/transformers/models/colqwen2/modular_colqwen2.py @@ -307,7 +307,6 @@ class ColQwen2ForRetrieval(ColPaliForRetrieval): def __init__(self, config: ColQwen2Config): super().__init__(config) del self._tied_weights_keys - self._tied_weights_keys = [f"vlm.{k}" for k in (self.vlm._tied_weights_keys or [])] @can_return_tuple @auto_docstring diff --git a/src/transformers/models/conditional_detr/modeling_conditional_detr.py b/src/transformers/models/conditional_detr/modeling_conditional_detr.py index a9e04ec546b2..b2baf08dcd58 100644 --- a/src/transformers/models/conditional_detr/modeling_conditional_detr.py +++ b/src/transformers/models/conditional_detr/modeling_conditional_detr.py @@ -970,6 +970,7 @@ class ConditionalDetrPreTrainedModel(PreTrainedModel): input_modalities = "image" _no_split_modules = [r"ConditionalDetrConvEncoder", r"ConditionalDetrEncoderLayer", r"ConditionalDetrDecoderLayer"] + @torch.no_grad() def _init_weights(self, module): std = self.config.init_std xavier_std = self.config.init_xavier_std @@ -983,13 +984,13 @@ def _init_weights(self, module): nn.init.uniform_(module.row_embeddings.weight) nn.init.uniform_(module.column_embeddings.weight) if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() # Copied from transformers.models.detr.modeling_detr.DetrEncoder with Detr->ConditionalDetr,DETR->ConditionalDETR diff --git a/src/transformers/models/convbert/modeling_convbert.py b/src/transformers/models/convbert/modeling_convbert.py index 4fd2fea47724..392f8ec79a1c 100755 --- a/src/transformers/models/convbert/modeling_convbert.py +++ b/src/transformers/models/convbert/modeling_convbert.py @@ -108,24 +108,25 @@ class ConvBertPreTrainedModel(PreTrainedModel): base_model_prefix = "convbert" supports_gradient_checkpointing = True + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv1d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, SeparableConv1D): - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, GroupedLinearLayer): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - module.bias.data.zero_() + module.weight.normal_(mean=0.0, std=self.config.initializer_range) + module.bias.zero_() class SeparableConv1D(nn.Module): @@ -707,7 +708,7 @@ def forward(self, generator_hidden_states: torch.FloatTensor) -> torch.FloatTens @auto_docstring class ConvBertForMaskedLM(ConvBertPreTrainedModel): - _tied_weights_keys = ["generator.lm_head.weight"] + _tied_weights_keys = {"generator_lm_head.weight": "convbert.embeddings.word_embeddings.weight"} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/convnext/modeling_convnext.py b/src/transformers/models/convnext/modeling_convnext.py index bcdca46a84e6..c0cbc8e55476 100755 --- a/src/transformers/models/convnext/modeling_convnext.py +++ b/src/transformers/models/convnext/modeling_convnext.py @@ -240,18 +240,19 @@ class ConvNextPreTrainedModel(PreTrainedModel): _no_split_modules = ["ConvNextLayer"] _can_record_outputs = {} # hidden states are collected explicitly + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, (nn.LayerNorm, ConvNextLayerNorm)): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, ConvNextLayer): if module.layer_scale_parameter is not None: - module.layer_scale_parameter.data.fill_(self.config.layer_scale_init_value) + module.layer_scale_parameter.fill_(self.config.layer_scale_init_value) @auto_docstring diff --git a/src/transformers/models/convnextv2/modeling_convnextv2.py b/src/transformers/models/convnextv2/modeling_convnextv2.py index d206ededf0ee..de320116bd16 100644 --- a/src/transformers/models/convnextv2/modeling_convnextv2.py +++ b/src/transformers/models/convnextv2/modeling_convnextv2.py @@ -260,18 +260,19 @@ class ConvNextV2PreTrainedModel(PreTrainedModel): input_modalities = "image" _no_split_modules = ["ConvNextV2Layer"] + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, (nn.LayerNorm, ConvNextV2LayerNorm)): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, ConvNextV2GRN): - module.weight.data.zero_() - module.bias.data.zero_() + module.weight.zero_() + module.bias.zero_() @auto_docstring diff --git a/src/transformers/models/cpmant/modeling_cpmant.py b/src/transformers/models/cpmant/modeling_cpmant.py index fbc64d4b141f..9f8ce38b2b08 100755 --- a/src/transformers/models/cpmant/modeling_cpmant.py +++ b/src/transformers/models/cpmant/modeling_cpmant.py @@ -525,23 +525,24 @@ class CpmAntPreTrainedModel(PreTrainedModel): config: CpmAntConfig base_model_prefix = "cpmant" + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.init_std) + module.weight.normal_(mean=0.0, std=self.config.init_std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.init_std) + module.weight.normal_(mean=0.0, std=self.config.init_std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, CpmAntLayerNorm): - module.weight.data.fill_(1.0) + module.weight.fill_(1.0) elif isinstance(module, CpmAntSegmentPositionEmbedding): - module.relative_attention_bias.data.normal_(mean=0.0, std=self.config.init_std) + module.relative_attention_bias.normal_(mean=0.0, std=self.config.init_std) @auto_docstring @@ -698,7 +699,7 @@ def forward( """ ) class CpmAntForCausalLM(CpmAntPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "cpmant.input_embedding.weight"} def __init__(self, config: CpmAntConfig): super().__init__(config) diff --git a/src/transformers/models/csm/modeling_csm.py b/src/transformers/models/csm/modeling_csm.py index 7d3f87b2953d..b2e13b0867ab 100644 --- a/src/transformers/models/csm/modeling_csm.py +++ b/src/transformers/models/csm/modeling_csm.py @@ -409,12 +409,13 @@ class CsmPreTrainedModel(PreTrainedModel): "attentions": CsmAttention, } + @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) if isinstance(module, CsmCodebooksHead): num_codebooks = module.num_codebooks for i in range(num_codebooks - 1): - module.weight.data[i].normal_(mean=0.0, std=self.config.initializer_range) + module.weight[i].normal_(mean=0.0, std=self.config.initializer_range) @auto_docstring @@ -769,10 +770,9 @@ def forward( """ ) class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin): - _tied_weights_keys = [ - "backbone_model.embed_tokens.embed_audio_tokens.weight", - "depth_decoder.model.embed_tokens.weight", - ] + _tied_weights_keys = { + "backbone_model.embed_tokens.embed_audio_tokens.weight": "depth_decoder.model.embed_tokens.weight" + } def __init__(self, config): super().__init__(config) @@ -790,13 +790,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.backbone_model.embed_tokens = value - def _tie_weights(self): - if self.config.tie_codebooks_embeddings: - self._tie_embedding_weights( - self.backbone_model.embed_tokens.embed_audio_tokens, - self.depth_decoder.model.embed_tokens, - ) - @classmethod def from_pretrained(cls, *args, **kwargs): if kwargs.get("output_loading_info", False): diff --git a/src/transformers/models/csm/modular_csm.py b/src/transformers/models/csm/modular_csm.py index 9ecc7017d83f..95183dced48b 100644 --- a/src/transformers/models/csm/modular_csm.py +++ b/src/transformers/models/csm/modular_csm.py @@ -140,12 +140,13 @@ class CsmPreTrainedModel(PreTrainedModel): "attentions": CsmAttention, } + @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) if isinstance(module, CsmCodebooksHead): num_codebooks = module.num_codebooks for i in range(num_codebooks - 1): - module.weight.data[i].normal_(mean=0.0, std=self.config.initializer_range) + module.weight[i].normal_(mean=0.0, std=self.config.initializer_range) @auto_docstring @@ -420,10 +421,9 @@ def forward(self, **super_kwargs): """ ) class CsmForConditionalGeneration(CsmPreTrainedModel, CsmGenerationMixin): - _tied_weights_keys = [ - "backbone_model.embed_tokens.embed_audio_tokens.weight", - "depth_decoder.model.embed_tokens.weight", - ] + _tied_weights_keys = { + "backbone_model.embed_tokens.embed_audio_tokens.weight": "depth_decoder.model.embed_tokens.weight" + } def __init__(self, config): super().__init__(config) @@ -441,13 +441,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.backbone_model.embed_tokens = value - def _tie_weights(self): - if self.config.tie_codebooks_embeddings: - self._tie_embedding_weights( - self.backbone_model.embed_tokens.embed_audio_tokens, - self.depth_decoder.model.embed_tokens, - ) - @classmethod def from_pretrained(cls, *args, **kwargs): if kwargs.get("output_loading_info", False): diff --git a/src/transformers/models/ctrl/modeling_ctrl.py b/src/transformers/models/ctrl/modeling_ctrl.py index 945ba0431c25..f3a5472410ce 100644 --- a/src/transformers/models/ctrl/modeling_ctrl.py +++ b/src/transformers/models/ctrl/modeling_ctrl.py @@ -188,19 +188,20 @@ class CTRLPreTrainedModel(PreTrainedModel): config: CTRLConfig base_model_prefix = "transformer" + @torch.no_grad() def _init_weights(self, module): """Initialize the weights.""" if isinstance(module, (nn.Linear, Conv1D)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) @auto_docstring @@ -384,7 +385,7 @@ def forward( """ ) class CTRLLMHeadModel(CTRLPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "transformer.w.weight"} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/cvt/modeling_cvt.py b/src/transformers/models/cvt/modeling_cvt.py index 1327a410d03d..55b251a087e7 100644 --- a/src/transformers/models/cvt/modeling_cvt.py +++ b/src/transformers/models/cvt/modeling_cvt.py @@ -489,19 +489,20 @@ class CvtPreTrainedModel(PreTrainedModel): main_input_name = "pixel_values" _no_split_modules = ["CvtLayer"] + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data = nn.init.trunc_normal_(module.weight.data, mean=0.0, std=self.config.initializer_range) + module.weight.copy_(nn.init.trunc_normal_(module.weight, mean=0.0, std=self.config.initializer_range)) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, CvtStage): if self.config.cls_token[module.stage]: - module.cls_token.data = nn.init.trunc_normal_( - module.cls_token.data, mean=0.0, std=self.config.initializer_range + module.cls_token.copy_( + nn.init.trunc_normal_(module.cls_token, mean=0.0, std=self.config.initializer_range) ) diff --git a/src/transformers/models/cwm/modeling_cwm.py b/src/transformers/models/cwm/modeling_cwm.py index cf4d996b0c49..df9760ed1ba7 100644 --- a/src/transformers/models/cwm/modeling_cwm.py +++ b/src/transformers/models/cwm/modeling_cwm.py @@ -437,7 +437,7 @@ def forward( @auto_docstring class CwmForCausalLM(CwmPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/d_fine/modeling_d_fine.py b/src/transformers/models/d_fine/modeling_d_fine.py index 5e79b02f5716..94d4de5d1d48 100644 --- a/src/transformers/models/d_fine/modeling_d_fine.py +++ b/src/transformers/models/d_fine/modeling_d_fine.py @@ -444,6 +444,7 @@ class DFinePreTrainedModel(PreTrainedModel): input_modalities = "image" _no_split_modules = [r"DFineHybridEncoder", r"DFineDecoderLayer"] + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" # initialize linear layer bias value according to a given probability value. @@ -467,7 +468,7 @@ def _init_weights(self, module): module.up.fill_(self.config.up) if isinstance(module, DFineMultiscaleDeformableAttention): - nn.init.constant_(module.sampling_offsets.weight.data, 0.0) + nn.init.constant_(module.sampling_offsets.weight, 0.0) default_dtype = torch.get_default_dtype() thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * ( 2.0 * math.pi / module.n_heads @@ -478,10 +479,10 @@ def _init_weights(self, module): scaling = torch.concat([torch.arange(1, n + 1) for n in module.num_points_list]).reshape(1, -1, 1) grid_init *= scaling with torch.no_grad(): - module.sampling_offsets.bias.data[...] = grid_init.flatten() + module.sampling_offsets.bias[...] = grid_init.flatten() - nn.init.constant_(module.attention_weights.weight.data, 0.0) - nn.init.constant_(module.attention_weights.bias.data, 0.0) + nn.init.constant_(module.attention_weights.weight, 0.0) + nn.init.constant_(module.attention_weights.bias, 0.0) if isinstance(module, DFineModel): prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1) @@ -490,9 +491,9 @@ def _init_weights(self, module): nn.init.constant_(module.enc_score_head.bias, bias) if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() if isinstance(module, DFineGate): bias = float(-math.log((1 - 0.5) / 0.5)) @@ -504,8 +505,8 @@ def _init_weights(self, module): init.constant_(module.reg_conf.layers[-1].weight, 0) if isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() if hasattr(module, "weight_embedding") and self.config.learn_initial_query: nn.init.xavier_uniform_(module.weight_embedding.weight) @@ -1547,7 +1548,7 @@ class DFineObjectDetectionOutput(ModelOutput): ) class DFineForObjectDetection(DFinePreTrainedModel): # When using clones, all layers > 0 will be clones, but layer 0 *is* required - _tied_weights_keys = ["bbox_embed", "class_embed"] + _tied_weights_keys = {"model.decoder.bbox_embed": "bbox_embed", "model.decoder.class_embed": "class_embed"} # We can't initialize the model on meta device as some weights are modified during the initialization _no_split_modules = None diff --git a/src/transformers/models/d_fine/modular_d_fine.py b/src/transformers/models/d_fine/modular_d_fine.py index 01d59e238acb..2996e1aac3f3 100644 --- a/src/transformers/models/d_fine/modular_d_fine.py +++ b/src/transformers/models/d_fine/modular_d_fine.py @@ -588,6 +588,7 @@ def forward( class DFinePreTrainedModel(RTDetrPreTrainedModel): + @torch.no_grad() def _init_weights(self, module): # initialize linear layer bias value according to a given probability value. if isinstance(module, (DFineForObjectDetection, DFineDecoder)): @@ -610,7 +611,7 @@ def _init_weights(self, module): module.up.fill_(self.config.up) if isinstance(module, DFineMultiscaleDeformableAttention): - nn.init.constant_(module.sampling_offsets.weight.data, 0.0) + nn.init.constant_(module.sampling_offsets.weight, 0.0) default_dtype = torch.get_default_dtype() thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * ( 2.0 * math.pi / module.n_heads @@ -621,10 +622,10 @@ def _init_weights(self, module): scaling = torch.concat([torch.arange(1, n + 1) for n in module.num_points_list]).reshape(1, -1, 1) grid_init *= scaling with torch.no_grad(): - module.sampling_offsets.bias.data[...] = grid_init.flatten() + module.sampling_offsets.bias[...] = grid_init.flatten() - nn.init.constant_(module.attention_weights.weight.data, 0.0) - nn.init.constant_(module.attention_weights.bias.data, 0.0) + nn.init.constant_(module.attention_weights.weight, 0.0) + nn.init.constant_(module.attention_weights.bias, 0.0) if isinstance(module, DFineModel): prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1) @@ -633,9 +634,9 @@ def _init_weights(self, module): nn.init.constant_(module.enc_score_head.bias, bias) if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() if isinstance(module, DFineGate): bias = float(-math.log((1 - 0.5) / 0.5)) @@ -647,8 +648,8 @@ def _init_weights(self, module): init.constant_(module.reg_conf.layers[-1].weight, 0) if isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() if hasattr(module, "weight_embedding") and self.config.learn_initial_query: nn.init.xavier_uniform_(module.weight_embedding.weight) diff --git a/src/transformers/models/dab_detr/modeling_dab_detr.py b/src/transformers/models/dab_detr/modeling_dab_detr.py index b5aafb5b8b28..cc48555a72fa 100644 --- a/src/transformers/models/dab_detr/modeling_dab_detr.py +++ b/src/transformers/models/dab_detr/modeling_dab_detr.py @@ -815,6 +815,7 @@ class DabDetrPreTrainedModel(PreTrainedModel): input_modalities = "image" _no_split_modules = [r"DabDetrConvEncoder", r"DabDetrEncoderLayer", r"DabDetrDecoderLayer"] + @torch.no_grad() def _init_weights(self, module): std = self.config.init_std xavier_std = self.config.init_xavier_std @@ -825,24 +826,24 @@ def _init_weights(self, module): nn.init.xavier_uniform_(module.k_linear.weight, gain=xavier_std) nn.init.xavier_uniform_(module.q_linear.weight, gain=xavier_std) if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, DabDetrForObjectDetection): - nn.init.constant_(module.bbox_predictor.layers[-1].weight.data, 0) - nn.init.constant_(module.bbox_predictor.layers[-1].bias.data, 0) + nn.init.constant_(module.bbox_predictor.layers[-1].weight, 0) + nn.init.constant_(module.bbox_predictor.layers[-1].bias, 0) # init prior_prob setting for focal loss prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1) bias_value = -math.log((1 - prior_prob) / prior_prob) - module.class_embed.bias.data.fill_(bias_value) + module.class_embed.bias.fill_(bias_value) elif isinstance(module, nn.PReLU): module.reset_parameters() @@ -1429,10 +1430,7 @@ def forward(self, q, k, mask: Optional[Tensor] = None): ) class DabDetrForObjectDetection(DabDetrPreTrainedModel): # When using clones, all layers > 0 will be clones, but layer 0 *is* required - _tied_weights_keys = [ - r"bbox_predictor\.layers\.\d+\.(weight|bias)", - r"model\.decoder\.bbox_embed\.layers\.\d+\.(weight|bias)", - ] + _tied_weights_keys = {"model.decoder.bbox_embed": "bbox_predictor"} def __init__(self, config: DabDetrConfig): super().__init__(config) diff --git a/src/transformers/models/dac/modeling_dac.py b/src/transformers/models/dac/modeling_dac.py index 81cfcbb931d4..54f1d1a32d49 100644 --- a/src/transformers/models/dac/modeling_dac.py +++ b/src/transformers/models/dac/modeling_dac.py @@ -477,16 +477,17 @@ class DacPreTrainedModel(PreTrainedAudioTokenizerBase): base_model_prefix = "dac" main_input_name = "input_values" + @torch.no_grad() def _init_weights(self, module): if isinstance(module, nn.Conv1d): nn.init.trunc_normal_(module.weight, std=0.02) nn.init.constant_(module.bias, 0) elif isinstance(module, Snake1d): - module.alpha.data.fill_(1.0) + module.alpha.fill_(1.0) elif isinstance(module, nn.ConvTranspose1d): module.reset_parameters() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=0.02) + module.weight.normal_(mean=0.0, std=0.02) def apply_weight_norm(self): weight_norm = nn.utils.weight_norm diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index 2559a29abca1..ac78fb0dea8c 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -480,6 +480,7 @@ class Data2VecAudioPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, Data2VecAudioFeatureProjection): @@ -489,15 +490,15 @@ def _init_weights(self, module): elif isinstance(module, Data2VecAudioPositionalConvLayer): nn.init.constant_(module.conv.bias, 0) elif isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() if module.weight is not None: - module.weight.data.fill_(1.0) + module.weight.fill_(1.0) elif isinstance(module, nn.Conv1d): nn.init.kaiming_normal_(module.weight) diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py index 1ef12699360c..b7a2a7ed2300 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_text.py +++ b/src/transformers/models/data2vec/modeling_data2vec_text.py @@ -494,23 +494,24 @@ class Data2VecTextPreTrainedModel(PreTrainedModel): "cross_attentions": Data2VecTextCrossAttention, } + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): if hasattr(module, "bias") and module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() if hasattr(module, "weight") and module.weight is not None: - module.weight.data.fill_(1.0) + module.weight.fill_(1.0) class Data2VecTextEncoder(nn.Module): @@ -713,7 +714,6 @@ def __init__(self, config): self.decoder = nn.Linear(config.hidden_size, config.vocab_size) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - self.decoder.bias = self.bias def forward(self, features, **kwargs): x = self.dense(features) @@ -725,14 +725,6 @@ def forward(self, features, **kwargs): return x - def _tie_weights(self): - # To tie those two weights if they get disconnected (on TPU or when the bias is resized) - # For accelerate compatibility and to not break backward compatibility - if self.decoder.bias.device.type == "meta": - self.decoder.bias = self.bias - else: - self.bias = self.decoder.bias - class Data2VecTextClassificationHead(nn.Module): """Head for sentence-level classification tasks.""" @@ -762,7 +754,10 @@ def forward(self, features, **kwargs): """ ) class Data2VecTextForCausalLM(Data2VecTextPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] + _tied_weights_keys = { + "lm_head.decoder.weight": "data2vec_text.embeddings.word_embeddings.weight", + "lm_head.decoder.bias": "lm_head.bias", + } def __init__(self, config): super().__init__(config) @@ -861,7 +856,10 @@ def forward( @auto_docstring class Data2VecTextForMaskedLM(Data2VecTextPreTrainedModel): - _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] + _tied_weights_keys = { + "lm_head.decoder.weight": "data2vec_text.embeddings.word_embeddings.weight", + "lm_head.decoder.bias": "lm_head.bias", + } def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/data2vec/modeling_data2vec_vision.py b/src/transformers/models/data2vec/modeling_data2vec_vision.py index b51d7ed0f5d5..ce96ea06324e 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_vision.py +++ b/src/transformers/models/data2vec/modeling_data2vec_vision.py @@ -706,31 +706,32 @@ class Data2VecVisionPreTrainedModel(PreTrainedModel): _keys_to_ignore_on_load_unexpected = [r".*relative_position_index.*"] _supports_sdpa = True + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, Data2VecVisionEmbeddings): - module.cls_token.data.zero_() + module.cls_token.zero_() if module.mask_token is not None: - module.mask_token.data.zero_() + module.mask_token.zero_() if module.position_embeddings is not None: - module.position_embeddings.data.zero_() + module.position_embeddings.zero_() elif isinstance(module, Data2VecVisionRelativePositionBias): - module.relative_position_bias_table.data.zero_() + module.relative_position_bias_table.zero_() elif isinstance(module, Data2VecVisionLayer): if module.lambda_1 is not None: - module.lambda_1.data.fill_(self.config.layer_scale_init_value) - module.lambda_2.data.fill_(self.config.layer_scale_init_value) + module.lambda_1.fill_(self.config.layer_scale_init_value) + module.lambda_2.fill_(self.config.layer_scale_init_value) @auto_docstring diff --git a/src/transformers/models/data2vec/modular_data2vec_audio.py b/src/transformers/models/data2vec/modular_data2vec_audio.py index 142bf7a5e783..db850fa2f1d5 100644 --- a/src/transformers/models/data2vec/modular_data2vec_audio.py +++ b/src/transformers/models/data2vec/modular_data2vec_audio.py @@ -144,6 +144,7 @@ class Data2VecAudioPreTrainedModel(PreTrainedModel, Wav2Vec2PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, Data2VecAudioFeatureProjection): @@ -153,15 +154,15 @@ def _init_weights(self, module): elif isinstance(module, Data2VecAudioPositionalConvLayer): nn.init.constant_(module.conv.bias, 0) elif isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() if module.weight is not None: - module.weight.data.fill_(1.0) + module.weight.fill_(1.0) elif isinstance(module, nn.Conv1d): nn.init.kaiming_normal_(module.weight) diff --git a/src/transformers/models/data2vec/modular_data2vec_text.py b/src/transformers/models/data2vec/modular_data2vec_text.py index 1c91e50db8c7..ad0dc81c8e01 100644 --- a/src/transformers/models/data2vec/modular_data2vec_text.py +++ b/src/transformers/models/data2vec/modular_data2vec_text.py @@ -81,23 +81,24 @@ class Data2VecTextPreTrainedModel(PreTrainedModel): "cross_attentions": Data2VecTextCrossAttention, } + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): if hasattr(module, "bias") and module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() if hasattr(module, "weight") and module.weight is not None: - module.weight.data.fill_(1.0) + module.weight.fill_(1.0) @auto_docstring @@ -119,7 +120,10 @@ class Data2VecTextClassificationHead(RobertaClassificationHead): """ ) class Data2VecTextForCausalLM(Data2VecTextPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] + _tied_weights_keys = { + "lm_head.decoder.weight": "data2vec_text.embeddings.word_embeddings.weight", + "lm_head.decoder.bias": "lm_head.bias", + } def __init__(self, config): super().__init__(config) @@ -218,7 +222,10 @@ def forward( @auto_docstring class Data2VecTextForMaskedLM(Data2VecTextPreTrainedModel): - _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] + _tied_weights_keys = { + "lm_head.decoder.weight": "data2vec_text.embeddings.word_embeddings.weight", + "lm_head.decoder.bias": "lm_head.bias", + } def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index a3f995d35b95..db212fd6378e 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -466,24 +466,25 @@ class DbrxPreTrainedModel(PreTrainedModel): "attentions": DbrxAttention, } + @torch.no_grad() def _init_weights(self, module: nn.Module): std = self.config.initializer_range if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) + module.weight.fill_(1.0) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, DbrxExpertGLU): - module.w1.data.normal_(mean=0.0, std=std) - module.v1.data.normal_(mean=0.0, std=std) - module.w2.data.normal_(mean=0.0, std=std) + module.w1.normal_(mean=0.0, std=std) + module.v1.normal_(mean=0.0, std=std) + module.w2.normal_(mean=0.0, std=std) @auto_docstring @@ -663,7 +664,7 @@ def load_balancing_loss_func( class DbrxForCausalLM(DbrxPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "transformer.wte.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/dbrx/modular_dbrx.py b/src/transformers/models/dbrx/modular_dbrx.py index 46507e44d52d..c9633e20fe1e 100644 --- a/src/transformers/models/dbrx/modular_dbrx.py +++ b/src/transformers/models/dbrx/modular_dbrx.py @@ -336,24 +336,25 @@ class DbrxPreTrainedModel(PreTrainedModel): "attentions": DbrxAttention, } + @torch.no_grad() def _init_weights(self, module: nn.Module): std = self.config.initializer_range if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) + module.weight.fill_(1.0) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, DbrxExpertGLU): - module.w1.data.normal_(mean=0.0, std=std) - module.v1.data.normal_(mean=0.0, std=std) - module.w2.data.normal_(mean=0.0, std=std) + module.w1.normal_(mean=0.0, std=std) + module.v1.normal_(mean=0.0, std=std) + module.w2.normal_(mean=0.0, std=std) @auto_docstring @@ -451,7 +452,7 @@ def forward( class DbrxForCausalLM(DbrxPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "transformer.wte.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/deberta/modeling_deberta.py b/src/transformers/models/deberta/modeling_deberta.py index e5432c730404..3b2ea9b53724 100644 --- a/src/transformers/models/deberta/modeling_deberta.py +++ b/src/transformers/models/deberta/modeling_deberta.py @@ -614,24 +614,25 @@ class DebertaPreTrainedModel(PreTrainedModel): _keys_to_ignore_on_load_unexpected = ["position_embeddings"] supports_gradient_checkpointing = True + @torch.no_grad() def _init_weights(self, module): """Initialize the weights.""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, (nn.LayerNorm, DebertaLayerNorm)): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() elif isinstance(module, DisentangledSelfAttention): - module.q_bias.data.zero_() - module.v_bias.data.zero_() + module.q_bias.zero_() + module.v_bias.zero_() elif isinstance(module, (LegacyDebertaLMPredictionHead, DebertaLMPredictionHead)): - module.bias.data.zero_() + module.bias.zero_() @auto_docstring @@ -761,16 +762,10 @@ def __init__(self, config): self.embedding_size = getattr(config, "embedding_size", config.hidden_size) # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. - self.decoder = nn.Linear(self.embedding_size, config.vocab_size, bias=False) + self.decoder = nn.Linear(self.embedding_size, config.vocab_size, bias=True) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` - self.decoder.bias = self.bias - - def _tie_weights(self): - self.decoder.bias = self.bias - def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -828,7 +823,10 @@ def forward(self, sequence_output, word_embeddings): @auto_docstring class DebertaForMaskedLM(DebertaPreTrainedModel): - _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"] + _tied_weights_keys = { + "cls.predictions.decoder.bias": "cls.predictions.bias", + "cls.predictions.decoder.weight": "deberta.embeddings.word_embeddings.weight", + } def __init__(self, config): super().__init__(config) @@ -837,7 +835,9 @@ def __init__(self, config): if self.legacy: self.cls = LegacyDebertaOnlyMLMHead(config) else: - self._tied_weights_keys = ["lm_predictions.lm_head.weight", "deberta.embeddings.word_embeddings.weight"] + self._tied_weights_keys = { + "lm_predictions.lm_head.weight": "deberta.embeddings.word_embeddings.weight", + } self.lm_predictions = DebertaOnlyMLMHead(config) # Initialize weights and apply final processing diff --git a/src/transformers/models/deberta_v2/modeling_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_deberta_v2.py index 28e6c87c71a5..791e433e4d2c 100644 --- a/src/transformers/models/deberta_v2/modeling_deberta_v2.py +++ b/src/transformers/models/deberta_v2/modeling_deberta_v2.py @@ -693,21 +693,22 @@ class DebertaV2PreTrainedModel(PreTrainedModel): _keys_to_ignore_on_load_unexpected = ["position_embeddings"] supports_gradient_checkpointing = True + @torch.no_grad() def _init_weights(self, module): """Initialize the weights.""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() elif isinstance(module, (LegacyDebertaV2LMPredictionHead, DebertaV2LMPredictionHead)): - module.bias.data.zero_() + module.bias.zero_() @auto_docstring @@ -839,16 +840,10 @@ def __init__(self, config): self.embedding_size = getattr(config, "embedding_size", config.hidden_size) # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. - self.decoder = nn.Linear(self.embedding_size, config.vocab_size, bias=False) + self.decoder = nn.Linear(self.embedding_size, config.vocab_size) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` - self.decoder.bias = self.bias - - def _tie_weights(self): - self.decoder.bias = self.bias - def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -903,7 +898,10 @@ def forward(self, sequence_output, word_embeddings): @auto_docstring class DebertaV2ForMaskedLM(DebertaV2PreTrainedModel): - _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"] + _tied_weights_keys = { + "cls.predictions.decoder.bias": "cls.predictions.bias", + "cls.predictions.decoder.weight": "deberta.embeddings.word_embeddings.weight", + } _keys_to_ignore_on_load_unexpected = [r"mask_predictions.*"] def __init__(self, config): @@ -913,7 +911,9 @@ def __init__(self, config): if self.legacy: self.cls = LegacyDebertaV2OnlyMLMHead(config) else: - self._tied_weights_keys = ["lm_predictions.lm_head.weight", "deberta.embeddings.word_embeddings.weight"] + self._tied_weights_keys = { + "lm_predictions.lm_head.weight": "deberta.embeddings.word_embeddings.weight", + } self.lm_predictions = DebertaV2OnlyMLMHead(config) # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index 48d808432628..180d78cb32ce 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -374,19 +374,20 @@ class DecisionTransformerGPT2PreTrainedModel(PreTrainedModel): def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) + @torch.no_grad() def _init_weights(self, module): """Initialize the weights.""" if isinstance(module, (nn.Linear, Conv1D)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale @@ -394,10 +395,11 @@ def _init_weights(self, module): # > -- GPT-2 :: https://openai.com/blog/better-language-models/ # # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py - for name, p in module.named_parameters(): - if "c_proj" in name and "weight" in name: - # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block - p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) + if isinstance(module, PreTrainedModel): + for name, p in module.named_parameters(): + if "c_proj" in name and "weight" in name: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + p.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel): @@ -615,19 +617,20 @@ class DecisionTransformerPreTrainedModel(PreTrainedModel): main_input_name = "states" supports_gradient_checkpointing = False + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) @auto_docstring( diff --git a/src/transformers/models/deepseek_v2/configuration_deepseek_v2.py b/src/transformers/models/deepseek_v2/configuration_deepseek_v2.py index aad76507d3a6..f7b81216d332 100644 --- a/src/transformers/models/deepseek_v2/configuration_deepseek_v2.py +++ b/src/transformers/models/deepseek_v2/configuration_deepseek_v2.py @@ -127,8 +127,7 @@ class DeepseekV2Config(PreTrainedConfig): "layers.*.self_attn.q_b_proj": "colwise", "layers.*.self_attn.kv_b_proj": "colwise", "layers.*.self_attn.o_proj": "rowwise", - "layers.*.mlp.gate_proj": "colwise", - "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.gate_up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } base_model_pp_plan = { diff --git a/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py b/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py index a3f4eb0d3340..2bd5aa73c249 100644 --- a/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py +++ b/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py @@ -42,19 +42,23 @@ from .configuration_deepseek_v2 import DeepseekV2Config -class DeepseekV2Experts(nn.ModuleList): - """ - ModuleList of experts. - """ +class DeepseekV2Experts(nn.Module): + """Collection of expert weights stored as 3D tensors.""" def __init__(self, config): super().__init__() self.num_experts = config.n_routed_experts - for _ in range(config.n_routed_experts): - self.append(DeepseekV2MLP(config, intermediate_size=config.moe_intermediate_size)) + self.hidden_dim = config.hidden_size + self.intermediate_dim = config.moe_intermediate_size + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) + self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) + self.act_fn = ACT2FN[config.hidden_act] def forward( - self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, ) -> torch.Tensor: """ Args: @@ -65,14 +69,24 @@ def forward( (batch_size * sequence_length, hidden_dim) """ final_hidden_states = torch.zeros_like(hidden_states) - expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0) + num_experts = top_k_weights.shape[1] + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() for expert_idx in expert_hit: - idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) - current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1]) - current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None] - final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + expert_idx = expert_idx[0] + if expert_idx == num_experts: + continue + _, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) + current_hidden_states = self.act_fn(gate) * up + current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) + current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None] + final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) + return final_hidden_states @@ -111,6 +125,7 @@ def route_tokens_to_experts(self, router_logits): topk_weight, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False) topk_weight = topk_weight * self.routed_scaling_factor + topk_weight = torch.zeros_like(router_logits).scatter_(1, topk_idx, topk_weight) return topk_idx, topk_weight def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -459,10 +474,11 @@ class DeepseekV2PreTrainedModel(PreTrainedModel): "attentions": DeepseekV2Attention, } + @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) if isinstance(module, DeepseekV2Moe): - module.gate.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.gate.weight.normal_(mean=0.0, std=self.config.initializer_range) @auto_docstring @@ -546,7 +562,7 @@ def forward( @auto_docstring class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/deepseek_v2/modular_deepseek_v2.py b/src/transformers/models/deepseek_v2/modular_deepseek_v2.py index 7e60d5c858b3..b6fa08ddd890 100644 --- a/src/transformers/models/deepseek_v2/modular_deepseek_v2.py +++ b/src/transformers/models/deepseek_v2/modular_deepseek_v2.py @@ -142,8 +142,7 @@ class DeepseekV2Config(LlamaConfig): "layers.*.self_attn.q_b_proj": "colwise", "layers.*.self_attn.kv_b_proj": "colwise", "layers.*.self_attn.o_proj": "rowwise", - "layers.*.mlp.gate_proj": "colwise", - "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.gate_up_proj": "colwise", "layers.*.mlp.down_proj": "rowwise", } @@ -224,12 +223,10 @@ def apply_rotary_emb( return xq_out, xk_out -class DeepseekV2Experts(Qwen2MoeExperts, nn.ModuleList): +class DeepseekV2Experts(Qwen2MoeExperts): def __init__(self, config): - nn.ModuleList.__init__(self) + super().__init__(config) self.num_experts = config.n_routed_experts - for _ in range(config.n_routed_experts): - self.append(DeepseekV2MLP(config, intermediate_size=config.moe_intermediate_size)) class DeepseekV2Moe(nn.Module): @@ -267,6 +264,7 @@ def route_tokens_to_experts(self, router_logits): topk_weight, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False) topk_weight = topk_weight * self.routed_scaling_factor + topk_weight = torch.zeros_like(router_logits).scatter_(1, topk_idx, topk_weight) return topk_idx, topk_weight def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -439,10 +437,11 @@ def __init__(self, config: DeepseekV2Config, layer_idx: int): class DeepseekV2PreTrainedModel(LlamaPreTrainedModel): _can_compile_fullgraph = False + @torch.no_grad() def _init_weights(self, module): PreTrainedModel._init_weights(self, module) if isinstance(module, DeepseekV2Moe): - module.gate.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.gate.weight.normal_(mean=0.0, std=self.config.initializer_range) class DeepseekV2Model(LlamaModel): diff --git a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py index 51e720a2eedf..98df80837e58 100644 --- a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py @@ -149,19 +149,23 @@ def forward(self, hidden_states): return router_logits -class DeepseekV3NaiveMoe(nn.ModuleList): - """ - ModuleList of experts. - """ +class DeepseekV3NaiveMoe(nn.Module): + """Collection of expert weights stored as 3D tensors.""" def __init__(self, config): super().__init__() self.num_experts = config.num_local_experts - for _ in range(self.num_experts): - self.append(DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size)) + self.hidden_dim = config.hidden_size + self.intermediate_dim = config.intermediate_size + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) + self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) + self.act_fn = ACT2FN[config.hidden_act] def forward( - self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, ) -> torch.Tensor: """ Args: @@ -172,14 +176,24 @@ def forward( (batch_size * sequence_length, hidden_dim) """ final_hidden_states = torch.zeros_like(hidden_states) - expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0) + num_experts = top_k_weights.shape[1] + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() for expert_idx in expert_hit: - idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) - current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1]) - current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None] - final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + expert_idx = expert_idx[0] + if expert_idx == num_experts: + continue + _, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) + current_hidden_states = self.act_fn(gate) * up + current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) + current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None] + final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) + return final_hidden_states @@ -542,10 +556,11 @@ class DeepseekV3PreTrainedModel(PreTrainedModel): "attentions": DeepseekV3Attention, } + @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) if isinstance(module, DeepseekV3TopkRouter): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) @auto_docstring @@ -631,7 +646,7 @@ def forward( @auto_docstring class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/deepseek_v3/modular_deepseek_v3.py b/src/transformers/models/deepseek_v3/modular_deepseek_v3.py index 3bc9d45e79e9..5a92d135870d 100644 --- a/src/transformers/models/deepseek_v3/modular_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modular_deepseek_v3.py @@ -102,12 +102,10 @@ def forward(self, hidden_states): return router_logits -class DeepseekV3NaiveMoe(MixtralExperts, nn.ModuleList): +class DeepseekV3NaiveMoe(MixtralExperts): def __init__(self, config): - nn.ModuleList.__init__(self) + super().__init__(config) self.num_experts = config.num_local_experts - for _ in range(self.num_experts): - self.append(DeepseekV3MLP(config, intermediate_size=config.moe_intermediate_size)) class DeepseekV3MoE(nn.Module): @@ -306,10 +304,11 @@ def __init__(self, config: DeepseekV3Config, layer_idx: int): class DeepseekV3PreTrainedModel(LlamaPreTrainedModel): _can_compile_fullgraph = False + @torch.no_grad() def _init_weights(self, module): PreTrainedModel._init_weights(self, module) if isinstance(module, DeepseekV3TopkRouter): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) class DeepseekV3Model(LlamaModel): diff --git a/src/transformers/models/deepseek_vl/modeling_deepseek_vl.py b/src/transformers/models/deepseek_vl/modeling_deepseek_vl.py index 41b6460e12bc..849eb5ef34f0 100644 --- a/src/transformers/models/deepseek_vl/modeling_deepseek_vl.py +++ b/src/transformers/models/deepseek_vl/modeling_deepseek_vl.py @@ -132,13 +132,14 @@ class DeepseekVLPreTrainedModel(PreTrainedModel): _can_compile_fullgraph = True _supports_param_buffer_assignment = False + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" # Required only for Linear layer in DeepseekVLAligner if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.text_config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.text_config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() @auto_docstring @@ -243,7 +244,7 @@ def forward( class DeepseekVLForConditionalGeneration(DeepseekVLPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["model.language_model.embed_tokens.weight", "lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} output_modalities = "text" _can_compile_fullgraph = True diff --git a/src/transformers/models/deepseek_vl/modular_deepseek_vl.py b/src/transformers/models/deepseek_vl/modular_deepseek_vl.py index 9b894b7f7505..1cc14a35bf3a 100644 --- a/src/transformers/models/deepseek_vl/modular_deepseek_vl.py +++ b/src/transformers/models/deepseek_vl/modular_deepseek_vl.py @@ -134,13 +134,14 @@ def forward(self, vision_encodings: torch.Tensor) -> torch.Tensor: class DeepseekVLPreTrainedModel(JanusPreTrainedModel): _no_split_modules = ["LlamaDecoderLayer"] + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" # Required only for Linear layer in DeepseekVLAligner if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.text_config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.text_config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() @auto_docstring diff --git a/src/transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py b/src/transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py index 531da23a5c51..17fed96166ce 100644 --- a/src/transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +++ b/src/transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py @@ -214,21 +214,22 @@ class DeepseekVLHybridPreTrainedModel(PreTrainedModel): _can_compile_fullgraph = True _supports_param_buffer_assignment = False + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.text_config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.text_config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Conv2d): nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, DeepseekVLHybridLayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() elif isinstance(module, DeepseekVLHybridModel): - module.high_res_vision_alpha.data.zero_() + module.high_res_vision_alpha.zero_() DEEPSEEK_VL_COMMON_CUSTOM_ARGS = r""" @@ -388,7 +389,7 @@ def get_high_res_image_features(self, pixel_values): class DeepseekVLHybridForConditionalGeneration(DeepseekVLHybridPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["model.language_model.embed_tokens.weight", "lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} output_modalities = "text" _can_compile_fullgraph = True diff --git a/src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py b/src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py index 27062cfd06b2..c8f5be1638d4 100644 --- a/src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +++ b/src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py @@ -216,21 +216,22 @@ def forward( class DeepseekVLHybridPreTrainedModel(DeepseekVLPreTrainedModel): + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.text_config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.text_config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Conv2d): nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, DeepseekVLHybridLayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() elif isinstance(module, DeepseekVLHybridModel): - module.high_res_vision_alpha.data.zero_() + module.high_res_vision_alpha.zero_() class DeepseekVLHybridModel(DeepseekVLModel): diff --git a/src/transformers/models/deformable_detr/modeling_deformable_detr.py b/src/transformers/models/deformable_detr/modeling_deformable_detr.py index 04a45b413c73..f60d2c29eae2 100755 --- a/src/transformers/models/deformable_detr/modeling_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modeling_deformable_detr.py @@ -14,7 +14,6 @@ # limitations under the License. """PyTorch Deformable DETR model.""" -import copy import math import warnings from dataclasses import dataclass @@ -234,10 +233,6 @@ class DeformableDetrObjectDetectionOutput(ModelOutput): enc_outputs_coord_logits: Optional[torch.FloatTensor] = None -def _get_clones(module, N): - return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) - - def inverse_sigmoid(x, eps=1e-5): x = x.clamp(min=0, max=1) x1 = x.clamp(min=eps) @@ -931,6 +926,7 @@ class DeformableDetrPreTrainedModel(PreTrainedModel): r"DeformableDetrDecoderLayer", ] + @torch.no_grad() def _init_weights(self, module): std = self.config.init_std @@ -938,7 +934,7 @@ def _init_weights(self, module): nn.init.uniform_(module.row_embeddings.weight) nn.init.uniform_(module.column_embeddings.weight) elif isinstance(module, DeformableDetrMultiscaleDeformableAttention): - nn.init.constant_(module.sampling_offsets.weight.data, 0.0) + nn.init.constant_(module.sampling_offsets.weight, 0.0) default_dtype = torch.get_default_dtype() thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * ( 2.0 * math.pi / module.n_heads @@ -953,23 +949,23 @@ def _init_weights(self, module): grid_init[:, :, i, :] *= i + 1 with torch.no_grad(): module.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) - nn.init.constant_(module.attention_weights.weight.data, 0.0) - nn.init.constant_(module.attention_weights.bias.data, 0.0) - nn.init.xavier_uniform_(module.value_proj.weight.data) - nn.init.constant_(module.value_proj.bias.data, 0.0) - nn.init.xavier_uniform_(module.output_proj.weight.data) - nn.init.constant_(module.output_proj.bias.data, 0.0) + nn.init.constant_(module.attention_weights.weight, 0.0) + nn.init.constant_(module.attention_weights.bias, 0.0) + nn.init.xavier_uniform_(module.value_proj.weight) + nn.init.constant_(module.value_proj.bias, 0.0) + nn.init.xavier_uniform_(module.output_proj.weight) + nn.init.constant_(module.output_proj.bias, 0.0) elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() if hasattr(module, "reference_points") and not self.config.two_stage: - nn.init.xavier_uniform_(module.reference_points.weight.data, gain=1.0) - nn.init.constant_(module.reference_points.bias.data, 0.0) + nn.init.xavier_uniform_(module.reference_points.weight, gain=1.0) + nn.init.constant_(module.reference_points.bias, 0.0) if hasattr(module, "level_embed"): nn.init.normal_(module.level_embed) @@ -1703,13 +1699,15 @@ def forward(self, x): ) class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel): # When using clones, all layers > 0 will be clones, but layer 0 *is* required - _tied_weights_keys = [r"bbox_embed\.[1-9]\d*", r"class_embed\.[1-9]\d*"] # We can't initialize the model on meta device as some weights are modified during the initialization _no_split_modules = None + _tied_weights_keys = { + r"bbox_embed.(?![0])\d+": "bbox_embed.0", + r"class_embed.(?![0])\d+": "class_embed.0", + } def __init__(self, config: DeformableDetrConfig): super().__init__(config) - # Deformable DETR encoder-decoder model self.model = DeformableDetrModel(config) # Detection heads on top @@ -1720,23 +1718,24 @@ def __init__(self, config: DeformableDetrConfig): output_dim=4, num_layers=3, ) - # if two-stage, the last class_embed and bbox_embed is for region proposal generation num_pred = (config.decoder_layers + 1) if config.two_stage else config.decoder_layers + self.class_embed = nn.ModuleList([nn.Linear(config.d_model, config.num_labels) for _ in range(num_pred)]) + self.bbox_embed = nn.ModuleList( + [ + DeformableDetrMLPPredictionHead( + input_dim=config.d_model, + hidden_dim=config.d_model, + output_dim=4, + num_layers=3, + ) + for _ in range(num_pred) + ] + ) if config.with_box_refine: - self.class_embed = _get_clones(self.class_embed, num_pred) - self.bbox_embed = _get_clones(self.bbox_embed, num_pred) - # hack implementation for iterative bounding box refinement - self.model.decoder.bbox_embed = self.bbox_embed - else: - self.class_embed = nn.ModuleList([self.class_embed for _ in range(num_pred)]) - self.bbox_embed = nn.ModuleList([self.bbox_embed for _ in range(num_pred)]) - self.model.decoder.bbox_embed = None + self._tied_weights_keys["model.decoder.bbox_embed"] = "bbox_embed" if config.two_stage: - # hack implementation for two-stage - self.model.decoder.class_embed = self.class_embed - - # Initialize weights and apply final processing + self._tied_weights_keys["model.decoder.class_embed"] = "class_embed" self.post_init() @auto_docstring diff --git a/src/transformers/models/deit/modeling_deit.py b/src/transformers/models/deit/modeling_deit.py index 4d6a16c0a438..b80a02d83a14 100644 --- a/src/transformers/models/deit/modeling_deit.py +++ b/src/transformers/models/deit/modeling_deit.py @@ -366,25 +366,28 @@ class DeiTPreTrainedModel(PreTrainedModel): "attentions": DeiTSelfAttention, } + @torch.no_grad() def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid # `trunc_normal_cpu` not implemented in `half` issues - module.weight.data = nn.init.trunc_normal_( - module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range - ).to(module.weight.dtype) + module.weight.copy_( + nn.init.trunc_normal_(module.weight.to(torch.float32), mean=0.0, std=self.config.initializer_range).to( + module.weight.dtype + ) + ) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, DeiTEmbeddings): - module.cls_token.data.zero_() - module.position_embeddings.data.zero_() - module.distillation_token.data.zero_() + module.cls_token.zero_() + module.position_embeddings.zero_() + module.distillation_token.zero_() if module.mask_token is not None: - module.mask_token.data.zero_() + module.mask_token.zero_() @auto_docstring diff --git a/src/transformers/models/deprecated/deta/modeling_deta.py b/src/transformers/models/deprecated/deta/modeling_deta.py index 4c881c4365a0..d7336d304a76 100644 --- a/src/transformers/models/deprecated/deta/modeling_deta.py +++ b/src/transformers/models/deprecated/deta/modeling_deta.py @@ -988,6 +988,7 @@ class DetaPreTrainedModel(PreTrainedModel): _no_split_modules = [r"DetaBackboneWithPositionalEncodings", r"DetaEncoderLayer", r"DetaDecoderLayer"] supports_gradient_checkpointing = True + @torch.no_grad() def _init_weights(self, module): std = self.config.init_std @@ -997,16 +998,16 @@ def _init_weights(self, module): elif isinstance(module, DetaMultiscaleDeformableAttention): module._reset_parameters() elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() if hasattr(module, "reference_points") and not self.config.two_stage: - nn.init.xavier_uniform_(module.reference_points.weight.data, gain=1.0) - nn.init.constant_(module.reference_points.bias.data, 0.0) + nn.init.xavier_uniform_(module.reference_points.weight, gain=1.0) + nn.init.constant_(module.reference_points.bias, 0.0) if hasattr(module, "level_embed"): nn.init.normal_(module.level_embed) @@ -1793,13 +1794,12 @@ def forward( ) class DetaForObjectDetection(DetaPreTrainedModel): # When using clones, all layers > 0 will be clones, but layer 0 *is* required - _tied_weights_keys = [r"bbox_embed\.\d+", r"class_embed\.\d+"] # We can't initialize the model on meta device as some weights are modified during the initialization _no_split_modules = None def __init__(self, config: DetaConfig): super().__init__(config) - + self._tied_weights_keys = {} # Deformable DETR encoder-decoder model self.model = DetaModel(config) @@ -1823,6 +1823,11 @@ def __init__(self, config: DetaConfig): nn.init.constant_(self.bbox_embed[0].layers[-1].bias.data[2:], -2.0) # hack implementation for iterative bounding box refinement self.model.decoder.bbox_embed = self.bbox_embed + self._tied_weights_keys.update( + { + "model.decoder.bbox_embed ": "bbox_embed", + } + ) else: nn.init.constant_(self.bbox_embed.layers[-1].bias.data[2:], -2.0) self.class_embed = nn.ModuleList([self.class_embed for _ in range(num_pred)]) @@ -1831,6 +1836,11 @@ def __init__(self, config: DetaConfig): if config.two_stage: # hack implementation for two-stage self.model.decoder.class_embed = self.class_embed + self._tied_weights_keys.update( + { + "model.decoder.class_embed ": "class_embed", + } + ) for box_embed in self.bbox_embed: nn.init.constant_(box_embed.layers[-1].bias.data[2:], 0.0) diff --git a/src/transformers/models/deprecated/efficientformer/modeling_efficientformer.py b/src/transformers/models/deprecated/efficientformer/modeling_efficientformer.py index 2167df912d87..f3303da0f6fd 100644 --- a/src/transformers/models/deprecated/efficientformer/modeling_efficientformer.py +++ b/src/transformers/models/deprecated/efficientformer/modeling_efficientformer.py @@ -498,15 +498,16 @@ class EfficientFormerPreTrainedModel(PreTrainedModel): main_input_name = "pixel_values" supports_gradient_checkpointing = False + @torch.no_grad() def _init_weights(self, module: nn.Module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) EFFICIENTFORMER_START_DOCSTRING = r""" diff --git a/src/transformers/models/deprecated/ernie_m/modeling_ernie_m.py b/src/transformers/models/deprecated/ernie_m/modeling_ernie_m.py index 1aaccbe3f146..7ed73c5a49a8 100755 --- a/src/transformers/models/deprecated/ernie_m/modeling_ernie_m.py +++ b/src/transformers/models/deprecated/ernie_m/modeling_ernie_m.py @@ -368,19 +368,20 @@ class ErnieMPreTrainedModel(PreTrainedModel): config: ErnieMConfig base_model_prefix = "ernie_m" + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) ERNIE_M_START_DOCSTRING = r""" diff --git a/src/transformers/models/deprecated/gptsan_japanese/modeling_gptsan_japanese.py b/src/transformers/models/deprecated/gptsan_japanese/modeling_gptsan_japanese.py index a0aa6c8b5c17..aca490be1430 100644 --- a/src/transformers/models/deprecated/gptsan_japanese/modeling_gptsan_japanese.py +++ b/src/transformers/models/deprecated/gptsan_japanese/modeling_gptsan_japanese.py @@ -528,60 +528,61 @@ def dummy_inputs(self): } return dummy_inputs + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" factor = self.config.initializer_factor # Used for testing weights initialization if isinstance(module, nn.LayerNorm): - module.weight.data.fill_(factor * 1.0) - module.bias.data.zero_() + module.weight.fill_(factor * 1.0) + module.bias.zero_() elif isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + module.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module, "bias") and module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=factor * 1.0) + module.weight.normal_(mean=0.0, std=factor * 1.0) elif isinstance(module, GPTSanJapaneseModel): # Mesh TensorFlow embeddings initialization # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 - module.embed_tokens.weight.data.normal_(mean=0.0, std=factor * 1.0) - module.position_embeddings.weight.data.normal_(mean=0.0, std=factor * 1.0) + module.embed_tokens.weight.normal_(mean=0.0, std=factor * 1.0) + module.position_embeddings.weight.normal_(mean=0.0, std=factor * 1.0) if hasattr(module, "extra_position_embeddings") and module.extra_position_embeddings is not None: - module.extra_position_embeddings.weight.data.normal_(mean=0.0, std=factor * 1.0) + module.extra_position_embeddings.weight.normal_(mean=0.0, std=factor * 1.0) elif isinstance(module, (GPTSanJapaneseModel, GPTSanJapaneseForConditionalGeneration)): # Mesh TensorFlow embeddings initialization # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 - module.final_logits_bias.data.normal_(mean=0.0, std=factor * 1.0) + module.final_logits_bias.normal_(mean=0.0, std=factor * 1.0) if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: - module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0) + module.lm_head.weight.normal_(mean=0.0, std=factor * 1.0) elif isinstance(module, GPTSanJapaneseDenseActDense): # Mesh TensorFlow FF initialization # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89 - module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + module.wi.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.wi, "bias") and module.wi.bias is not None: - module.wi.bias.data.zero_() - module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + module.wi.bias.zero_() + module.wo.weight.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) if hasattr(module.wo, "bias") and module.wo.bias is not None: - module.wo.bias.data.zero_() + module.wo.bias.zero_() elif isinstance(module, GPTSanJapaneseAttention): # Multi-headed attention d_model = self.config.d_model key_value_proj_dim = self.config.d_model n_heads = self.config.num_heads - module.k_proj.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) - module.v_proj.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) - module.q_proj.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) - module.out_proj.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) + module.k_proj.weight.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) + module.v_proj.weight.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) + module.q_proj.weight.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) + module.out_proj.weight.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) elif isinstance(module, GPTSanJapaneseSparseMLP): # Mesh TensorFlow attention initialization to avoid scaling before softmax # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 d_model = self.config.d_model key_value_proj_dim = self.config.d_model n_heads = self.config.num_heads - module.router.classifier.weight.data.normal_(mean=0.0, std=factor * 1) + module.router.classifier.weight.normal_(mean=0.0, std=factor * 1) for idx in range(self.config.num_experts): - module.experts[f"expert_{idx}"].wi.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) - module.experts[f"expert_{idx}"].wo.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.experts[f"expert_{idx}"].wi.weight.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.experts[f"expert_{idx}"].wo.weight.normal_(mean=0.0, std=factor * (d_model**-0.5)) def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id @@ -853,7 +854,7 @@ def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.T class GPTSanJapaneseForConditionalGeneration(GPTSanJapanesePreTrainedModel): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} def __init__(self, config: GPTSanJapaneseConfig): super().__init__(config) diff --git a/src/transformers/models/deprecated/graphormer/modeling_graphormer.py b/src/transformers/models/deprecated/graphormer/modeling_graphormer.py index b3e8ea742c8d..bc74d7a5e7d5 100755 --- a/src/transformers/models/deprecated/graphormer/modeling_graphormer.py +++ b/src/transformers/models/deprecated/graphormer/modeling_graphormer.py @@ -721,7 +721,7 @@ def init_graphormer_params(self, module: Union[nn.Linear, nn.Embedding, Graphorm if isinstance(module, nn.Linear): self.normal_(module.weight.data) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() if isinstance(module, nn.Embedding): self.normal_(module.weight.data) if module.padding_idx is not None: @@ -731,6 +731,7 @@ def init_graphormer_params(self, module: Union[nn.Linear, nn.Embedding, Graphorm self.normal_(module.k_proj.weight.data) self.normal_(module.v_proj.weight.data) + @torch.no_grad() def _init_weights( self, module: Union[ @@ -742,28 +743,28 @@ def _init_weights( """ if isinstance(module, (nn.Linear, nn.Conv2d)): # We might be missing part of the Linear init, dependent on the layer num - module.weight.data.normal_(mean=0.0, std=0.02) + module.weight.normal_(mean=0.0, std=0.02) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=0.02) + module.weight.normal_(mean=0.0, std=0.02) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, GraphormerMultiheadAttention): - module.q_proj.weight.data.normal_(mean=0.0, std=0.02) - module.k_proj.weight.data.normal_(mean=0.0, std=0.02) - module.v_proj.weight.data.normal_(mean=0.0, std=0.02) + module.q_proj.weight.normal_(mean=0.0, std=0.02) + module.k_proj.weight.normal_(mean=0.0, std=0.02) + module.v_proj.weight.normal_(mean=0.0, std=0.02) module.reset_parameters() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, GraphormerGraphEncoder): if module.apply_graphormer_init: module.apply(self.init_graphormer_params) elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) class GraphormerModel(GraphormerPreTrainedModel): diff --git a/src/transformers/models/deprecated/jukebox/modeling_jukebox.py b/src/transformers/models/deprecated/jukebox/modeling_jukebox.py index ac8597361522..d71fadd8bf6c 100755 --- a/src/transformers/models/deprecated/jukebox/modeling_jukebox.py +++ b/src/transformers/models/deprecated/jukebox/modeling_jukebox.py @@ -601,22 +601,23 @@ class JukeboxVQVAE(PreTrainedModel): config: JukeboxVQVAEConfig base_model_prefix = "vqvae" + @torch.no_grad() def _init_weights(self, module): if isinstance(module, nn.Embedding): # embed_tokens - module.weight.data.normal_(mean=0.0, std=0.02 * self.config.init_scale) + module.weight.normal_(mean=0.0, std=0.02 * self.config.init_scale) elif isinstance(module, JukeboxConv1D): if self.config.zero_out: - module.weight.data.zero_() + module.weight.zero_() else: - module.weight.data.normal_(mean=0.0, std=0.02 * self.config.init_scale) + module.weight.normal_(mean=0.0, std=0.02 * self.config.init_scale) elif isinstance(module, JukeboxResConv1DBlock) and self.config.zero_out: - module.conv1d_2.weight.data.zero_() - module.conv1d_2.bias.data.zero_() + module.conv1d_2.weight.zero_() + module.conv1d_2.bias.zero_() if isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() def __init__(self, config: JukeboxVQVAEConfig): super().__init__(config) @@ -1790,32 +1791,33 @@ class JukeboxPrior(PreTrainedModel): config: JukeboxPriorConfig + @torch.no_grad() def _init_weights(self, module): init_scale = self.config.init_scale if isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=0.02 * init_scale) + module.weight.normal_(mean=0.0, std=0.02 * init_scale) elif isinstance(module, JukeboxConv1D): if self.config.zero_out: - module.weight.data.zero_() + module.weight.zero_() else: - module.weight.data.normal_(mean=0.0, std=0.02 * init_scale) + module.weight.normal_(mean=0.0, std=0.02 * init_scale) elif isinstance(module, JukeboxPositionalEmbedding): - module.pos_emb.data.normal_(mean=0.0, std=0.01 * init_scale) + module.pos_emb.normal_(mean=0.0, std=0.01 * init_scale) elif isinstance(module, JukeboxRangeEmbedding): - module.emb.weight.data.normal_(mean=0.0, std=0.01 * init_scale) + module.emb.weight.normal_(mean=0.0, std=0.01 * init_scale) elif isinstance(module, JukeboxConditionalAutoregressive) and hasattr(module, "lm_head"): - module.lm_head.weight.data.normal_(mean=0.0, std=0.02 * init_scale) + module.lm_head.weight.normal_(mean=0.0, std=0.02 * init_scale) elif isinstance(module, JukeboxConditionalAutoregressive) and hasattr(module, "start_token"): - module.start_token.data.normal_(mean=0.0, std=0.01 * init_scale) + module.start_token.normal_(mean=0.0, std=0.01 * init_scale) elif isinstance(module, JukeboxResConv1DBlock) and self.config.zero_out: - module.conv1d_2.weight.data.zero_() - module.conv1d_2.bias.data.zero_() + module.conv1d_2.weight.zero_() + module.conv1d_2.bias.zero_() if isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() def __init__(self, config: JukeboxPriorConfig, level=None, nb_priors=3, vqvae_encoder=None, vqvae_decoder=None): super().__init__(config) @@ -2268,6 +2270,7 @@ class JukeboxPreTrainedModel(PreTrainedModel): base_model_prefix = "jukebox" supports_gradient_checkpointing = False + @torch.no_grad() def _init_weights(self, module): if isinstance(module, (JukeboxPrior, JukeboxVQVAE)): module.apply(module._init_weights) diff --git a/src/transformers/models/deprecated/mctct/modeling_mctct.py b/src/transformers/models/deprecated/mctct/modeling_mctct.py index 4f74c775a36a..db7c475dabd4 100755 --- a/src/transformers/models/deprecated/mctct/modeling_mctct.py +++ b/src/transformers/models/deprecated/mctct/modeling_mctct.py @@ -392,27 +392,28 @@ class MCTCTPreTrainedModel(PreTrainedModel): main_input_name = "input_features" supports_gradient_checkpointing = True + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" std = self.config.initializer_range if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, MCTCTLayerNorm): - module.singleton_weight.data.fill_(1.0) - module.singleton_bias.data.zero_() + module.singleton_weight.fill_(1.0) + module.singleton_bias.zero_() if isinstance(module, (nn.Linear, nn.Conv1d)): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): """ diff --git a/src/transformers/models/deprecated/mega/modeling_mega.py b/src/transformers/models/deprecated/mega/modeling_mega.py index 7342cba3d608..d66848e1d2b1 100644 --- a/src/transformers/models/deprecated/mega/modeling_mega.py +++ b/src/transformers/models/deprecated/mega/modeling_mega.py @@ -1332,6 +1332,7 @@ class MegaPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = False _no_split_modules = ["MegaMovingAverageGatedAttention"] + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, MegaMultiDimensionDampedEma): @@ -1365,16 +1366,16 @@ def _init_weights(self, module): nn.init.constant_(module.qk_bias, 0.0) elif isinstance(module, nn.Linear): # initializes all linear layers in the entire network - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) MEGA_START_DOCSTRING = r""" @@ -1638,7 +1639,7 @@ def forward( """MEGA Model with a `language modeling` head on top for CLM fine-tuning.""", MEGA_START_DOCSTRING ) class MegaForCausalLM(MegaPreTrainedModel): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "mega.embedding_layer.word_embeddings.weight"} def __init__(self, config: MegaConfig): super().__init__(config) @@ -1785,7 +1786,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti @add_start_docstrings("""MEGA Model with a `language modeling` head on top.""", MEGA_START_DOCSTRING) class MegaForMaskedLM(MegaPreTrainedModel): - _tied_weights_keys = ["mlm_head.weight"] + _tied_weights_keys = {"mlm_head.weight": "mega.embedding_layer.word_embeddings.weight"} def __init__(self, config: MegaConfig): super().__init__(config) diff --git a/src/transformers/models/deprecated/nat/modeling_nat.py b/src/transformers/models/deprecated/nat/modeling_nat.py index 4f16a1bfbafd..a43562406ce6 100644 --- a/src/transformers/models/deprecated/nat/modeling_nat.py +++ b/src/transformers/models/deprecated/nat/modeling_nat.py @@ -592,15 +592,16 @@ class NatPreTrainedModel(PreTrainedModel): base_model_prefix = "nat" main_input_name = "pixel_values" + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) NAT_START_DOCSTRING = r""" diff --git a/src/transformers/models/deprecated/nezha/modeling_nezha.py b/src/transformers/models/deprecated/nezha/modeling_nezha.py index bf617665c542..8e3cb0cd3f4b 100644 --- a/src/transformers/models/deprecated/nezha/modeling_nezha.py +++ b/src/transformers/models/deprecated/nezha/modeling_nezha.py @@ -535,16 +535,10 @@ def __init__(self, config): # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. - self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.decoder = nn.Linear(config.hidden_size, config.vocab_size) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` - self.decoder.bias = self.bias - - def _tie_weights(self): - self.decoder.bias = self.bias - def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -593,19 +587,20 @@ class NezhaPreTrainedModel(PreTrainedModel): base_model_prefix = "nezha" supports_gradient_checkpointing = True + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) @dataclass @@ -873,7 +868,10 @@ def forward( NEZHA_START_DOCSTRING, ) class NezhaForPreTraining(NezhaPreTrainedModel): - _tied_weights_keys = ["cls.predictions.decoder"] + _tied_weights_keys = { + "cls.predictions.decoder.weight": "nezha.embeddings.word_embeddings.weight", + "cls.predictions.decoder.bias": "cls.predictions.bias", + } def __init__(self, config): super().__init__(config) @@ -974,7 +972,10 @@ def forward( @add_start_docstrings("""Nezha Model with a `language modeling` head on top.""", NEZHA_START_DOCSTRING) class NezhaForMaskedLM(NezhaPreTrainedModel): - _tied_weights_keys = ["cls.predictions.decoder"] + _tied_weights_keys = { + "cls.predictions.decoder.weight": "nezha.embeddings.word_embeddings.weight", + "cls.predictions.decoder.bias": "cls.predictions.bias", + } def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py index bf39cfca912a..7da07eca1e34 100644 --- a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py +++ b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py @@ -439,19 +439,20 @@ class OpenLlamaPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["OpenLlamaDecoderLayer"] + @torch.no_grad() def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): if self.config.use_stable_embedding: - torch.nn.init.xavier_normal_(module.weight.data) + torch.nn.init.xavier_normal_(module.weight) else: - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() OPEN_LLAMA_INPUTS_DOCSTRING = r""" diff --git a/src/transformers/models/deprecated/qdqbert/modeling_qdqbert.py b/src/transformers/models/deprecated/qdqbert/modeling_qdqbert.py index 86478bcf5a18..f395fe51d645 100755 --- a/src/transformers/models/deprecated/qdqbert/modeling_qdqbert.py +++ b/src/transformers/models/deprecated/qdqbert/modeling_qdqbert.py @@ -540,15 +540,11 @@ def __init__(self, config): # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. - self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` - self.decoder.bias = self.bias - - def _tie_weights(self): - self.decoder.bias = self.bias def forward(self, hidden_states): hidden_states = self.transform(hidden_states) @@ -601,19 +597,20 @@ class QDQBertPreTrainedModel(PreTrainedModel): base_model_prefix = "bert" supports_gradient_checkpointing = True + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) QDQBERT_START_DOCSTRING = r""" @@ -853,7 +850,7 @@ def forward( """QDQBERT Model with a `language modeling` head on top for CLM fine-tuning.""", QDQBERT_START_DOCSTRING ) class QDQBertLMHeadModel(QDQBertPreTrainedModel): - _tied_weights_keys = ["predictions.decoder.weight", "predictions.decoder.bias"] + _tied_weights_keys = {"predictions.decoder.weight": "predictions.decoder.bias"} def __init__(self, config): super().__init__(config) @@ -1007,7 +1004,7 @@ def prepare_inputs_for_generation( @add_start_docstrings("""QDQBERT Model with a `language modeling` head on top.""", QDQBERT_START_DOCSTRING) class QDQBertForMaskedLM(QDQBertPreTrainedModel): - _tied_weights_keys = ["predictions.decoder.weight", "predictions.decoder.bias"] + _tied_weights_keys = {"predictions.decoder.weight": "predictions.decoder.bias"} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/deprecated/realm/modeling_realm.py b/src/transformers/models/deprecated/realm/modeling_realm.py index 7a135b9fdb5e..0b8062c5c900 100644 --- a/src/transformers/models/deprecated/realm/modeling_realm.py +++ b/src/transformers/models/deprecated/realm/modeling_realm.py @@ -624,16 +624,9 @@ def __init__(self, config): # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. - self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` - self.decoder.bias = self.bias - - def _tie_weights(self): - self.decoder.bias = self.bias - def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -794,19 +787,20 @@ class RealmPreTrainedModel(PreTrainedModel): config: RealmConfig base_model_prefix = "realm" + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) def _flatten_inputs(self, *inputs): """Flatten inputs' shape to (-1, input_shape[-1])""" @@ -961,7 +955,10 @@ def forward( REALM_START_DOCSTRING, ) class RealmEmbedder(RealmPreTrainedModel): - _tied_weights_keys = ["cls.predictions.decoder.bias"] + _tied_weights_keys = { + "cls.predictions.decoder.weight": "realm.embeddings.word_embeddings.weight", + "cls.predictions.decoder.bias": "cls.predictions.bias", + } def __init__(self, config): super().__init__(config) @@ -1186,7 +1183,10 @@ def forward( REALM_START_DOCSTRING, ) class RealmKnowledgeAugEncoder(RealmPreTrainedModel): - _tied_weights_keys = ["cls.predictions.decoder"] + _tied_weights_keys = { + "cls.predictions.decoder.weight": "realm.embeddings.word_embeddings.weight", + "cls.predictions.decoder.bias": "cls.predictions.bias", + } def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/deprecated/retribert/modeling_retribert.py b/src/transformers/models/deprecated/retribert/modeling_retribert.py index fa7695133fb8..7a762e46b890 100644 --- a/src/transformers/models/deprecated/retribert/modeling_retribert.py +++ b/src/transformers/models/deprecated/retribert/modeling_retribert.py @@ -42,19 +42,20 @@ class RetriBertPreTrainedModel(PreTrainedModel): config: RetriBertConfig base_model_prefix = "retribert" + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) RETRIBERT_START_DOCSTRING = r""" diff --git a/src/transformers/models/deprecated/speech_to_text_2/modeling_speech_to_text_2.py b/src/transformers/models/deprecated/speech_to_text_2/modeling_speech_to_text_2.py index 617e4d757c94..821467abccba 100755 --- a/src/transformers/models/deprecated/speech_to_text_2/modeling_speech_to_text_2.py +++ b/src/transformers/models/deprecated/speech_to_text_2/modeling_speech_to_text_2.py @@ -371,16 +371,17 @@ class Speech2Text2PreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True + @torch.no_grad() def _init_weights(self, module): std = self.config.init_std if isinstance(module, (nn.Linear, nn.Conv1d)): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, Speech2Text2SinusoidalPositionalEmbedding): weight = module.get_embedding(*module.weight.shape, module.padding_idx) weight = nn.Parameter(weight, requires_grad=False) @@ -628,7 +629,7 @@ def forward(self, *args, **kwargs): SPEECH_TO_TEXT_2_START_DOCSTRING, ) class Speech2Text2ForCausalLM(Speech2Text2PreTrainedModel): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.decoder.embed_tokens.weight"} def __init__(self, config): config.is_decoder = True diff --git a/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py b/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py index 1b4126f9ef20..2bc57636b944 100644 --- a/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py +++ b/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py @@ -84,14 +84,15 @@ class TrajectoryTransformerPreTrainedModel(PreTrainedModel): main_input_name = "trajectories" supports_gradient_checkpointing = True + @torch.no_grad() def _init_weights(self, module): if isinstance(module, (nn.Linear, nn.Embedding)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, EinLinear): for i in range(module.n_models): nn.init.kaiming_uniform_(module.weight[i], a=math.sqrt(5) / self.config.kaiming_initializer_range) diff --git a/src/transformers/models/deprecated/transfo_xl/modeling_transfo_xl.py b/src/transformers/models/deprecated/transfo_xl/modeling_transfo_xl.py index ba9cd4025dc2..b28613d71b7f 100644 --- a/src/transformers/models/deprecated/transfo_xl/modeling_transfo_xl.py +++ b/src/transformers/models/deprecated/transfo_xl/modeling_transfo_xl.py @@ -841,7 +841,7 @@ def forward( TRANSFO_XL_START_DOCSTRING, ) class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): - _tied_weights_keys = [r"crit\.out_projs\.\d+", r"crit\.out_layers\.\d+\.weight"] + _tied_weights_keys = {r"crit\.out_projs\.\d+": r"crit\.out_layers\.\d+\.weight"} def __init__(self, config): super().__init__(config) @@ -874,9 +874,6 @@ def tie_weights(self): Run this to be sure output and input (adaptive) softmax weights are tied """ - if self.config.tie_word_embeddings: - for i in range(len(self.crit.out_layers)): - self._tie_embedding_weights(self.crit.out_layers[i], self.transformer.word_emb.emb_layers[i]) if self.config.tie_projs: for i, tie_proj in enumerate(self.config.tie_projs): if tie_proj and self.config.div_val == 1 and self.config.d_model != self.config.d_embed: diff --git a/src/transformers/models/deprecated/tvlt/modeling_tvlt.py b/src/transformers/models/deprecated/tvlt/modeling_tvlt.py index 9cdba679bc0a..fbea2e2b77a3 100644 --- a/src/transformers/models/deprecated/tvlt/modeling_tvlt.py +++ b/src/transformers/models/deprecated/tvlt/modeling_tvlt.py @@ -548,15 +548,16 @@ class TvltPreTrainedModel(PreTrainedModel): main_input_name = "pixel_values" supports_gradient_checkpointing = True + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) TVLT_START_DOCSTRING = r""" diff --git a/src/transformers/models/deprecated/van/modeling_van.py b/src/transformers/models/deprecated/van/modeling_van.py index 6ee0e881e558..007b74755e5d 100644 --- a/src/transformers/models/deprecated/van/modeling_van.py +++ b/src/transformers/models/deprecated/van/modeling_van.py @@ -359,6 +359,7 @@ class VanPreTrainedModel(PreTrainedModel): main_input_name = "pixel_values" supports_gradient_checkpointing = True + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): @@ -371,9 +372,9 @@ def _init_weights(self, module): elif isinstance(module, nn.Conv2d): fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels fan_out //= module.groups - module.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + module.weight.normal_(0, math.sqrt(2.0 / fan_out)) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() VAN_START_DOCSTRING = r""" diff --git a/src/transformers/models/deprecated/vit_hybrid/modeling_vit_hybrid.py b/src/transformers/models/deprecated/vit_hybrid/modeling_vit_hybrid.py index efa98eada009..bbc6554ff5d5 100644 --- a/src/transformers/models/deprecated/vit_hybrid/modeling_vit_hybrid.py +++ b/src/transformers/models/deprecated/vit_hybrid/modeling_vit_hybrid.py @@ -457,31 +457,38 @@ class ViTHybridPreTrainedModel(PreTrainedModel): _no_split_modules = ["ViTHybridEmbeddings", "ViTHybridLayer"] _supports_sdpa = True + @torch.no_grad() def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid # `trunc_normal_cpu` not implemented in `half` issues - module.weight.data = nn.init.trunc_normal_( - module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range - ).to(module.weight.dtype) + module.weight.copy_( + nn.init.trunc_normal_(module.weight.to(torch.float32), mean=0.0, std=self.config.initializer_range).to( + module.weight.dtype + ) + ) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, ViTHybridEmbeddings): - module.position_embeddings.data = nn.init.trunc_normal_( - module.position_embeddings.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.position_embeddings.dtype) - module.cls_token.data = nn.init.trunc_normal_( - module.cls_token.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.cls_token.dtype) - module.mask_token.data.zero_() + module.position_embeddings.copy_( + nn.init.trunc_normal_( + module.position_embeddings.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + ) + module.cls_token.copy_( + nn.init.trunc_normal_( + module.cls_token.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.cls_token.dtype) + ) + module.mask_token.zero_() VIT_START_DOCSTRING = r""" diff --git a/src/transformers/models/deprecated/xlm_prophetnet/modeling_xlm_prophetnet.py b/src/transformers/models/deprecated/xlm_prophetnet/modeling_xlm_prophetnet.py index bf44f7c19f34..c592e756b7c9 100644 --- a/src/transformers/models/deprecated/xlm_prophetnet/modeling_xlm_prophetnet.py +++ b/src/transformers/models/deprecated/xlm_prophetnet/modeling_xlm_prophetnet.py @@ -520,15 +520,16 @@ class XLMProphetNetPreTrainedModel(PreTrainedModel): base_model_prefix = "prophetnet" supports_gradient_checkpointing = True + @torch.no_grad() def _init_weights(self, module): if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.init_std) + module.weight.normal_(mean=0.0, std=self.config.init_std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.init_std) + module.weight.normal_(mean=0.0, std=self.config.init_std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id @@ -1169,14 +1170,10 @@ class XLMProphetNetEncoder(XLMProphetNetPreTrainedModel): embeddings instead of randomly initialized word embeddings. """ - def __init__(self, config: XLMProphetNetConfig, word_embeddings: Optional[nn.Embedding] = None): + def __init__(self, config: XLMProphetNetConfig): super().__init__(config) - self.word_embeddings = ( - word_embeddings - if word_embeddings is not None - else nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) - ) + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) self.position_embeddings = XLMProphetNetPositionalEmbeddings(config) self.embeddings_layer_norm = LayerNorm(config.hidden_size) @@ -1287,7 +1284,7 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel): embeddings instead of randomly initialized word embeddings. """ - def __init__(self, config: XLMProphetNetConfig, word_embeddings: Optional[nn.Embedding] = None): + def __init__(self, config: XLMProphetNetConfig): super().__init__(config) self.ngram = config.ngram @@ -1296,11 +1293,7 @@ def __init__(self, config: XLMProphetNetConfig, word_embeddings: Optional[nn.Emb self.dropout = config.dropout self.max_target_positions = config.max_position_embeddings - self.word_embeddings = ( - word_embeddings - if word_embeddings is not None - else nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) - ) + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) self.position_embeddings = XLMProphetNetPositionalEmbeddings(config) self.ngram_embeddings = nn.Embedding(self.ngram, config.hidden_size, None) @@ -1611,7 +1604,10 @@ def prepare_predict_attention_mask(self, hidden_states, attention_mask): XLM_PROPHETNET_START_DOCSTRING, ) class XLMProphetNetModel(XLMProphetNetPreTrainedModel): - _tied_weights_keys = ["encoder.word_embeddings.weight", "decoder.word_embeddings.weight"] + _tied_weights_keys = { + "encoder.word_embeddings.weight": "word_embeddings.weight", + "decoder.word_embeddings.weight": "word_embeddings.weight", + } def __init__(self, config: XLMProphetNetConfig): super().__init__(config) @@ -1620,12 +1616,12 @@ def __init__(self, config: XLMProphetNetConfig): encoder_config = copy.deepcopy(config) encoder_config.is_encoder_decoder = False encoder_config.use_cache = False - self.encoder = XLMProphetNetEncoder(encoder_config, self.word_embeddings) + self.encoder = XLMProphetNetEncoder(encoder_config) decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True decoder_config.is_encoder_decoder = False - self.decoder = XLMProphetNetDecoder(decoder_config, self.word_embeddings) + self.decoder = XLMProphetNetDecoder(decoder_config) # Initialize weights and apply final processing self.post_init() @@ -1638,11 +1634,6 @@ def set_input_embeddings(self, value): self.encoder.word_embeddings = self.word_embeddings self.decoder.word_embeddings = self.word_embeddings - def _tie_weights(self): - if self.config.tie_word_embeddings: - self._tie_embedding_weights(self.encoder.word_embeddings, self.word_embeddings) - self._tie_embedding_weights(self.decoder.word_embeddings, self.word_embeddings) - def get_encoder(self): return self.encoder @@ -1736,7 +1727,7 @@ def forward( XLM_PROPHETNET_START_DOCSTRING, ) class XLMProphetNetForConditionalGeneration(XLMProphetNetPreTrainedModel): - _tied_weights_keys = ["encoder.word_embeddings.weight", "decoder.word_embeddings.weight", "lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "prophetnet.word_embeddings.weight"} def __init__(self, config: XLMProphetNetConfig): super().__init__(config) @@ -1749,10 +1740,6 @@ def __init__(self, config: XLMProphetNetConfig): # Initialize weights and apply final processing self.post_init() - def _tie_weights(self): - if self.config.tie_word_embeddings: - self._tie_embedding_weights(self.prophetnet.word_embeddings, self.lm_head) - def get_input_embeddings(self): return self.prophetnet.word_embeddings @@ -1934,11 +1921,9 @@ def get_decoder(self): XLM_PROPHETNET_START_DOCSTRING, ) class XLMProphetNetForCausalLM(XLMProphetNetPreTrainedModel): - _tied_weights_keys = [ - "prophetnet.word_embeddings.weight", - "prophetnet.decoder.word_embeddings.weight", - "lm_head.weight", - ] + _tied_weights_keys = { + "lm_head.weight": "model.decoder.embed_tokens.weight", + } def __init__(self, config: XLMProphetNetConfig): # set config for CLM @@ -1962,10 +1947,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.prophetnet.decoder.word_embeddings = value - def _tie_weights(self): - if self.config.tie_word_embeddings: - self._tie_embedding_weights(self.prophetnet.decoder.word_embeddings, self.lm_head) - def set_decoder(self, decoder): self.prophetnet.decoder = decoder @@ -2163,6 +2144,10 @@ class XLMProphetNetDecoderWrapper(XLMProphetNetPreTrainedModel): classes. """ + _tied_weights_keys = { + "model.decoder.embed_tokens.weight": "word_embeddings.weight", + } + def __init__(self, config: XLMProphetNetConfig): super().__init__(config) @@ -2172,9 +2157,6 @@ def __init__(self, config: XLMProphetNetConfig): # Initialize weights and apply final processing self.post_init() - def _tie_weights(self): - self._tie_embedding_weights(self.word_embeddings, self.decoder.get_input_embeddings()) - def forward(self, *args, **kwargs): return self.decoder(*args, **kwargs) diff --git a/src/transformers/models/depth_anything/modeling_depth_anything.py b/src/transformers/models/depth_anything/modeling_depth_anything.py index 862b77807d3a..d6dae7cb72ee 100644 --- a/src/transformers/models/depth_anything/modeling_depth_anything.py +++ b/src/transformers/models/depth_anything/modeling_depth_anything.py @@ -216,15 +216,16 @@ class DepthAnythingPreTrainedModel(PreTrainedModel): input_modalities = "image" supports_gradient_checkpointing = True + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) class DepthAnythingNeck(nn.Module): diff --git a/src/transformers/models/depth_pro/modeling_depth_pro.py b/src/transformers/models/depth_pro/modeling_depth_pro.py index c8a90eaaef02..b754cf9074c1 100644 --- a/src/transformers/models/depth_pro/modeling_depth_pro.py +++ b/src/transformers/models/depth_pro/modeling_depth_pro.py @@ -608,19 +608,20 @@ class DepthProPreTrainedModel(PreTrainedModel): _no_split_modules = ["DepthProPreActResidualLayer"] _keys_to_ignore_on_load_unexpected = ["fov_model.*"] + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)): nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() @auto_docstring diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index f0378c25a381..742bf8785731 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -727,6 +727,7 @@ class DetrPreTrainedModel(PreTrainedModel): input_modalities = "image" _no_split_modules = [r"DetrConvEncoder", r"DetrEncoderLayer", r"DetrDecoderLayer"] + @torch.no_grad() def _init_weights(self, module): std = self.config.init_std xavier_std = self.config.init_xavier_std @@ -740,13 +741,13 @@ def _init_weights(self, module): nn.init.uniform_(module.row_embeddings.weight) nn.init.uniform_(module.column_embeddings.weight) if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() class DetrEncoder(DetrPreTrainedModel): diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index d82430b623e1..7e67ac52768c 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -596,13 +596,14 @@ class DiffLlamaPreTrainedModel(PreTrainedModel): "attentions": DiffLlamaAttention, } + @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) if isinstance(module, DiffLlamaAttention): - module.lambda_q1.data.normal_(0, self.config.lambda_std_dev) - module.lambda_k1.data.normal_(0, self.config.lambda_std_dev) - module.lambda_q2.data.normal_(0, self.config.lambda_std_dev) - module.lambda_k2.data.normal_(0, self.config.lambda_std_dev) + module.lambda_q1.normal_(0, self.config.lambda_std_dev) + module.lambda_k1.normal_(0, self.config.lambda_std_dev) + module.lambda_q2.normal_(0, self.config.lambda_std_dev) + module.lambda_k2.normal_(0, self.config.lambda_std_dev) @auto_docstring @@ -686,7 +687,7 @@ def forward( @auto_docstring class DiffLlamaForCausalLM(DiffLlamaPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/diffllama/modular_diffllama.py b/src/transformers/models/diffllama/modular_diffllama.py index 331c7327b681..97b1cc051660 100644 --- a/src/transformers/models/diffllama/modular_diffllama.py +++ b/src/transformers/models/diffllama/modular_diffllama.py @@ -399,13 +399,14 @@ class DiffLlamaPreTrainedModel(LlamaPreTrainedModel): _supports_flex_attn = False _supports_attention_backend = False + @torch.no_grad() def _init_weights(self, module): PreTrainedModel._init_weights(self, module) if isinstance(module, DiffLlamaAttention): - module.lambda_q1.data.normal_(0, self.config.lambda_std_dev) - module.lambda_k1.data.normal_(0, self.config.lambda_std_dev) - module.lambda_q2.data.normal_(0, self.config.lambda_std_dev) - module.lambda_k2.data.normal_(0, self.config.lambda_std_dev) + module.lambda_q1.normal_(0, self.config.lambda_std_dev) + module.lambda_k1.normal_(0, self.config.lambda_std_dev) + module.lambda_q2.normal_(0, self.config.lambda_std_dev) + module.lambda_k2.normal_(0, self.config.lambda_std_dev) class DiffLlamaModel(LlamaModel): diff --git a/src/transformers/models/dinat/modeling_dinat.py b/src/transformers/models/dinat/modeling_dinat.py index 8f3220cfa1e9..103e12ce5ed9 100644 --- a/src/transformers/models/dinat/modeling_dinat.py +++ b/src/transformers/models/dinat/modeling_dinat.py @@ -561,15 +561,16 @@ class DinatPreTrainedModel(PreTrainedModel): main_input_name = "pixel_values" input_modalities = "image" + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) @auto_docstring diff --git a/src/transformers/models/dinov2/modeling_dinov2.py b/src/transformers/models/dinov2/modeling_dinov2.py index fa1887588020..49693d507733 100644 --- a/src/transformers/models/dinov2/modeling_dinov2.py +++ b/src/transformers/models/dinov2/modeling_dinov2.py @@ -414,36 +414,43 @@ class Dinov2PreTrainedModel(PreTrainedModel): "attentions": Dinov2SelfAttention, } + @torch.no_grad() def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid # `trunc_normal_cpu` not implemented in `half` issues - module.weight.data = nn.init.trunc_normal_( - module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range - ).to(module.weight.dtype) + module.weight.copy_( + nn.init.trunc_normal_(module.weight.to(torch.float32), mean=0.0, std=self.config.initializer_range).to( + module.weight.dtype + ) + ) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, Dinov2Embeddings): - module.position_embeddings.data = nn.init.trunc_normal_( - module.position_embeddings.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.position_embeddings.dtype) - - module.cls_token.data = nn.init.trunc_normal_( - module.cls_token.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.cls_token.dtype) + module.position_embeddings.copy_( + nn.init.trunc_normal_( + module.position_embeddings.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + ) + + module.cls_token.copy_( + nn.init.trunc_normal_( + module.cls_token.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.cls_token.dtype) + ) if self.config.use_mask_token: - module.mask_token.data.zero_() + module.mask_token.zero_() elif isinstance(module, Dinov2LayerScale): - module.lambda1.data.fill_(self.config.layerscale_value) + module.lambda1.fill_(self.config.layerscale_value) @auto_docstring diff --git a/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py b/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py index bf16e8eadc40..ddbc6e05b1a5 100644 --- a/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +++ b/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py @@ -431,36 +431,43 @@ class Dinov2WithRegistersPreTrainedModel(PreTrainedModel): "attentions": Dinov2WithRegistersSelfAttention, } + @torch.no_grad() def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid # `trunc_normal_cpu` not implemented in `half` issues - module.weight.data = nn.init.trunc_normal_( - module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range - ).to(module.weight.dtype) + module.weight.copy_( + nn.init.trunc_normal_(module.weight.to(torch.float32), mean=0.0, std=self.config.initializer_range).to( + module.weight.dtype + ) + ) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, Dinov2WithRegistersEmbeddings): - module.position_embeddings.data = nn.init.trunc_normal_( - module.position_embeddings.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.position_embeddings.dtype) - - module.cls_token.data = nn.init.trunc_normal_( - module.cls_token.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.cls_token.dtype) - - module.mask_token.data.zero_() - module.register_tokens.data.zero_() + module.position_embeddings.copy_( + nn.init.trunc_normal_( + module.position_embeddings.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + ) + + module.cls_token.copy_( + nn.init.trunc_normal_( + module.cls_token.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.cls_token.dtype) + ) + + module.mask_token.zero_() + module.register_tokens.zero_() elif isinstance(module, Dinov2WithRegistersLayerScale): # noqa: F821 - module.lambda1.data.fill_(self.config.layerscale_value) + module.lambda1.fill_(self.config.layerscale_value) @auto_docstring diff --git a/src/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py b/src/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py index 05a843361db4..1cb6cf79bc0b 100644 --- a/src/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +++ b/src/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py @@ -277,36 +277,43 @@ class Dinov2WithRegistersEncoder(Dinov2Encoder): class Dinov2WithRegistersPreTrainedModel(Dinov2PreTrainedModel): + @torch.no_grad() def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid # `trunc_normal_cpu` not implemented in `half` issues - module.weight.data = nn.init.trunc_normal_( - module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range - ).to(module.weight.dtype) + module.weight.copy_( + nn.init.trunc_normal_(module.weight.to(torch.float32), mean=0.0, std=self.config.initializer_range).to( + module.weight.dtype + ) + ) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, Dinov2WithRegistersEmbeddings): - module.position_embeddings.data = nn.init.trunc_normal_( - module.position_embeddings.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.position_embeddings.dtype) - - module.cls_token.data = nn.init.trunc_normal_( - module.cls_token.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.cls_token.dtype) - - module.mask_token.data.zero_() - module.register_tokens.data.zero_() + module.position_embeddings.copy_( + nn.init.trunc_normal_( + module.position_embeddings.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + ) + + module.cls_token.copy_( + nn.init.trunc_normal_( + module.cls_token.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.cls_token.dtype) + ) + + module.mask_token.zero_() + module.register_tokens.zero_() elif isinstance(module, Dinov2WithRegistersLayerScale): # noqa: F821 - module.lambda1.data.fill_(self.config.layerscale_value) + module.lambda1.fill_(self.config.layerscale_value) class Dinov2WithRegistersModel(Dinov2Model): diff --git a/src/transformers/models/dinov3_convnext/modeling_dinov3_convnext.py b/src/transformers/models/dinov3_convnext/modeling_dinov3_convnext.py index bc6720ebfe73..286cc87c3ca3 100644 --- a/src/transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +++ b/src/transformers/models/dinov3_convnext/modeling_dinov3_convnext.py @@ -191,18 +191,19 @@ class DINOv3ConvNextPreTrainedModel(PreTrainedModel): input_modalities = "image" _no_split_modules = ["DINOv3ConvNextLayer"] + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, (nn.LayerNorm, DINOv3ConvNextLayerNorm)): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, DINOv3ConvNextLayer): if module.gamma is not None: - module.gamma.data.fill_(self.config.layer_scale_init_value) + module.gamma.fill_(self.config.layer_scale_init_value) @auto_docstring diff --git a/src/transformers/models/dinov3_vit/modeling_dinov3_vit.py b/src/transformers/models/dinov3_vit/modeling_dinov3_vit.py index 49e75dcd35bf..462e02377837 100644 --- a/src/transformers/models/dinov3_vit/modeling_dinov3_vit.py +++ b/src/transformers/models/dinov3_vit/modeling_dinov3_vit.py @@ -447,36 +447,43 @@ class DINOv3ViTPreTrainedModel(PreTrainedModel): "attentions": DINOv3ViTAttention, } + @torch.no_grad() def _init_weights(self, module) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid # `trunc_normal_cpu` not implemented in `half` issues - module.weight.data = nn.init.trunc_normal_( - module.weight.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.weight.dtype) + module.weight.copy_( + nn.init.trunc_normal_( + module.weight.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.weight.dtype) + ) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, DINOv3ViTEmbeddings): - module.cls_token.data = nn.init.trunc_normal_( - module.cls_token.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.cls_token.dtype) - if module.config.num_register_tokens > 0: - module.register_tokens.data = nn.init.trunc_normal_( - module.register_tokens.data.to(torch.float32), + module.cls_token.copy_( + nn.init.trunc_normal_( + module.cls_token.to(torch.float32), mean=0.0, std=self.config.initializer_range, - ).to(module.register_tokens.dtype) - module.mask_token.data.zero_() + ).to(module.cls_token.dtype) + ) + if module.config.num_register_tokens > 0: + module.register_tokens.copy_( + nn.init.trunc_normal_( + module.register_tokens.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.register_tokens.dtype) + ) + module.mask_token.zero_() elif isinstance(module, DINOv3ViTLayerScale): - module.lambda1.data.fill_(self.config.layerscale_value) + module.lambda1.fill_(self.config.layerscale_value) @auto_docstring diff --git a/src/transformers/models/dinov3_vit/modular_dinov3_vit.py b/src/transformers/models/dinov3_vit/modular_dinov3_vit.py index edb6cf82b240..19c85d5829e0 100644 --- a/src/transformers/models/dinov3_vit/modular_dinov3_vit.py +++ b/src/transformers/models/dinov3_vit/modular_dinov3_vit.py @@ -342,36 +342,43 @@ class DINOv3ViTPreTrainedModel(Dinov2PreTrainedModel): "attentions": DINOv3ViTAttention, } + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid # `trunc_normal_cpu` not implemented in `half` issues - module.weight.data = nn.init.trunc_normal_( - module.weight.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.weight.dtype) + module.weight.copy_( + nn.init.trunc_normal_( + module.weight.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.weight.dtype) + ) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, DINOv3ViTEmbeddings): - module.cls_token.data = nn.init.trunc_normal_( - module.cls_token.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.cls_token.dtype) - if module.config.num_register_tokens > 0: - module.register_tokens.data = nn.init.trunc_normal_( - module.register_tokens.data.to(torch.float32), + module.cls_token.copy_( + nn.init.trunc_normal_( + module.cls_token.to(torch.float32), mean=0.0, std=self.config.initializer_range, - ).to(module.register_tokens.dtype) - module.mask_token.data.zero_() + ).to(module.cls_token.dtype) + ) + if module.config.num_register_tokens > 0: + module.register_tokens.copy_( + nn.init.trunc_normal_( + module.register_tokens.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.register_tokens.dtype) + ) + module.mask_token.zero_() elif isinstance(module, DINOv3ViTLayerScale): - module.lambda1.data.fill_(self.config.layerscale_value) + module.lambda1.fill_(self.config.layerscale_value) @auto_docstring diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index 6f2fb86fb885..0638a99124b6 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -299,19 +299,20 @@ class DistilBertPreTrainedModel(PreTrainedModel): "attentions": DistilBertSelfAttention, } + @torch.no_grad() def _init_weights(self, module: nn.Module): """Initialize the weights.""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, Embeddings) and self.config.sinusoidal_pos_embds: create_sinusoidal_embeddings( self.config.max_position_embeddings, self.config.dim, module.position_embeddings.weight @@ -430,7 +431,7 @@ def forward( """ ) class DistilBertForMaskedLM(DistilBertPreTrainedModel): - _tied_weights_keys = ["vocab_projector.weight"] + _tied_weights_keys = {"vocab_projector.weight": "distilbert.embeddings.word_embeddings.weight"} def __init__(self, config: PreTrainedConfig): super().__init__(config) diff --git a/src/transformers/models/doge/modeling_doge.py b/src/transformers/models/doge/modeling_doge.py index 1ced8dbbdd63..c3cc3033d5bf 100644 --- a/src/transformers/models/doge/modeling_doge.py +++ b/src/transformers/models/doge/modeling_doge.py @@ -524,17 +524,18 @@ class DogePreTrainedModel(PreTrainedModel): "attentions": DogeAttention, } + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" super()._init_weights(module) if isinstance(module, DogeAttention): if hasattr(module, "A"): - module.A.data.zero_() + module.A.zero_() elif isinstance(module, DogeDecoderLayer): if hasattr(module, "input_residual"): - module.input_residual.data.fill_(1.0) + module.input_residual.fill_(1.0) if hasattr(module, "post_attention_residual"): - module.post_attention_residual.data.fill_(1.0) + module.post_attention_residual.fill_(1.0) @auto_docstring @@ -726,7 +727,7 @@ def load_balancing_loss_func( @auto_docstring class DogeForCausalLM(DogePreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/doge/modular_doge.py b/src/transformers/models/doge/modular_doge.py index fd71f7479f6b..261f7ba42458 100644 --- a/src/transformers/models/doge/modular_doge.py +++ b/src/transformers/models/doge/modular_doge.py @@ -540,17 +540,18 @@ class DogePreTrainedModel(LlamaPreTrainedModel): "attentions": DogeAttention, } + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" PreTrainedModel._init_weights(self, module) if isinstance(module, DogeAttention): if hasattr(module, "A"): - module.A.data.zero_() + module.A.zero_() elif isinstance(module, DogeDecoderLayer): if hasattr(module, "input_residual"): - module.input_residual.data.fill_(1.0) + module.input_residual.fill_(1.0) if hasattr(module, "post_attention_residual"): - module.post_attention_residual.data.fill_(1.0) + module.post_attention_residual.fill_(1.0) class DogeModel(MixtralModel): diff --git a/src/transformers/models/donut/modeling_donut_swin.py b/src/transformers/models/donut/modeling_donut_swin.py index eac5d7449604..e7d9422e69e2 100644 --- a/src/transformers/models/donut/modeling_donut_swin.py +++ b/src/transformers/models/donut/modeling_donut_swin.py @@ -789,22 +789,23 @@ class DonutSwinPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["DonutSwinStage"] + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, DonutSwinEmbeddings): if module.mask_token is not None: - module.mask_token.data.zero_() + module.mask_token.zero_() if module.position_embeddings is not None: - module.position_embeddings.data.zero_() + module.position_embeddings.zero_() elif isinstance(module, DonutSwinSelfAttention): - module.relative_position_bias_table.data.zero_() + module.relative_position_bias_table.zero_() @auto_docstring diff --git a/src/transformers/models/dots1/modeling_dots1.py b/src/transformers/models/dots1/modeling_dots1.py index d4a8188e24c6..6f8f1429dfa9 100644 --- a/src/transformers/models/dots1/modeling_dots1.py +++ b/src/transformers/models/dots1/modeling_dots1.py @@ -305,19 +305,23 @@ def forward(self, hidden_states): return router_logits -class Dots1NaiveMoe(nn.ModuleList): - """ - ModuleList of experts. - """ +class Dots1NaiveMoe(nn.Module): + """Collection of expert weights stored as 3D tensors.""" def __init__(self, config): super().__init__() self.num_experts = config.num_local_experts - for _ in range(self.num_experts): - self.append(Dots1MLP(config, intermediate_size=config.moe_intermediate_size)) + self.hidden_dim = config.hidden_size + self.intermediate_dim = config.intermediate_size + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) + self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) + self.act_fn = ACT2FN[config.hidden_act] def forward( - self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, ) -> torch.Tensor: """ Args: @@ -328,14 +332,24 @@ def forward( (batch_size * sequence_length, hidden_dim) """ final_hidden_states = torch.zeros_like(hidden_states) - expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0) + num_experts = top_k_weights.shape[1] + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() for expert_idx in expert_hit: - idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) - current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1]) - current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None] - final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + expert_idx = expert_idx[0] + if expert_idx == num_experts: + continue + _, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) + current_hidden_states = self.act_fn(gate) * up + current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) + current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None] + final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) + return final_hidden_states @@ -460,10 +474,11 @@ class Dots1PreTrainedModel(PreTrainedModel): "attentions": Dots1Attention, } + @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) if isinstance(module, Dots1TopkRouter): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) @auto_docstring @@ -559,7 +574,7 @@ def forward( @auto_docstring class Dots1ForCausalLM(Dots1PreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/dpr/modeling_dpr.py b/src/transformers/models/dpr/modeling_dpr.py index 7ee4dcaf52e1..6ed58db0184c 100644 --- a/src/transformers/models/dpr/modeling_dpr.py +++ b/src/transformers/models/dpr/modeling_dpr.py @@ -105,19 +105,20 @@ class DPRReaderOutput(ModelOutput): class DPRPreTrainedModel(PreTrainedModel): _supports_sdpa = True + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) class DPREncoder(DPRPreTrainedModel): diff --git a/src/transformers/models/dpt/modeling_dpt.py b/src/transformers/models/dpt/modeling_dpt.py index 6185ab3a45d0..6562e7891772 100755 --- a/src/transformers/models/dpt/modeling_dpt.py +++ b/src/transformers/models/dpt/modeling_dpt.py @@ -732,18 +732,19 @@ class DPTPreTrainedModel(PreTrainedModel): "attentions": DPTSelfAttention, } + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) if isinstance(module, (DPTViTEmbeddings, DPTViTHybridEmbeddings)): - module.cls_token.data.zero_() - module.position_embeddings.data.zero_() + module.cls_token.zero_() + module.position_embeddings.zero_() @auto_docstring diff --git a/src/transformers/models/edgetam/modeling_edgetam.py b/src/transformers/models/edgetam/modeling_edgetam.py index 417583b4a18e..279957a52d7f 100644 --- a/src/transformers/models/edgetam/modeling_edgetam.py +++ b/src/transformers/models/edgetam/modeling_edgetam.py @@ -308,22 +308,23 @@ class EdgeTamPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_attention_backend = True + @torch.no_grad() def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, (nn.LayerNorm, EdgeTamLayerNorm)): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() if isinstance(module, EdgeTamModel): if module.no_memory_embedding is not None: - module.no_memory_embedding.data.zero_() + module.no_memory_embedding.zero_() # copied and adapted from original implementation, also practically equal to DetrSinePositionEmbedding @@ -921,7 +922,9 @@ def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores): ) class EdgeTamModel(EdgeTamPreTrainedModel): input_modalities = ["image", "text"] - _tied_weights_keys = ["prompt_encoder.shared_embedding.positional_embedding"] + _tied_weights_keys = { + "prompt_encoder.shared_embedding.positional_embedding": "shared_image_embedding.positional_embedding" + } # need to be ignored, as it's a buffer and will not be correctly detected as tied weight _keys_to_ignore_on_load_missing = ["prompt_encoder.shared_embedding.positional_embedding"] _can_record_outputs = {"mask_decoder_attentions": OutputRecorder(EdgeTamTwoWayAttentionBlock, index=2)} @@ -953,11 +956,6 @@ def __init__(self, config: EdgeTamConfig): self.post_init() - def _tie_weights(self): - self.prompt_encoder.shared_embedding.positional_embedding.data = ( - self.shared_image_embedding.positional_embedding.data - ) - def get_image_wide_positional_embeddings(self) -> torch.Tensor: size = self.prompt_encoder.image_embedding_size target_device = self.shared_image_embedding.positional_embedding.device diff --git a/src/transformers/models/edgetam/modular_edgetam.py b/src/transformers/models/edgetam/modular_edgetam.py index d432a725b021..594cb6084aa0 100644 --- a/src/transformers/models/edgetam/modular_edgetam.py +++ b/src/transformers/models/edgetam/modular_edgetam.py @@ -174,22 +174,23 @@ class EdgeTamFeedForward(Sam2FeedForward): @auto_docstring class EdgeTamPreTrainedModel(Sam2PreTrainedModel): + @torch.no_grad() def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, (nn.LayerNorm, EdgeTamLayerNorm)): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() if isinstance(module, EdgeTamModel): if module.no_memory_embedding is not None: - module.no_memory_embedding.data.zero_() + module.no_memory_embedding.zero_() @auto_docstring( diff --git a/src/transformers/models/edgetam_video/modeling_edgetam_video.py b/src/transformers/models/edgetam_video/modeling_edgetam_video.py index 1e5f1290c8c1..65bb962bdeab 100644 --- a/src/transformers/models/edgetam_video/modeling_edgetam_video.py +++ b/src/transformers/models/edgetam_video/modeling_edgetam_video.py @@ -778,31 +778,32 @@ class EdgeTamVideoPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_attention_backend = True + @torch.no_grad() def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, (nn.LayerNorm, EdgeTamVideoLayerNorm)): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() elif isinstance(module, EdgeTamVideoModel): if module.no_memory_positional_encoding is not None: - module.no_memory_positional_encoding.data.zero_() + module.no_memory_positional_encoding.zero_() if module.memory_temporal_positional_encoding is not None: - module.memory_temporal_positional_encoding.data.zero_() + module.memory_temporal_positional_encoding.zero_() if module.no_object_pointer is not None: - module.no_object_pointer.data.zero_() + module.no_object_pointer.zero_() if module.occlusion_spatial_embedding_parameter is not None: - module.occlusion_spatial_embedding_parameter.data.zero_() + module.occlusion_spatial_embedding_parameter.zero_() if isinstance(module, EdgeTamVideoMemoryFuserCXBlock): if module.scale is not None: - module.scale.data.zero_() + module.scale.zero_() class EdgeTamVideoInferenceCache: @@ -1977,7 +1978,9 @@ def get_1d_sine_pe(pos_inds, dim, temperature=10000): @auto_docstring class EdgeTamVideoModel(EdgeTamVideoPreTrainedModel): input_modalities = ["video", "text"] - _tied_weights_keys = ["prompt_encoder.shared_embedding.positional_embedding"] + _tied_weights_keys = { + "prompt_encoder.shared_embedding.positional_embedding": "shared_image_embedding.positional_embedding" + } # need to be ignored, as it's a buffer and will not be correctly detected as tied weight _keys_to_ignore_on_load_missing = ["prompt_encoder.shared_embedding.positional_embedding"] _can_record_outputs = {"mask_decoder_attentions": OutputRecorder(EdgeTamVideoTwoWayAttentionBlock, index=2)} @@ -2034,11 +2037,6 @@ def __init__(self, config: EdgeTamVideoConfig): self.post_init() - def _tie_weights(self): - self.prompt_encoder.shared_embedding.positional_embedding.data = ( - self.shared_image_embedding.positional_embedding.data - ) - def get_input_embeddings(self): return self.vision_encoder.get_input_embeddings() diff --git a/src/transformers/models/edgetam_video/modular_edgetam_video.py b/src/transformers/models/edgetam_video/modular_edgetam_video.py index 65ca8ac1bdbe..06dc598a2772 100644 --- a/src/transformers/models/edgetam_video/modular_edgetam_video.py +++ b/src/transformers/models/edgetam_video/modular_edgetam_video.py @@ -1025,7 +1025,9 @@ def _forward_2d(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch. @auto_docstring class EdgeTamVideoModel(Sam2VideoModel): - _tied_weights_keys = ["prompt_encoder.shared_embedding.positional_embedding"] + _tied_weights_keys = { + "prompt_encoder.shared_embedding.positional_embedding": "shared_image_embedding.positional_embedding" + } # need to be ignored, as it's a buffer and will not be correctly detected as tied weight _keys_to_ignore_on_load_missing = ["prompt_encoder.shared_embedding.positional_embedding"] _keys_to_ignore_on_load_unexpected = [] diff --git a/src/transformers/models/efficientloftr/modeling_efficientloftr.py b/src/transformers/models/efficientloftr/modeling_efficientloftr.py index 16c9eabdcd65..5f21d7cad00f 100644 --- a/src/transformers/models/efficientloftr/modeling_efficientloftr.py +++ b/src/transformers/models/efficientloftr/modeling_efficientloftr.py @@ -675,15 +675,16 @@ class EfficientLoFTRPreTrainedModel(PreTrainedModel): "attentions": EfficientLoFTRAttention, } + @torch.no_grad() def _init_weights(self, module: nn.Module) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d, nn.Conv1d, nn.BatchNorm2d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) # Copied from transformers.models.superpoint.modeling_superpoint.SuperPointPreTrainedModel.extract_one_channel_pixel_values with SuperPoint->EfficientLoFTR def extract_one_channel_pixel_values(self, pixel_values: torch.FloatTensor) -> torch.FloatTensor: diff --git a/src/transformers/models/efficientnet/modeling_efficientnet.py b/src/transformers/models/efficientnet/modeling_efficientnet.py index 0e35f791f9d2..4c55a3058b98 100644 --- a/src/transformers/models/efficientnet/modeling_efficientnet.py +++ b/src/transformers/models/efficientnet/modeling_efficientnet.py @@ -436,12 +436,13 @@ class EfficientNetPreTrainedModel(PreTrainedModel): input_modalities = "image" _no_split_modules = [] + @torch.no_grad() def _init_weights(self, module: nn.Module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() @auto_docstring diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index cb915277f6bb..2fd477541986 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -532,19 +532,20 @@ class ElectraPreTrainedModel(PreTrainedModel): "cross_attentions": ElectraCrossAttention, } + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) @dataclass @@ -1004,7 +1005,7 @@ def forward( """ ) class ElectraForMaskedLM(ElectraPreTrainedModel): - _tied_weights_keys = ["generator_lm_head.weight"] + _tied_weights_keys = {"generator_lm_head.weight": "electra.embeddings.word_embeddings.weight"} def __init__(self, config): super().__init__(config) @@ -1304,7 +1305,7 @@ def forward( """ ) class ElectraForCausalLM(ElectraPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["generator_lm_head.weight"] + _tied_weights_keys = {"generator_lm_head.weight": "electra.embeddings.word_embeddings.weight"} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index e2d1b1c98535..3ccd79801601 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -938,6 +938,7 @@ class Emu3VQVAE(PreTrainedModel): "Emu3VQVAEVectorQuantizer", ] + @torch.no_grad() def _init_weights(self, module): if isinstance(module, (nn.Conv2d, nn.Conv3d)): nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") @@ -955,9 +956,9 @@ def _init_weights(self, module): nn.init.constant_(module.weight, 1.0) nn.init.constant_(module.bias, 0.0) elif isinstance(module, nn.Embedding): - module.weight.data.normal_() + module.weight.normal_() if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() def __init__(self, config: Emu3VQVAEConfig): super().__init__(config) @@ -1258,7 +1259,7 @@ def forward( @auto_docstring class Emu3ForCausalLM(Emu3PreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} config: Emu3TextConfig @@ -1489,7 +1490,7 @@ def forward( class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): base_model_prefix = "" output_modalities = ["image", "text"] - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _checkpoint_conversion_mapping = { "^text_model.model": "model.text_model", "^vqmodel": "model.vqmodel", diff --git a/src/transformers/models/emu3/modular_emu3.py b/src/transformers/models/emu3/modular_emu3.py index 0dfadf53ad80..bd85a98641df 100644 --- a/src/transformers/models/emu3/modular_emu3.py +++ b/src/transformers/models/emu3/modular_emu3.py @@ -688,6 +688,7 @@ class Emu3VQVAE(PreTrainedModel): "Emu3VQVAEVectorQuantizer", ] + @torch.no_grad() def _init_weights(self, module): if isinstance(module, (nn.Conv2d, nn.Conv3d)): nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") @@ -705,9 +706,9 @@ def _init_weights(self, module): nn.init.constant_(module.weight, 1.0) nn.init.constant_(module.bias, 0.0) elif isinstance(module, nn.Embedding): - module.weight.data.normal_() + module.weight.normal_() if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() def __init__(self, config: Emu3VQVAEConfig): super().__init__(config) @@ -1043,7 +1044,7 @@ def forward( class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): base_model_prefix = "" output_modalities = ["image", "text"] - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _checkpoint_conversion_mapping = { "^text_model.model": "model.text_model", "^vqmodel": "model.vqmodel", diff --git a/src/transformers/models/encodec/modeling_encodec.py b/src/transformers/models/encodec/modeling_encodec.py index c3c32f5bd61d..a9449caa707f 100644 --- a/src/transformers/models/encodec/modeling_encodec.py +++ b/src/transformers/models/encodec/modeling_encodec.py @@ -454,11 +454,12 @@ class EncodecPreTrainedModel(PreTrainedAudioTokenizerBase): base_model_prefix = "encodec" main_input_name = "input_values" + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.GroupNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, nn.Conv1d): nn.init.kaiming_normal_(module.weight) if module.bias is not None: diff --git a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py index 6944045ddd16..e62cb8f623cc 100644 --- a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py @@ -166,24 +166,7 @@ def __init__( # tie encoder, decoder weights if config set accordingly self.tie_weights() - def tie_weights(self): - self.encoder.tie_weights() - self.decoder.tie_weights() - # tie encoder & decoder if needed - if self.config.tie_encoder_decoder: - # tie encoder and decoder base model - decoder_base_model_prefix = self.decoder.base_model_prefix - tied_weights = self._tie_encoder_decoder_weights( - self.encoder, - self.decoder._modules[decoder_base_model_prefix], - self.decoder.base_model_prefix, - "encoder", - ) - # Setting a dynamic variable instead of `_tied_weights_keys` because it's a class - # attributed not an instance member, therefore modifying it will modify the entire class - # Leading to issues on subsequent calls by different tests or subsequent calls. - self._dynamic_tied_weights_keys = tied_weights - + @torch.no_grad() def _init_weights(self, module): if module in self.encoder.modules(): self.encoder._init_weights(module) diff --git a/src/transformers/models/eomt/modeling_eomt.py b/src/transformers/models/eomt/modeling_eomt.py index 8579e1b7a443..e52e98364c09 100644 --- a/src/transformers/models/eomt/modeling_eomt.py +++ b/src/transformers/models/eomt/modeling_eomt.py @@ -996,6 +996,7 @@ class EomtPreTrainedModel(PreTrainedModel): "attentions": EomtAttention, } + @torch.no_grad() def _init_weights(self, module: nn.Module) -> None: std = self.config.initializer_range if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): @@ -1005,20 +1006,20 @@ def _init_weights(self, module: nn.Module) -> None: bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 nn.init.uniform_(module.bias, -bound, bound) elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=1) + module.weight.normal_(mean=0.0, std=1) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, EomtLayerScale): if hasattr(module, "lambda1"): - module.lambda1.data.fill_(self.config.layerscale_value) + module.lambda1.fill_(self.config.layerscale_value) elif isinstance(module, EomtEmbeddings): - module.cls_token.data = nn.init.trunc_normal_( - module.cls_token.data.to(torch.float32), mean=0.0, std=std - ).to(module.cls_token.dtype) - module.register_tokens.data.zero_() + module.cls_token.copy_( + nn.init.trunc_normal_(module.cls_token.to(torch.float32), mean=0.0, std=std).to(module.cls_token.dtype) + ) + module.register_tokens.zero_() @auto_docstring( diff --git a/src/transformers/models/eomt/modular_eomt.py b/src/transformers/models/eomt/modular_eomt.py index be66a7b7598d..2c95affa154e 100644 --- a/src/transformers/models/eomt/modular_eomt.py +++ b/src/transformers/models/eomt/modular_eomt.py @@ -401,6 +401,7 @@ class EomtPreTrainedModel(PreTrainedModel): "attentions": EomtAttention, } + @torch.no_grad() def _init_weights(self, module: nn.Module) -> None: std = self.config.initializer_range if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): @@ -410,20 +411,20 @@ def _init_weights(self, module: nn.Module) -> None: bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 nn.init.uniform_(module.bias, -bound, bound) elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=1) + module.weight.normal_(mean=0.0, std=1) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, EomtLayerScale): if hasattr(module, "lambda1"): - module.lambda1.data.fill_(self.config.layerscale_value) + module.lambda1.fill_(self.config.layerscale_value) elif isinstance(module, EomtEmbeddings): - module.cls_token.data = nn.init.trunc_normal_( - module.cls_token.data.to(torch.float32), mean=0.0, std=std - ).to(module.cls_token.dtype) - module.register_tokens.data.zero_() + module.cls_token.copy_( + nn.init.trunc_normal_(module.cls_token.to(torch.float32), mean=0.0, std=std).to(module.cls_token.dtype) + ) + module.register_tokens.zero_() @auto_docstring( diff --git a/src/transformers/models/ernie/modeling_ernie.py b/src/transformers/models/ernie/modeling_ernie.py index b45e56d587c0..24890d50ac2e 100644 --- a/src/transformers/models/ernie/modeling_ernie.py +++ b/src/transformers/models/ernie/modeling_ernie.py @@ -488,16 +488,9 @@ def __init__(self, config): # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. - self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` - self.decoder.bias = self.bias - - def _tie_weights(self): - self.decoder.bias = self.bias - def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -553,23 +546,24 @@ class ErniePreTrainedModel(PreTrainedModel): "cross_attentions": ErnieCrossAttention, } + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, ErnieLMPredictionHead): - module.bias.data.zero_() + module.bias.zero_() @auto_docstring( @@ -788,7 +782,10 @@ def forward(self, sequence_output, pooled_output): """ ) class ErnieForPreTraining(ErniePreTrainedModel): - _tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"] + _tied_weights_keys = { + "cls.predictions.decoder.bias": "cls.predictions.bias", + "cls.predictions.decoder.weight": "ernie.embeddings.word_embeddings.weight", + } def __init__(self, config): super().__init__(config) @@ -899,7 +896,10 @@ def forward(self, sequence_output: torch.Tensor) -> torch.Tensor: """ ) class ErnieForCausalLM(ErniePreTrainedModel, GenerationMixin): - _tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"] + _tied_weights_keys = { + "cls.predictions.decoder.weight": "ernie.embeddings.word_embeddings.weight", + "cls.predictions.decoder.bias": "cls.predictions.bias", + } def __init__(self, config): super().__init__(config) @@ -990,7 +990,10 @@ def forward( @auto_docstring class ErnieForMaskedLM(ErniePreTrainedModel): - _tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"] + _tied_weights_keys = { + "cls.predictions.decoder.bias": "cls.predictions.bias", + "cls.predictions.decoder.weight": "ernie.embeddings.word_embeddings.weight", + } def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/ernie/modular_ernie.py b/src/transformers/models/ernie/modular_ernie.py index 491ce971e24b..4bf0440d7c16 100644 --- a/src/transformers/models/ernie/modular_ernie.py +++ b/src/transformers/models/ernie/modular_ernie.py @@ -162,23 +162,24 @@ class ErniePreTrainedModel(PreTrainedModel): "cross_attentions": ErnieCrossAttention, } + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, ErnieLMPredictionHead): - module.bias.data.zero_() + module.bias.zero_() class ErnieModel(BertModel): @@ -337,7 +338,10 @@ class ErnieForPreTrainingOutput(BertForPreTrainingOutput): class ErnieForPreTraining(BertForPreTraining): - _tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"] + _tied_weights_keys = { + "cls.predictions.decoder.bias": "cls.predictions.bias", + "cls.predictions.decoder.weight": "ernie.embeddings.word_embeddings.weight", + } @can_return_tuple @auto_docstring @@ -486,7 +490,10 @@ def forward( class ErnieForMaskedLM(BertForMaskedLM): - _tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"] + _tied_weights_keys = { + "cls.predictions.decoder.bias": "cls.predictions.bias", + "cls.predictions.decoder.weight": "ernie.embeddings.word_embeddings.weight", + } @can_return_tuple @auto_docstring diff --git a/src/transformers/models/ernie4_5/modeling_ernie4_5.py b/src/transformers/models/ernie4_5/modeling_ernie4_5.py index 5658c7691c3c..68d279fb9abf 100644 --- a/src/transformers/models/ernie4_5/modeling_ernie4_5.py +++ b/src/transformers/models/ernie4_5/modeling_ernie4_5.py @@ -432,7 +432,7 @@ def forward( @auto_docstring class Ernie4_5ForCausalLM(Ernie4_5PreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py b/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py index c2dbd8d436d8..8ff07d9f638f 100644 --- a/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +++ b/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py @@ -315,45 +315,64 @@ def forward(self, hidden_states): return hidden_states + self.e_score_correction_bias.squeeze() -class Ernie4_5_MoeExperts(nn.ModuleList): +class Ernie4_5_MoeExperts(nn.Module): def __init__(self, config): super().__init__() self.num_experts = config.moe_num_experts - for _ in range(self.num_experts): - self.append(Ernie4_5_MoeMLP(config, config.moe_intermediate_size)) + self.hidden_dim = config.hidden_size + self.intermediate_dim = config.moe_intermediate_size + self.use_bias = config.use_bias + self.act_fn = ACT2FN[config.hidden_act] + + self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) + self.down_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim, self.intermediate_dim)) + if self.use_bias: + self.gate_up_proj_bias = nn.Parameter(torch.zeros(self.num_experts, 2 * self.intermediate_dim)) + self.down_proj_bias = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim)) + else: + self.gate_up_proj_bias = None + self.down_proj_bias = None def forward( self, hidden_states: torch.Tensor, selected_experts: torch.Tensor, routing_weights: torch.Tensor ) -> torch.Tensor: final_hidden_states = torch.zeros_like(hidden_states) + if selected_experts.numel() == 0: + return final_hidden_states + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() for expert_idx in expert_hit: + expert_idx = int(expert_idx.item()) idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1]) - current_hidden_states = self[expert_idx](current_state) * routing_weights[top_x, idx, None] + gate_inputs = F.linear( + current_state, + self.gate_up_proj[expert_idx], + None if self.gate_up_proj_bias is None else self.gate_up_proj_bias[expert_idx], + ) + gate, up = gate_inputs.chunk(2, dim=-1) + current_hidden_states = self.act_fn(gate) * up + current_hidden_states = F.linear( + current_hidden_states, + self.down_proj[expert_idx], + None if self.down_proj_bias is None else self.down_proj_bias[expert_idx], + ) + current_hidden_states = current_hidden_states * routing_weights[top_x, idx, None] final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) return final_hidden_states -class Ernie4_5_MoeSparseMoeBlock(nn.Module): +class Ernie4_5_MoeTopKRouter(nn.Module): def __init__(self, config): super().__init__() - self.hidden_dim = config.hidden_size - self.num_experts = config.moe_num_experts + self.weight = nn.Parameter(torch.zeros(config.moe_num_experts, config.hidden_size, dtype=torch.float32)) + self.moe_statics = Ernie4_5_MoeStatics(config) self.top_k = config.moe_k self.norm_min = config.moe_norm_min - self.gate = nn.Linear(config.hidden_size, config.moe_num_experts, bias=False, dtype=torch.float32) - self.moe_statics = Ernie4_5_MoeStatics(config) - self.experts = Ernie4_5_MoeExperts(config) - - self.shared_experts = None - if config.moe_num_shared_experts > 0: - self.shared_experts = Ernie4_5_MoeMLP(config, config.moe_intermediate_size * config.moe_num_shared_experts) - - def route_tokens_to_experts(self, hidden_states): + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: device_type = ( hidden_states.device.type if isinstance(hidden_states.device.type, str) and hidden_states.device.type != "mps" @@ -361,7 +380,7 @@ def route_tokens_to_experts(self, hidden_states): ) with torch.autocast(device_type=device_type, enabled=False): # Force float32 - router_logits = self.gate(hidden_states.float()) + router_logits = F.linear(hidden_states.float(), self.weight) routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) _, selected_experts = torch.topk(self.moe_statics(routing_weights), self.top_k, dim=-1) routing_weights = torch.gather(routing_weights, dim=-1, index=selected_experts) @@ -369,7 +388,21 @@ def route_tokens_to_experts(self, hidden_states): routing_weights.sum(dim=-1, keepdim=True), min=self.norm_min ) routing_weights = routing_weights.to(router_logits.dtype) - return selected_experts, routing_weights + return routing_weights, selected_experts + + +class Ernie4_5_MoeSparseMoeBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.hidden_dim = config.hidden_size + self.num_experts = config.moe_num_experts + self.top_k = config.moe_k + self.gate = Ernie4_5_MoeTopKRouter(config) + self.experts = Ernie4_5_MoeExperts(config) + + self.shared_experts = None + if config.moe_num_shared_experts > 0: + self.shared_experts = Ernie4_5_MoeMLP(config, config.moe_intermediate_size * config.moe_num_shared_experts) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, _ = hidden_states.shape @@ -378,14 +411,14 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.shared_experts is not None: shared_output = self.shared_experts(hidden_states) - selected_experts, routing_weights = self.route_tokens_to_experts(hidden_states) + routing_weights, selected_experts = self.gate(hidden_states) final_hidden_states = self.experts(hidden_states, selected_experts, routing_weights) if self.shared_experts is not None: final_hidden_states = final_hidden_states + shared_output final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, self.hidden_dim) - return final_hidden_states + return final_hidden_states.to(hidden_states.dtype) class Ernie4_5_MoeDecoderLayer(GradientCheckpointingLayer): @@ -454,18 +487,19 @@ class Ernie4_5_MoePreTrainedModel(PreTrainedModel): _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) _supports_attention_backend = True _can_record_outputs = { - "router_logits": OutputRecorder(nn.Linear, layer_name="mlp.gate", index=0), + "router_logits": OutputRecorder(Ernie4_5_MoeTopKRouter, layer_name="mlp.gate", index=0), "hidden_states": Ernie4_5_MoeDecoderLayer, "attentions": Ernie4_5_MoeAttention, } - _keep_in_fp32_modules_strict = ["gate", "moe_statics"] # Not supporting multi-token prediction (MTP) atm _keys_to_ignore_on_load_unexpected = ["mtp"] + _keep_in_fp32_modules_strict = ["gate.weight", "moe_statics"] + @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) if isinstance(module, Ernie4_5_MoeStatics): - module.e_score_correction_bias.data.zero_() + module.e_score_correction_bias.zero_() @auto_docstring @@ -634,7 +668,7 @@ def load_balancing_loss_func( @auto_docstring class Ernie4_5_MoeForCausalLM(Ernie4_5_MoePreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py b/src/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py index b12958b785b7..fe403f81afad 100644 --- a/src/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +++ b/src/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py @@ -19,6 +19,7 @@ import torch.nn.functional as F from torch import nn +from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...masking_utils import create_causal_mask from ...modeling_outputs import MoeModelOutputWithPast @@ -96,45 +97,64 @@ def forward(self, hidden_states): return hidden_states + self.e_score_correction_bias.squeeze() -class Ernie4_5_MoeExperts(nn.ModuleList): +class Ernie4_5_MoeExperts(nn.Module): def __init__(self, config): super().__init__() self.num_experts = config.moe_num_experts - for _ in range(self.num_experts): - self.append(Ernie4_5_MoeMLP(config, config.moe_intermediate_size)) + self.hidden_dim = config.hidden_size + self.intermediate_dim = config.moe_intermediate_size + self.use_bias = config.use_bias + self.act_fn = ACT2FN[config.hidden_act] + + self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) + self.down_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim, self.intermediate_dim)) + if self.use_bias: + self.gate_up_proj_bias = nn.Parameter(torch.zeros(self.num_experts, 2 * self.intermediate_dim)) + self.down_proj_bias = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim)) + else: + self.gate_up_proj_bias = None + self.down_proj_bias = None def forward( self, hidden_states: torch.Tensor, selected_experts: torch.Tensor, routing_weights: torch.Tensor ) -> torch.Tensor: final_hidden_states = torch.zeros_like(hidden_states) + if selected_experts.numel() == 0: + return final_hidden_states + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() for expert_idx in expert_hit: + expert_idx = int(expert_idx.item()) idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1]) - current_hidden_states = self[expert_idx](current_state) * routing_weights[top_x, idx, None] + gate_inputs = F.linear( + current_state, + self.gate_up_proj[expert_idx], + None if self.gate_up_proj_bias is None else self.gate_up_proj_bias[expert_idx], + ) + gate, up = gate_inputs.chunk(2, dim=-1) + current_hidden_states = self.act_fn(gate) * up + current_hidden_states = F.linear( + current_hidden_states, + self.down_proj[expert_idx], + None if self.down_proj_bias is None else self.down_proj_bias[expert_idx], + ) + current_hidden_states = current_hidden_states * routing_weights[top_x, idx, None] final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) return final_hidden_states -class Ernie4_5_MoeSparseMoeBlock(nn.Module): +class Ernie4_5_MoeTopKRouter(nn.Module): def __init__(self, config): super().__init__() - self.hidden_dim = config.hidden_size - self.num_experts = config.moe_num_experts + self.weight = nn.Parameter(torch.zeros(config.moe_num_experts, config.hidden_size, dtype=torch.float32)) + self.moe_statics = Ernie4_5_MoeStatics(config) self.top_k = config.moe_k self.norm_min = config.moe_norm_min - self.gate = nn.Linear(config.hidden_size, config.moe_num_experts, bias=False, dtype=torch.float32) - self.moe_statics = Ernie4_5_MoeStatics(config) - self.experts = Ernie4_5_MoeExperts(config) - - self.shared_experts = None - if config.moe_num_shared_experts > 0: - self.shared_experts = Ernie4_5_MoeMLP(config, config.moe_intermediate_size * config.moe_num_shared_experts) - - def route_tokens_to_experts(self, hidden_states): + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: device_type = ( hidden_states.device.type if isinstance(hidden_states.device.type, str) and hidden_states.device.type != "mps" @@ -142,7 +162,7 @@ def route_tokens_to_experts(self, hidden_states): ) with torch.autocast(device_type=device_type, enabled=False): # Force float32 - router_logits = self.gate(hidden_states.float()) + router_logits = F.linear(hidden_states.float(), self.weight) routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) _, selected_experts = torch.topk(self.moe_statics(routing_weights), self.top_k, dim=-1) routing_weights = torch.gather(routing_weights, dim=-1, index=selected_experts) @@ -150,7 +170,21 @@ def route_tokens_to_experts(self, hidden_states): routing_weights.sum(dim=-1, keepdim=True), min=self.norm_min ) routing_weights = routing_weights.to(router_logits.dtype) - return selected_experts, routing_weights + return routing_weights, selected_experts + + +class Ernie4_5_MoeSparseMoeBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.hidden_dim = config.hidden_size + self.num_experts = config.moe_num_experts + self.top_k = config.moe_k + self.gate = Ernie4_5_MoeTopKRouter(config) + self.experts = Ernie4_5_MoeExperts(config) + + self.shared_experts = None + if config.moe_num_shared_experts > 0: + self.shared_experts = Ernie4_5_MoeMLP(config, config.moe_intermediate_size * config.moe_num_shared_experts) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, _ = hidden_states.shape @@ -159,14 +193,14 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.shared_experts is not None: shared_output = self.shared_experts(hidden_states) - selected_experts, routing_weights = self.route_tokens_to_experts(hidden_states) + routing_weights, selected_experts = self.gate(hidden_states) final_hidden_states = self.experts(hidden_states, selected_experts, routing_weights) if self.shared_experts is not None: final_hidden_states = final_hidden_states + shared_output final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, self.hidden_dim) - return final_hidden_states + return final_hidden_states.to(hidden_states.dtype) class Ernie4_5_MoeDecoderLayer(Qwen3MoeDecoderLayer): @@ -193,19 +227,20 @@ def __init__(self, config, layer_idx): class Ernie4_5_MoePreTrainedModel(MixtralPreTrainedModel): config: Ernie4_5_MoeConfig _no_split_modules = ["Ernie4_5_MoeDecoderLayer"] - _keep_in_fp32_modules_strict = ["gate", "moe_statics"] # Not supporting multi-token prediction (MTP) atm _keys_to_ignore_on_load_unexpected = ["mtp"] _can_record_outputs = { - "router_logits": OutputRecorder(nn.Linear, layer_name="mlp.gate", index=0), + "router_logits": OutputRecorder(Ernie4_5_MoeTopKRouter, layer_name="mlp.gate", index=0), "hidden_states": Ernie4_5_MoeDecoderLayer, "attentions": Ernie4_5_MoeAttention, } + _keep_in_fp32_modules_strict = ["gate.weight", "moe_statics"] + @torch.no_grad() def _init_weights(self, module): PreTrainedModel._init_weights(self, module) if isinstance(module, Ernie4_5_MoeStatics): - module.e_score_correction_bias.data.zero_() + module.e_score_correction_bias.zero_() @auto_docstring diff --git a/src/transformers/models/esm/modeling_esm.py b/src/transformers/models/esm/modeling_esm.py index 358370d0f9f0..a3f1fbdf58b5 100755 --- a/src/transformers/models/esm/modeling_esm.py +++ b/src/transformers/models/esm/modeling_esm.py @@ -551,22 +551,22 @@ class EsmPreTrainedModel(PreTrainedModel): ], } - # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights with BertLMPredictionHead->EsmLMHead + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, EsmLMHead): - module.bias.data.zero_() + module.bias.zero_() def get_output_embeddings(self): # NOTE: get_output_embeddings() must return None to prevent accidental weight tying. @@ -727,7 +727,7 @@ def predict_contacts(self, tokens, attention_mask): @auto_docstring class EsmForMaskedLM(EsmPreTrainedModel): - _tied_weights_keys = ["lm_head.decoder.weight"] + _tied_weights_keys = {"lm_head.decoder.weight": "esm.embeddings.word_embeddings.weight"} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/esm/modeling_esmfold.py b/src/transformers/models/esm/modeling_esmfold.py index b08d3569de17..0c676d631b24 100644 --- a/src/transformers/models/esm/modeling_esmfold.py +++ b/src/transformers/models/esm/modeling_esmfold.py @@ -915,6 +915,7 @@ class EsmFoldPreTrainedModel(EsmPreTrainedModel): """ # Subclass `EsMPreTrainedModel` to deal with special init + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, EsmFoldLinear): diff --git a/src/transformers/models/evolla/modeling_evolla.py b/src/transformers/models/evolla/modeling_evolla.py index c405df1bb85c..994ce020f811 100644 --- a/src/transformers/models/evolla/modeling_evolla.py +++ b/src/transformers/models/evolla/modeling_evolla.py @@ -517,20 +517,21 @@ class EvollaSaProtPreTrainedModel(PreTrainedModel): ], } + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" std = self.config.initializer_range if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) class EvollaSaProtProteinEncoder(EvollaSaProtPreTrainedModel): @@ -1268,15 +1269,16 @@ class EvollaPreTrainedModel(PreTrainedModel): "attentions": EvollaAttention, } + @torch.no_grad() def _init_weights(self, module): std = self.config.initializer_range super()._init_weights(module) if isinstance(module, EvollaSequenceAlignerCrossAttention): module.gate_attention.zero_() module.gate_ffw.zero_() - module.attention_norm.weight.data.fill_(1.0) + module.attention_norm.weight.fill_(1.0) elif isinstance(module, EvollaSequenceCompressorResampler): - module.latents.data.normal_(mean=0.0, std=std) + module.latents.normal_(mean=0.0, std=std) class EvollaModel(EvollaPreTrainedModel): diff --git a/src/transformers/models/evolla/modular_evolla.py b/src/transformers/models/evolla/modular_evolla.py index 51d327370ee3..b31f6645c5be 100644 --- a/src/transformers/models/evolla/modular_evolla.py +++ b/src/transformers/models/evolla/modular_evolla.py @@ -202,20 +202,21 @@ class EvollaSaProtPreTrainedModel(PreTrainedModel): ], } + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" std = self.config.initializer_range if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) class EvollaSaProtProteinEncoder(EvollaSaProtPreTrainedModel): @@ -732,15 +733,16 @@ class EvollaPreTrainedModel(LlamaPreTrainedModel): "EvollaSequenceAlignerCrossAttention", ] + @torch.no_grad() def _init_weights(self, module): std = self.config.initializer_range PreTrainedModel._init_weights(self, module) if isinstance(module, EvollaSequenceAlignerCrossAttention): module.gate_attention.zero_() module.gate_ffw.zero_() - module.attention_norm.weight.data.fill_(1.0) + module.attention_norm.weight.fill_(1.0) elif isinstance(module, EvollaSequenceCompressorResampler): - module.latents.data.normal_(mean=0.0, std=std) + module.latents.normal_(mean=0.0, std=std) class EvollaModel(EvollaPreTrainedModel): diff --git a/src/transformers/models/exaone4/modeling_exaone4.py b/src/transformers/models/exaone4/modeling_exaone4.py index efc82d192f02..cb70c9cff142 100644 --- a/src/transformers/models/exaone4/modeling_exaone4.py +++ b/src/transformers/models/exaone4/modeling_exaone4.py @@ -455,7 +455,7 @@ def forward( @auto_docstring class Exaone4ForCausalLM(Exaone4PreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 1b89172a19cd..4446169eb6c6 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -678,19 +678,20 @@ class FalconPreTrainedModel(PreTrainedModel): def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) + @torch.no_grad() def _init_weights(self, module: nn.Module): """Initialize the weights.""" if isinstance(module, (nn.Linear, FalconLinear)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) # Adapted from transformers.modeling_utils.PreTrainedModel._check_and_enable_sdpa @classmethod @@ -1001,7 +1002,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( """ ) class FalconForCausalLM(FalconPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "transformer.word_embeddings.weight"} def __init__(self, config: FalconConfig): super().__init__(config) diff --git a/src/transformers/models/falcon_h1/modeling_falcon_h1.py b/src/transformers/models/falcon_h1/modeling_falcon_h1.py index 28117b49d52b..f15f8ee1c3b1 100644 --- a/src/transformers/models/falcon_h1/modeling_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modeling_falcon_h1.py @@ -1194,21 +1194,23 @@ class FalconH1PreTrainedModel(PreTrainedModel): _supports_sdpa = True _is_stateful = True + @torch.no_grad() def _init_weights(self, module): std = self.config.initializer_range - for name, param in module.named_parameters(recurse=True): - if not param.requires_grad: - continue - if "layernorm" in name.lower() and "weight" in name: - # LayerNorm weights usually initialized to 1 - param.data.fill_(1.0) - elif "bias" in name: - param.data.zero_() - else: - try: - param.data.normal_(mean=0.0, std=std) - except Exception as e: - print(f"Skipping init for {name} due to error: {e}") + if isinstance(module, nn.Module): + for name, param in module.named_parameters(recurse=True): + if not param.requires_grad: + continue + if "layernorm" in name.lower() and "weight" in name: + # LayerNorm weights usually initialized to 1 + param.fill_(1.0) + elif "bias" in name: + param.zero_() + else: + try: + param.normal_(mean=0.0, std=std) + except Exception as e: + print(f"Skipping init for {name} due to error: {e}") def compute_mup_vector(config): @@ -1503,7 +1505,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( @auto_docstring class FalconH1ForCausalLM(FalconH1PreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/falcon_h1/modular_falcon_h1.py b/src/transformers/models/falcon_h1/modular_falcon_h1.py index 62cbab82c3e6..5371cab2bf20 100644 --- a/src/transformers/models/falcon_h1/modular_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modular_falcon_h1.py @@ -920,21 +920,23 @@ class FalconH1PreTrainedModel(PreTrainedModel): _supports_sdpa = True _is_stateful = True + @torch.no_grad() def _init_weights(self, module): std = self.config.initializer_range - for name, param in module.named_parameters(recurse=True): - if not param.requires_grad: - continue - if "layernorm" in name.lower() and "weight" in name: - # LayerNorm weights usually initialized to 1 - param.data.fill_(1.0) - elif "bias" in name: - param.data.zero_() - else: - try: - param.data.normal_(mean=0.0, std=std) - except Exception as e: - print(f"Skipping init for {name} due to error: {e}") + if isinstance(module, nn.Module): + for name, param in module.named_parameters(recurse=True): + if not param.requires_grad: + continue + if "layernorm" in name.lower() and "weight" in name: + # LayerNorm weights usually initialized to 1 + param.fill_(1.0) + elif "bias" in name: + param.zero_() + else: + try: + param.normal_(mean=0.0, std=std) + except Exception as e: + print(f"Skipping init for {name} due to error: {e}") def compute_mup_vector(config): diff --git a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py index b5f03cfe7076..d7acfd8f1a53 100644 --- a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py @@ -568,6 +568,7 @@ class FalconMambaPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _is_stateful = True + @torch.no_grad() def _init_weights(self, module): """Initialize the weights.""" std = self.config.initializer_range @@ -577,7 +578,7 @@ def _init_weights(self, module): A = torch.arange(1, module.ssm_state_size + 1, dtype=torch.float32)[None, :] A = A.expand(module.intermediate_size, -1).contiguous() module.A_log.copy_(torch.log(A)) - module.D.data.fill_(1.0) + module.D.fill_(1.0) dt_init_std = self.config.time_step_rank**-0.5 * self.config.time_step_scale if self.config.time_step_init_scheme == "constant": @@ -622,7 +623,7 @@ def _init_weights(self, module): if not getattr(module.bias, "_no_reinit", False): nn.init.zeros_(module.bias) elif isinstance(module, FalconMambaRMSNorm): - module.weight.data.fill_(1.0) + module.weight.fill_(1.0) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, std=std) @@ -780,7 +781,7 @@ def forward( """ ) class FalconMambaForCausalLM(FalconMambaPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "backbone.embeddings.weight"} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py b/src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py index fa1544a0171c..51f50d298e27 100644 --- a/src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +++ b/src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py @@ -991,24 +991,25 @@ class FastSpeech2ConformerPreTrainedModel(PreTrainedModel): main_input_name = "input_ids" + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): nn.init.normal_(module.weight, std=1.0 / math.sqrt(module.weight.size(1))) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Conv1d): nn.init.kaiming_normal_(module.weight) if module.bias is not None: key = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) nn.init.uniform_(module.bias, a=-key, b=key) elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d)): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, nn.Embedding): - module.weight.data.normal_() + module.weight.normal_() if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, FastSpeech2ConformerAttention): nn.init.xavier_uniform_(module.pos_bias_u) nn.init.xavier_uniform_(module.pos_bias_v) @@ -1403,12 +1404,13 @@ def __init__(self, config: FastSpeech2ConformerHifiGanConfig): # Initialize weights and apply final processing self.post_init() + @torch.no_grad() def _init_weights(self, module: nn.Module): """Initialize the weights.""" if isinstance(module, (nn.Conv1d, nn.ConvTranspose1d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() def apply_weight_norm(self): weight_norm = nn.utils.weight_norm diff --git a/src/transformers/models/flaubert/modeling_flaubert.py b/src/transformers/models/flaubert/modeling_flaubert.py index 4dcef63f3f49..5a22aff9c047 100644 --- a/src/transformers/models/flaubert/modeling_flaubert.py +++ b/src/transformers/models/flaubert/modeling_flaubert.py @@ -671,21 +671,22 @@ def dummy_inputs(self): langs_list = None return {"input_ids": inputs_list, "attention_mask": attns_list, "langs": langs_list} + @torch.no_grad() def _init_weights(self, module): """Initialize the weights.""" if isinstance(module, nn.Embedding): if self.config is not None and self.config.embed_init_std is not None: nn.init.normal_(module.weight, mean=0, std=self.config.embed_init_std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() if isinstance(module, nn.Linear): if self.config is not None and self.config.init_std is not None: nn.init.normal_(module.weight, mean=0, std=self.config.init_std) if module.bias is not None: nn.init.constant_(module.bias, 0.0) if isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) if isinstance(module, FlaubertModel) and self.config.sinusoidal_embeddings: create_sinusoidal_embeddings( self.config.max_position_embeddings, self.config.emb_dim, out=module.position_embeddings.weight @@ -947,7 +948,7 @@ def forward( """ ) class FlaubertWithLMHeadModel(FlaubertPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["pred_layer.proj.weight"] + _tied_weights_keys = {"pred_layer.proj.weight": "transformer.embeddings.weight"} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/flava/modeling_flava.py b/src/transformers/models/flava/modeling_flava.py index 8a19b90ac2cf..bcca5d13d528 100644 --- a/src/transformers/models/flava/modeling_flava.py +++ b/src/transformers/models/flava/modeling_flava.py @@ -665,31 +665,32 @@ class FlavaPreTrainedModel(PreTrainedModel): input_modalities = ["image", "text"] supports_gradient_checkpointing = True + @torch.no_grad() def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, FlavaMaskedPredictionHead): - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, FlavaImageEmbeddings): - module.cls_token.data.zero_() - module.position_embeddings.data.zero_() + module.cls_token.zero_() + module.position_embeddings.zero_() if module.mask_token is not None: - module.mask_token.data.zero_() + module.mask_token.zero_() elif isinstance(module, FlavaMultimodalModel): if module.use_cls_token: - module.cls_token.data.zero_() + module.cls_token.zero_() elif isinstance(module, FlavaModel): - module.logit_scale.data.fill_(self.config.logit_scale_init_value) + module.logit_scale.fill_(self.config.logit_scale_init_value) @auto_docstring @@ -1445,17 +1446,11 @@ def __init__(self, config, weight=None): super().__init__() self.config = config self.transform = FlavaPredictionHeadTransform(config) - self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) if weight is not None: self.decoder.weight = weight - # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` - self.decoder.bias = self.bias - - def _tie_weights(self): - self.decoder.bias = self.bias - def forward(self, x): x = self.transform(x) x = self.decoder(x) @@ -1522,12 +1517,12 @@ def forward(self, image_embeddings, text_embeddings, logit_scale): ) class FlavaForPreTraining(FlavaPreTrainedModel): # Those are linked to xxx.bias - _tied_weights_keys = [ - "mmm_text_head.decoder.bias", - "mmm_image_head.decoder.bias", - "mlm_head.decoder.bias", - "mim_head.decoder.bias", - ] + _tied_weights_keys = { + "mmm_text_head.bias": "mmm_text_head.decoder.bias", + "mim_head.bias": "mim_head.decoder.bias", + "mlm_head.bias": "mlm_head.decoder.bias", + "mmm_image_head.bias": "mmm_image_head.decoder.bias", + } def __init__(self, config: FlavaConfig, image_codebook: Optional[nn.Module] = None): r""" diff --git a/src/transformers/models/flex_olmo/configuration_flex_olmo.py b/src/transformers/models/flex_olmo/configuration_flex_olmo.py index 515301b93c0c..4cc2cbe6f7f3 100644 --- a/src/transformers/models/flex_olmo/configuration_flex_olmo.py +++ b/src/transformers/models/flex_olmo/configuration_flex_olmo.py @@ -109,6 +109,7 @@ class FlexOlmoConfig(PreTrainedConfig): model_type = "flex_olmo" keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = {"num_local_experts": "num_experts"} base_model_tp_plan = { "layers.*.self_attn.q_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k "layers.*.self_attn.k_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k diff --git a/src/transformers/models/flex_olmo/modeling_flex_olmo.py b/src/transformers/models/flex_olmo/modeling_flex_olmo.py index 01d10317cf09..fc65a865ecd8 100644 --- a/src/transformers/models/flex_olmo/modeling_flex_olmo.py +++ b/src/transformers/models/flex_olmo/modeling_flex_olmo.py @@ -23,6 +23,7 @@ from typing import Optional, Union import torch +import torch.nn.functional as F from torch import nn from ...activations import ACT2FN @@ -291,21 +292,23 @@ def forward( return attn_output, attn_weights -class FlexOlmoExperts(nn.ModuleList): - """ - ModuleList of experts. - """ +class FlexOlmoExperts(nn.Module): + """Collection of expert weights stored as 3D tensors.""" - def __init__(self, config): + def __init__(self, config: FlexOlmoConfig): super().__init__() - for _ in range(config.num_experts): - self.append(FlexOlmoMLP(config)) - self.num_experts = config.num_experts - self.top_k = config.num_experts_per_tok - self.norm_topk_prob = config.norm_topk_prob + self.num_experts = config.num_local_experts + self.hidden_dim = config.hidden_size + self.intermediate_dim = config.intermediate_size + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) + self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) + self.act_fn = ACT2FN[config.hidden_act] def forward( - self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, ) -> torch.Tensor: """ Args: @@ -316,39 +319,58 @@ def forward( (batch_size * sequence_length, hidden_dim) """ final_hidden_states = torch.zeros_like(hidden_states) - expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0) + num_experts = top_k_weights.shape[1] + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() for expert_idx in expert_hit: - idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) - current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1]) - current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None] - final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + expert_idx = expert_idx[0] + if expert_idx == num_experts: + continue + _, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) + current_hidden_states = self.act_fn(gate) * up + current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) + current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None] + final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) + return final_hidden_states -class FlexOlmoSparseMoeBlock(nn.Module): +class FlexOlmoTopKRouter(nn.Module): def __init__(self, config): super().__init__() - self.num_experts = config.num_experts self.top_k = config.num_experts_per_tok + self.num_experts = config.num_experts self.norm_topk_prob = config.norm_topk_prob - self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False) - self.experts = FlexOlmoExperts(config) + self.hidden_dim = config.hidden_size + self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim)) - def route_tokens_to_experts(self, hidden_states, router_logits): - routing_weights = torch.nn.functional.softmax(router_logits.float(), dim=-1) - top_k_weights, top_k_index = torch.topk(routing_weights, self.top_k, dim=-1) + def forward(self, hidden_states): + hidden_states = hidden_states.reshape(-1, self.hidden_dim) + router_logits = F.linear(hidden_states, self.weight) # (seq_len, num_experts) + router_logits = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1) + router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) if self.norm_topk_prob: - top_k_weights /= top_k_weights.sum(dim=-1, keepdim=True) - top_k_weights = top_k_weights.to(hidden_states.dtype) - return top_k_index, top_k_weights + router_top_value /= router_top_value.sum(dim=-1, keepdim=True) + router_top_value = router_top_value.to(router_logits.dtype) + router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) + return router_scores, router_indices + + +class FlexOlmoSparseMoeBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.gate = FlexOlmoTopKRouter(config) + self.experts = FlexOlmoExperts(config) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) - router_logits = self.gate(hidden_states) - top_k_index, top_k_weights = self.route_tokens_to_experts(hidden_states, router_logits) + top_k_weights, top_k_index = self.gate(hidden_states) final_hidden_states = self.experts(hidden_states, top_k_index, top_k_weights).reshape( batch_size, sequence_length, hidden_dim ) @@ -415,6 +437,16 @@ class FlexOlmoPreTrainedModel(PreTrainedModel): "attentions": FlexOlmoAttention, } + @torch.no_grad() + def _init_weights(self, module): + super()._init_weights(module) + std = self.config.initializer_range + if isinstance(module, FlexOlmoExperts): + module.gate_up_proj.normal_(mean=0.0, std=std) + module.down_proj.normal_(mean=0.0, std=std) + elif isinstance(module, FlexOlmoTopKRouter): + module.weight.normal_(mean=0.0, std=std) + @auto_docstring class FlexOlmoModel(FlexOlmoPreTrainedModel): @@ -582,7 +614,7 @@ def load_balancing_loss_func( @auto_docstring class FlexOlmoForCausalLM(FlexOlmoPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/florence2/modeling_florence2.py b/src/transformers/models/florence2/modeling_florence2.py index 4e1250231a99..66c46ebe7872 100644 --- a/src/transformers/models/florence2/modeling_florence2.py +++ b/src/transformers/models/florence2/modeling_florence2.py @@ -637,10 +637,6 @@ class Florence2PreTrainedModel(PreTrainedModel): ) class Florence2Model(Florence2PreTrainedModel): _checkpoint_conversion_mapping = {} - _tied_weights_keys = [ - "language_model.encoder.embed_tokens.weight", - "language_model.decoder.embed_tokens.weight", - ] def __init__(self, config: Florence2Config): super().__init__(config) @@ -806,11 +802,9 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start ) class Florence2ForConditionalGeneration(Florence2PreTrainedModel, GenerationMixin): _checkpoint_conversion_mapping = {} - _tied_weights_keys = [ - "model.language_model.encoder.embed_tokens.weight", - "model.language_model.decoder.embed_tokens.weight", - "lm_head.weight", - ] + _tied_weights_keys = { + "lm_head.weight": "model.language_model.shared.weight", + } def __init__(self, config: Florence2Config): super().__init__(config) diff --git a/src/transformers/models/florence2/modular_florence2.py b/src/transformers/models/florence2/modular_florence2.py index 6ae43c0b69a7..bbffc6d96e56 100644 --- a/src/transformers/models/florence2/modular_florence2.py +++ b/src/transformers/models/florence2/modular_florence2.py @@ -1511,10 +1511,6 @@ class Florence2PreTrainedModel(LlavaPreTrainedModel): ) class Florence2Model(LlavaModel): _checkpoint_conversion_mapping = {} - _tied_weights_keys = [ - "language_model.encoder.embed_tokens.weight", - "language_model.decoder.embed_tokens.weight", - ] def __init__(self, config: Florence2Config): super().__init__(config) @@ -1627,11 +1623,9 @@ def forward( ) class Florence2ForConditionalGeneration(LlavaForConditionalGeneration): _checkpoint_conversion_mapping = {} - _tied_weights_keys = [ - "model.language_model.encoder.embed_tokens.weight", - "model.language_model.decoder.embed_tokens.weight", - "lm_head.weight", - ] + _tied_weights_keys = { + "lm_head.weight": "model.language_model.shared.weight", + } def get_encoder(self): return self.model.get_encoder() diff --git a/src/transformers/models/fnet/modeling_fnet.py b/src/transformers/models/fnet/modeling_fnet.py index b8cdd1f2ea58..5cc5c870fa9e 100755 --- a/src/transformers/models/fnet/modeling_fnet.py +++ b/src/transformers/models/fnet/modeling_fnet.py @@ -325,27 +325,14 @@ class FNetLMPredictionHead(nn.Module): def __init__(self, config): super().__init__() self.transform = FNetPredictionHeadTransform(config) - - # The output weights are the same as the input embeddings, but there is - # an output-only bias for each token. self.decoder = nn.Linear(config.hidden_size, config.vocab_size) - self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - self.decoder.bias = self.bias def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) return hidden_states - def _tie_weights(self) -> None: - # For accelerate compatibility and to not break backward compatibility - if self.decoder.bias.device.type == "meta": - self.decoder.bias = self.bias - else: - # To tie those two weights if they get disconnected (on TPU or when the bias is resized) - self.bias = self.decoder.bias - class FNetOnlyMLMHead(nn.Module): def __init__(self, config): @@ -387,20 +374,21 @@ class FNetPreTrainedModel(PreTrainedModel): base_model_prefix = "fnet" supports_gradient_checkpointing = True + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) # NOTE: Original code uses same initialization as weights for biases as well. if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) @dataclass @@ -536,7 +524,10 @@ def forward( """ ) class FNetForPreTraining(FNetPreTrainedModel): - _tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"] + _tied_weights_keys = { + "cls.predictions.decoder.bias": "cls.predictions.bias", + "cls.predictions.decoder.weight": "fnet.embeddings.word_embeddings.weight", + } def __init__(self, config): super().__init__(config) @@ -626,7 +617,10 @@ def forward( @auto_docstring class FNetForMaskedLM(FNetPreTrainedModel): - _tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"] + _tied_weights_keys = { + "cls.predictions.decoder.bias": "cls.predictions.bias", + "cls.predictions.decoder.weight": "fnet.embeddings.word_embeddings.weight", + } def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/focalnet/modeling_focalnet.py b/src/transformers/models/focalnet/modeling_focalnet.py index 9b5d4daed70c..a297378f5492 100644 --- a/src/transformers/models/focalnet/modeling_focalnet.py +++ b/src/transformers/models/focalnet/modeling_focalnet.py @@ -581,22 +581,23 @@ class FocalNetPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["FocalNetStage"] + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, FocalNetEmbeddings): if module.mask_token is not None: - module.mask_token.data.zero_() + module.mask_token.zero_() elif isinstance(module, FocalNetLayer): if self.config.use_layerscale: - module.gamma_1.data.fill_(self.config.layerscale_value) - module.gamma_2.data.fill_(self.config.layerscale_value) + module.gamma_1.fill_(self.config.layerscale_value) + module.gamma_2.fill_(self.config.layerscale_value) @auto_docstring diff --git a/src/transformers/models/fsmt/modeling_fsmt.py b/src/transformers/models/fsmt/modeling_fsmt.py index f2b45525dfea..f0edc8cce1a1 100644 --- a/src/transformers/models/fsmt/modeling_fsmt.py +++ b/src/transformers/models/fsmt/modeling_fsmt.py @@ -220,21 +220,22 @@ class PretrainedFSMTModel(PreTrainedModel): config: FSMTConfig base_model_prefix = "model" + @torch.no_grad() def _init_weights(self, module): std = self.config.init_std if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, SinusoidalPositionalEmbedding): weight = module.get_embedding(*module.weight.shape, module.padding_idx) weight = nn.Parameter(weight, requires_grad=False) weight.detach_() module.weight = weight elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() @property def dummy_inputs(self): @@ -338,13 +339,13 @@ class FSMTEncoder(nn.Module): config: FSMTConfig """ - def __init__(self, config: FSMTConfig, embed_tokens): + def __init__(self, config: FSMTConfig): super().__init__() self.dropout = config.dropout self.layerdrop = config.encoder_layerdrop - self.padding_idx = embed_tokens.padding_idx - self.embed_tokens = embed_tokens - embed_dim = embed_tokens.embedding_dim + self.padding_idx = config.pad_token_id + self.embed_tokens = nn.Embedding(config.src_vocab_size, config.d_model, config.pad_token_id) + embed_dim = self.embed_tokens.embedding_dim self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 self.embed_positions = SinusoidalPositionalEmbedding( config.max_position_embeddings + self.padding_idx + 1, embed_dim, self.padding_idx @@ -531,31 +532,19 @@ class FSMTDecoder(nn.Module): embed_tokens (nn.Embedding): output embedding """ - def __init__(self, config: FSMTConfig, embed_tokens: nn.Embedding): + def __init__(self, config: FSMTConfig): super().__init__() self.dropout = config.dropout self.layerdrop = config.decoder_layerdrop - self.padding_idx = embed_tokens.padding_idx + self.padding_idx = config.pad_token_id self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 - self.embed_tokens = embed_tokens - embed_dim = embed_tokens.embedding_dim + self.embed_tokens = nn.Embedding(config.tgt_vocab_size, config.d_model, self.padding_idx) + embed_dim = self.embed_tokens.embedding_dim self.embed_positions = SinusoidalPositionalEmbedding( config.max_position_embeddings + self.padding_idx + 1, embed_dim, self.padding_idx ) self.layers = nn.ModuleList([DecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]) # type: list[DecoderLayer] - - if is_deepspeed_zero3_enabled(): - import deepspeed - - with deepspeed.zero.GatheredParameters(self.embed_tokens.weight, modifier_rank=None): - embed_tokens_weight_shape = self.embed_tokens.weight.shape - else: - embed_tokens_weight_shape = self.embed_tokens.weight.shape - self.output_projection = nn.Linear(embed_tokens_weight_shape[1], embed_tokens_weight_shape[0], bias=False) - self.output_projection.weight = self.embed_tokens.weight - - def _tie_weights(self): - self.embed_tokens.weight = self.output_projection.weight + self.output_projection = nn.Linear(config.d_model, config.tgt_vocab_size, bias=False) def forward( self, @@ -828,29 +817,20 @@ def _get_shape(t): @auto_docstring class FSMTModel(PretrainedFSMTModel): - _tied_weights_keys = ["decoder.embed_tokens.weight", "decoder.output_projection.weight"] + _tied_weights_keys = { + "encoder.embed_tokens.weight": "decoder.embed_tokens.weight", + "decoder.output_projection.weight": "decoder.embed_tokens.weight", + } def __init__(self, config: FSMTConfig): super().__init__(config) - - padding_idx = config.pad_token_id - encoder_embed_tokens = nn.Embedding(config.src_vocab_size, config.d_model, padding_idx) - decoder_embed_tokens = nn.Embedding(config.tgt_vocab_size, config.d_model, padding_idx) - - self.encoder = FSMTEncoder(config, encoder_embed_tokens) - self.decoder = FSMTDecoder(config, decoder_embed_tokens) - - # Initialize weights and apply final processing + self.encoder = FSMTEncoder(config) + self.decoder = FSMTDecoder(config) self.post_init() def get_encoder(self): return self.encoder - def _tie_weights(self): - if self.config.tie_word_embeddings: - self._tie_embedding_weights(self.decoder.embed_tokens, self.get_input_embeddings()) - self._tie_embedding_weights(self.decoder.output_projection, self.get_input_embeddings()) - @auto_docstring def forward( self, @@ -978,7 +958,6 @@ def set_output_embeddings(self, value): ) class FSMTForConditionalGeneration(PretrainedFSMTModel, GenerationMixin): base_model_prefix = "model" - _tied_weights_keys = ["decoder.embed_tokens.weight", "decoder.output_projection.weight"] def __init__(self, config: FSMTConfig): super().__init__(config) diff --git a/src/transformers/models/funnel/modeling_funnel.py b/src/transformers/models/funnel/modeling_funnel.py index 1b477dbb551a..7290c54e091a 100644 --- a/src/transformers/models/funnel/modeling_funnel.py +++ b/src/transformers/models/funnel/modeling_funnel.py @@ -672,6 +672,7 @@ class FunnelPreTrainedModel(PreTrainedModel): config: FunnelConfig base_model_prefix = "funnel" + @torch.no_grad() def _init_weights(self, module): classname = module.__class__.__name__ if classname.find("Linear") != -1: @@ -694,7 +695,7 @@ def _init_weights(self, module): std = 1.0 if self.config.initializer_std is None else self.config.initializer_std nn.init.normal_(module.word_embeddings.weight, std=std) if module.word_embeddings.padding_idx is not None: - module.word_embeddings.weight.data[module.word_embeddings.padding_idx].zero_() + module.word_embeddings.weight[module.word_embeddings.padding_idx].zero_() class FunnelClassificationHead(nn.Module): @@ -982,7 +983,7 @@ def forward( @auto_docstring class FunnelForMaskedLM(FunnelPreTrainedModel): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "funnel.embeddings.word_embeddings.weight"} def __init__(self, config: FunnelConfig) -> None: super().__init__(config) diff --git a/src/transformers/models/fuyu/modeling_fuyu.py b/src/transformers/models/fuyu/modeling_fuyu.py index fdacd7409615..0a412375ae59 100644 --- a/src/transformers/models/fuyu/modeling_fuyu.py +++ b/src/transformers/models/fuyu/modeling_fuyu.py @@ -44,16 +44,17 @@ class FuyuPreTrainedModel(PreTrainedModel): _no_split_modules = [] _skip_keys_device_placement = "past_key_values" + @torch.no_grad() def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() @auto_docstring( @@ -257,7 +258,7 @@ class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin): "^vision_embed_tokens": "model.vision_embed_tokens", "^language_model.lm_head": "lm_head", } - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} def __init__(self, config: FuyuConfig): super().__init__(config) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 335c2b2cf7b5..1acb039017dc 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -349,12 +349,13 @@ class GemmaPreTrainedModel(PreTrainedModel): "attentions": GemmaAttention, } + @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight) if "RMSNorm" in module.__class__.__name__: - module.weight.data.zero_() + module.weight.zero_() @auto_docstring @@ -447,7 +448,7 @@ def forward( @auto_docstring class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/gemma/modular_gemma.py b/src/transformers/models/gemma/modular_gemma.py index aa64cc9e63e8..d1b3070a5ad0 100644 --- a/src/transformers/models/gemma/modular_gemma.py +++ b/src/transformers/models/gemma/modular_gemma.py @@ -394,12 +394,13 @@ def __init__(self, config: GemmaConfig, layer_idx: int): class GemmaPreTrainedModel(LlamaPreTrainedModel): + @torch.no_grad() def _init_weights(self, module): PreTrainedModel._init_weights(self, module) # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight) if "RMSNorm" in module.__class__.__name__: - module.weight.data.zero_() + module.weight.zero_() class GemmaModel(LlamaModel): diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index f824053201ad..6db748900375 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -381,12 +381,13 @@ class Gemma2PreTrainedModel(PreTrainedModel): "attentions": Gemma2Attention, } + @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight) if "RMSNorm" in module.__class__.__name__: - module.weight.data.zero_() + module.weight.zero_() @auto_docstring @@ -519,7 +520,7 @@ def forward( @auto_docstring class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index 8dff40771914..00f74c850dc5 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -466,13 +466,14 @@ class Gemma3PreTrainedModel(PreTrainedModel): } input_modalities = ["image", "text"] + @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) if isinstance(module, Gemma3MultiModalProjector): - module.mm_input_projection_weight.data.zero_() + module.mm_input_projection_weight.zero_() # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight) elif "RMSNorm" in module.__class__.__name__: - module.weight.data.zero_() + module.weight.zero_() def _bidirectional_window_overlay(sliding_window: int) -> Callable[[int, int, int, int], bool]: @@ -626,7 +627,7 @@ def forward( @auto_docstring class Gemma3ForCausalLM(Gemma3PreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} config: Gemma3TextConfig @@ -1044,7 +1045,7 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin): "^multi_modal_projector": "model.multi_modal_projector", "^language_model.lm_head": "lm_head", } - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} # we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch # Fix: https://github.com/huggingface/transformers/issues/40564 accepts_loss_kwargs = False diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py index f4b4ce22381e..addd9ac994b9 100644 --- a/src/transformers/models/gemma3/modular_gemma3.py +++ b/src/transformers/models/gemma3/modular_gemma3.py @@ -569,13 +569,14 @@ class Gemma3PreTrainedModel(Gemma2PreTrainedModel): "SiglipMultiheadAttentionPoolingHead", ] + @torch.no_grad() def _init_weights(self, module): PreTrainedModel._init_weights(self, module) if isinstance(module, Gemma3MultiModalProjector): - module.mm_input_projection_weight.data.zero_() + module.mm_input_projection_weight.zero_() # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight) elif "RMSNorm" in module.__class__.__name__: - module.weight.data.zero_() + module.weight.zero_() def _bidirectional_window_overlay(sliding_window: int) -> Callable[[int, int, int, int], bool]: diff --git a/src/transformers/models/gemma3n/modeling_gemma3n.py b/src/transformers/models/gemma3n/modeling_gemma3n.py index 452860d956f9..2ab5b224d725 100644 --- a/src/transformers/models/gemma3n/modeling_gemma3n.py +++ b/src/transformers/models/gemma3n/modeling_gemma3n.py @@ -1600,14 +1600,15 @@ class Gemma3nPreTrainedModel(PreTrainedModel): } input_modalities = ["image", "text", "audio"] + @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) if isinstance(module, Gemma3nAudioCumulativeGroupNorm): - module.weight.data.fill_(1.0) + module.weight.fill_(1.0) elif isinstance(module, Gemma3nAudioAttention): - module.per_dim_scale.data.zero_() + module.per_dim_scale.zero_() elif isinstance(module, Gemma3nTextAltUp): - module.correct_output_scale.data.zero_() + module.correct_output_scale.zero_() class Gemma3nRotaryEmbedding(nn.Module): @@ -1932,7 +1933,7 @@ def project_per_layer_inputs( @auto_docstring(custom_intro="The base Gemma 3n language model with a language modeling head.") class Gemma3nForCausalLM(Gemma3nPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} config: Gemma3nTextConfig @@ -2345,7 +2346,7 @@ def get_audio_features( ) class Gemma3nForConditionalGeneration(Gemma3nPreTrainedModel, GenerationMixin): _checkpoint_conversion_mapping = {} - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} base_model_prefix = "model" def __init__(self, config: Gemma3nConfig): diff --git a/src/transformers/models/gemma3n/modular_gemma3n.py b/src/transformers/models/gemma3n/modular_gemma3n.py index 6d431e9acc55..14a6ce1f8f1a 100644 --- a/src/transformers/models/gemma3n/modular_gemma3n.py +++ b/src/transformers/models/gemma3n/modular_gemma3n.py @@ -1876,14 +1876,15 @@ class Gemma3nPreTrainedModel(Gemma2PreTrainedModel): input_modalities = ["image", "text", "audio"] _no_split_modules = ["Gemma3nTextDecoderLayer"] + @torch.no_grad() def _init_weights(self, module): PreTrainedModel._init_weights(self, module) if isinstance(module, Gemma3nAudioCumulativeGroupNorm): - module.weight.data.fill_(1.0) + module.weight.fill_(1.0) elif isinstance(module, Gemma3nAudioAttention): - module.per_dim_scale.data.zero_() + module.per_dim_scale.zero_() elif isinstance(module, Gemma3nTextAltUp): - module.correct_output_scale.data.zero_() + module.correct_output_scale.zero_() @auto_docstring(custom_intro="The base Gemma 3n language model without a language modeling head.") diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py index 5cc3195b4c38..24ce421e1d5e 100644 --- a/src/transformers/models/git/modeling_git.py +++ b/src/transformers/models/git/modeling_git.py @@ -388,6 +388,7 @@ class GitPreTrainedModel(PreTrainedModel): input_modalities = ["image", "text"] supports_gradient_checkpointing = True + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, GitVisionEmbeddings): @@ -395,16 +396,16 @@ def _init_weights(self, module): nn.init.normal_(module.patch_embedding.weight, std=self.config.initializer_range) nn.init.normal_(module.position_embedding.weight, std=self.config.initializer_range) if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) # Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings with CLIP->Git @@ -1119,7 +1120,7 @@ def forward( """ ) class GitForCausalLM(GitPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["output.weight"] + _tied_weights_keys = {"output.weight": "git.embeddings.word_embeddings.weight"} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index f72268465ece..a4880c0145e9 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -450,7 +450,7 @@ def forward( @auto_docstring class GlmForCausalLM(GlmPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/glm4/modeling_glm4.py b/src/transformers/models/glm4/modeling_glm4.py index 935a722fd1db..ba07da7cab54 100644 --- a/src/transformers/models/glm4/modeling_glm4.py +++ b/src/transformers/models/glm4/modeling_glm4.py @@ -454,7 +454,7 @@ def forward( @auto_docstring class Glm4ForCausalLM(Glm4PreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/glm4_moe/modeling_glm4_moe.py b/src/transformers/models/glm4_moe/modeling_glm4_moe.py index f7bc01465160..a39dcb44ad38 100644 --- a/src/transformers/models/glm4_moe/modeling_glm4_moe.py +++ b/src/transformers/models/glm4_moe/modeling_glm4_moe.py @@ -330,19 +330,23 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -class Glm4MoeNaiveMoe(nn.ModuleList): - """ - ModuleList of experts. - """ +class Glm4MoeNaiveMoe(nn.Module): + """Collection of expert weights stored as 3D tensors.""" def __init__(self, config): super().__init__() self.num_experts = config.num_local_experts - for _ in range(self.num_experts): - self.append(Glm4MoeMLP(config, intermediate_size=config.moe_intermediate_size)) + self.hidden_dim = config.hidden_size + self.intermediate_dim = config.intermediate_size + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) + self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) + self.act_fn = ACT2FN[config.hidden_act] def forward( - self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, ) -> torch.Tensor: """ Args: @@ -353,14 +357,24 @@ def forward( (batch_size * sequence_length, hidden_dim) """ final_hidden_states = torch.zeros_like(hidden_states) - expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0) + num_experts = top_k_weights.shape[1] + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() for expert_idx in expert_hit: - idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) - current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1]) - current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None] - final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + expert_idx = expert_idx[0] + if expert_idx == num_experts: + continue + _, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) + current_hidden_states = self.act_fn(gate) * up + current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) + current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None] + final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) + return final_hidden_states @@ -486,10 +500,11 @@ class Glm4MoePreTrainedModel(PreTrainedModel): "attentions": Glm4MoeAttention, } + @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) if isinstance(module, Glm4MoeTopkRouter): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) @auto_docstring @@ -575,7 +590,7 @@ def forward( @auto_docstring class Glm4MoeForCausalLM(Glm4MoePreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/glm4v/modeling_glm4v.py b/src/transformers/models/glm4v/modeling_glm4v.py index 147e18b7e78e..20c7212e2f65 100644 --- a/src/transformers/models/glm4v/modeling_glm4v.py +++ b/src/transformers/models/glm4v/modeling_glm4v.py @@ -1364,7 +1364,7 @@ class Glm4vCausalLMOutputWithPast(ModelOutput): class Glm4vForConditionalGeneration(Glm4vPreTrainedModel, GenerationMixin): _checkpoint_conversion_mapping = {} - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} # Reference: fix gemma3 grad acc #37208 accepts_loss_kwargs = False @@ -1424,8 +1424,6 @@ def forward( **kwargs: Unpack[TransformersKwargs], ) -> Union[tuple, Glm4vCausalLMOutputWithPast]: r""" - rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): - The rope index difference between sequence length and multimodal rope. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored @@ -1434,6 +1432,8 @@ def forward( The temporal, height and width of feature shape of each image in LLM. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): The temporal, height and width of feature shape of each video in LLM. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. Example: diff --git a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py index 7afb2e0b1463..ca2a26d93392 100644 --- a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py +++ b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py @@ -351,19 +351,23 @@ def forward(self, hidden_states): return router_logits -class Glm4vMoeTextNaiveMoe(nn.ModuleList): - """ - ModuleList of experts. - """ +class Glm4vMoeTextNaiveMoe(nn.Module): + """Collection of expert weights stored as 3D tensors.""" def __init__(self, config): super().__init__() self.num_experts = config.num_local_experts - for _ in range(self.num_experts): - self.append(Glm4vMoeTextMLP(config, intermediate_size=config.moe_intermediate_size)) + self.hidden_dim = config.hidden_size + self.intermediate_dim = config.intermediate_size + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) + self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) + self.act_fn = ACT2FN[config.hidden_act] def forward( - self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, ) -> torch.Tensor: """ Args: @@ -374,14 +378,24 @@ def forward( (batch_size * sequence_length, hidden_dim) """ final_hidden_states = torch.zeros_like(hidden_states) - expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0) + num_experts = top_k_weights.shape[1] + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() for expert_idx in expert_hit: - idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) - current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1]) - current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None] - final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + expert_idx = expert_idx[0] + if expert_idx == num_experts: + continue + _, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) + current_hidden_states = self.act_fn(gate) * up + current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) + current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None] + final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) + return final_hidden_states @@ -547,10 +561,11 @@ class Glm4vMoePreTrainedModel(PreTrainedModel): } input_modalities = ["text", "image", "video"] + @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) if isinstance(module, Glm4vMoeTextTopkRouter): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) @dataclass @@ -1572,7 +1587,7 @@ def load_balancing_loss_func( class Glm4vMoeForConditionalGeneration(Glm4vMoePreTrainedModel, GenerationMixin): _checkpoint_conversion_mapping = {} - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} # Reference: fix gemma3 grad acc #37208 accepts_loss_kwargs = False diff --git a/src/transformers/models/glpn/modeling_glpn.py b/src/transformers/models/glpn/modeling_glpn.py index 17d6f5565edb..4255ae22f47f 100755 --- a/src/transformers/models/glpn/modeling_glpn.py +++ b/src/transformers/models/glpn/modeling_glpn.py @@ -389,20 +389,20 @@ class GLPNPreTrainedModel(PreTrainedModel): input_modalities = "image" _no_split_modules = [] - # Copied from transformers.models.segformer.modeling_segformer.SegformerPreTrainedModel._init_weights + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) @auto_docstring diff --git a/src/transformers/models/got_ocr2/modeling_got_ocr2.py b/src/transformers/models/got_ocr2/modeling_got_ocr2.py index 809926990d41..90dfa7cb6839 100644 --- a/src/transformers/models/got_ocr2/modeling_got_ocr2.py +++ b/src/transformers/models/got_ocr2/modeling_got_ocr2.py @@ -287,15 +287,16 @@ class GotOcr2PreTrainedModel(PreTrainedModel): _supports_flex_attn = False _supports_attention_backend = True + @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) if isinstance(module, GotOcr2VisionAttention): if module.use_rel_pos: - module.rel_pos_h.data.zero_() - module.rel_pos_w.data.zero_() + module.rel_pos_h.zero_() + module.rel_pos_w.zero_() elif isinstance(module, GotOcr2VisionEncoder): if module.pos_embed is not None: - module.pos_embed.data.zero_() + module.pos_embed.zero_() @dataclass @@ -663,7 +664,7 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin): "^multi_modal_projector": "model.multi_modal_projector", "^language_model.lm_head": "lm_head", } - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} def __init__(self, config: GotOcr2Config): super().__init__(config) diff --git a/src/transformers/models/got_ocr2/modular_got_ocr2.py b/src/transformers/models/got_ocr2/modular_got_ocr2.py index 1b56eff7729d..9312ed42ff38 100644 --- a/src/transformers/models/got_ocr2/modular_got_ocr2.py +++ b/src/transformers/models/got_ocr2/modular_got_ocr2.py @@ -289,15 +289,16 @@ class GotOcr2PreTrainedModel(LlavaPreTrainedModel): _supports_sdpa = False _supports_flex_attn = False + @torch.no_grad() def _init_weights(self, module): PreTrainedModel._init_weights(self, module) if isinstance(module, GotOcr2VisionAttention): if module.use_rel_pos: - module.rel_pos_h.data.zero_() - module.rel_pos_w.data.zero_() + module.rel_pos_h.zero_() + module.rel_pos_w.zero_() elif isinstance(module, GotOcr2VisionEncoder): if module.pos_embed is not None: - module.pos_embed.data.zero_() + module.pos_embed.zero_() class GotOcr2Model(LlavaModel): diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 8b134f25c6f8..b667089c2c42 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -483,19 +483,20 @@ class GPT2PreTrainedModel(PreTrainedModel): def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) + @torch.no_grad() def _init_weights(self, module): """Initialize the weights.""" if isinstance(module, (nn.Linear, Conv1D)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale @@ -503,10 +504,11 @@ def _init_weights(self, module): # > -- GPT-2 :: https://openai.com/blog/better-language-models/ # # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py - for name, p in module.named_parameters(): - if name == "c_proj.weight": - # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block - p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) + if isinstance(module, PreTrainedModel): + for name, p in module.named_parameters(): + if name == "c_proj.weight": + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + p.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) @dataclass @@ -751,7 +753,7 @@ def forward( """ ) class GPT2LMHeadModel(GPT2PreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "transformer.wte.weight"} def __init__(self, config): super().__init__(config) @@ -854,7 +856,7 @@ def forward( """ ) class GPT2DoubleHeadsModel(GPT2PreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "transformer.wte.weight"} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index fbbad2c60825..8f9d2b2fac00 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -361,6 +361,7 @@ class GPTBigCodePreTrainedModel(PreTrainedModel): def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) + @torch.no_grad() def _init_weights(self, module): """Initialize the weights.""" if isinstance(module, (GPTBigCodeMLP, GPTBigCodeAttention)): @@ -370,21 +371,21 @@ def _init_weights(self, module): # > -- GPT-2 :: https://openai.com/blog/better-language-models/ # # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py - module.c_proj.weight.data.normal_( + module.c_proj.weight.normal_( mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)) ) module.c_proj._is_hf_initialized = True elif isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) @auto_docstring @@ -576,7 +577,7 @@ def forward( """ ) class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "transformer.wte.weight"} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index d758b0529d86..c591ef2ec914 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -384,19 +384,20 @@ class GPTNeoPreTrainedModel(PreTrainedModel): def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) + @torch.no_grad() def _init_weights(self, module): """Initialize the weights.""" if isinstance(module, (nn.Linear,)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) @auto_docstring @@ -667,7 +668,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( """ ) class GPTNeoForCausalLM(GPTNeoPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "transformer.wte.weight"} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 719ec08ce3e6..fc7d6fd40a80 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -517,7 +517,7 @@ def set_input_embeddings(self, value): """ ) class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["embed_out.weight"] + _tied_weights_keys = {"embed_out.weight": "gpt_neox.embed_in.weight"} _tp_plan = {"embed_out": "colwise_rep"} _pp_plan = {"embed_out": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/gpt_neox/modular_gpt_neox.py b/src/transformers/models/gpt_neox/modular_gpt_neox.py index dfd877825363..c267753db350 100644 --- a/src/transformers/models/gpt_neox/modular_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modular_gpt_neox.py @@ -390,7 +390,7 @@ def forward( """ ) class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["embed_out.weight"] + _tied_weights_keys = {"embed_out.weight": "gpt_neox.embed_in.weight"} _tp_plan = {"embed_out": "colwise_rep"} _pp_plan = {"embed_out": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index 5120929f9b4b..a906004dd41e 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -50,22 +50,23 @@ class GPTNeoXJapanesePreTrainedModel(PreTrainedModel): _can_compile_fullgraph = True + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, GPTNeoXJapaneseAttention): if module.dense_bias is not None: - module.dense_bias.data.zero_() + module.dense_bias.zero_() # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->GPTNeoXJapanese @@ -656,7 +657,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( """ ) class GPTNeoXJapaneseForCausalLM(GPTNeoXJapanesePreTrainedModel, GenerationMixin): - _tied_weights_keys = ["embed_out.weight"] + _tied_weights_keys = {"embed_out.weight": "gpt_neox_japanese.embed_in.weight"} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/gpt_oss/modeling_gpt_oss.py b/src/transformers/models/gpt_oss/modeling_gpt_oss.py index 92688a0ab341..11e323544806 100644 --- a/src/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -71,10 +71,10 @@ def __init__(self, config): self.num_experts = config.num_local_experts self.hidden_size = config.hidden_size self.expert_dim = self.intermediate_size - self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim)) - self.gate_up_proj_bias = nn.Parameter(torch.empty(self.num_experts, 2 * self.expert_dim)) + self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size, 2 * self.expert_dim)) + self.gate_up_proj_bias = nn.Parameter(torch.zeros(self.num_experts, 2 * self.expert_dim)) self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size))) - self.down_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.hidden_size)) + self.down_proj_bias = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size)) self.alpha = 1.702 self.limit = 7.0 @@ -146,8 +146,8 @@ def __init__(self, config): self.top_k = config.num_experts_per_tok self.num_experts = config.num_local_experts self.hidden_dim = config.hidden_size - self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim)) - self.bias = nn.Parameter(torch.empty(self.num_experts)) + self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim)) + self.bias = nn.Parameter(torch.zeros(self.num_experts)) def forward(self, hidden_states): hidden_states = hidden_states.reshape(-1, self.hidden_dim) @@ -440,30 +440,31 @@ class GptOssPreTrainedModel(PreTrainedModel): _supports_flash_attention = False _supports_flex_attention = False + @torch.no_grad() def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Parameter): - module.data.normal_(mean=0.0, std=std) + module.normal_(mean=0.0, std=std) elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, GptOssRMSNorm): - module.weight.data.fill_(1.0) + module.weight.fill_(1.0) elif isinstance(module, GptOssExperts): - module.gate_up_proj.data.normal_(mean=0.0, std=std) - module.gate_up_proj_bias.data.zero_() - module.down_proj.data.normal_(mean=0.0, std=std) - module.down_proj_bias.data.zero_() + module.gate_up_proj.normal_(mean=0.0, std=std) + module.gate_up_proj_bias.zero_() + module.down_proj.normal_(mean=0.0, std=std) + module.down_proj_bias.zero_() elif isinstance(module, GptOssAttention): - module.sinks.data.normal_(mean=0.0, std=std) + module.sinks.normal_(mean=0.0, std=std) elif isinstance(module, GptOssTopKRouter): - module.weight.data.normal_(mean=0.0, std=std) - module.bias.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) + module.bias.normal_(mean=0.0, std=std) @auto_docstring @@ -635,7 +636,7 @@ def load_balancing_loss_func( @auto_docstring class GptOssForCausalLM(GptOssPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/gpt_oss/modular_gpt_oss.py b/src/transformers/models/gpt_oss/modular_gpt_oss.py index e44831063200..4f33517001b3 100644 --- a/src/transformers/models/gpt_oss/modular_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modular_gpt_oss.py @@ -69,10 +69,10 @@ def __init__(self, config): self.num_experts = config.num_local_experts self.hidden_size = config.hidden_size self.expert_dim = self.intermediate_size - self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim)) - self.gate_up_proj_bias = nn.Parameter(torch.empty(self.num_experts, 2 * self.expert_dim)) + self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size, 2 * self.expert_dim)) + self.gate_up_proj_bias = nn.Parameter(torch.zeros(self.num_experts, 2 * self.expert_dim)) self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size))) - self.down_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.hidden_size)) + self.down_proj_bias = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size)) self.alpha = 1.702 self.limit = 7.0 @@ -144,8 +144,8 @@ def __init__(self, config): self.top_k = config.num_experts_per_tok self.num_experts = config.num_local_experts self.hidden_dim = config.hidden_size - self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim)) - self.bias = nn.Parameter(torch.empty(self.num_experts)) + self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim)) + self.bias = nn.Parameter(torch.zeros(self.num_experts)) def forward(self, hidden_states): hidden_states = hidden_states.reshape(-1, self.hidden_dim) @@ -356,30 +356,31 @@ class GptOssPreTrainedModel(LlamaPreTrainedModel): "attentions": GptOssAttention, } + @torch.no_grad() def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Parameter): - module.data.normal_(mean=0.0, std=std) + module.normal_(mean=0.0, std=std) elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, GptOssRMSNorm): - module.weight.data.fill_(1.0) + module.weight.fill_(1.0) elif isinstance(module, GptOssExperts): - module.gate_up_proj.data.normal_(mean=0.0, std=std) - module.gate_up_proj_bias.data.zero_() - module.down_proj.data.normal_(mean=0.0, std=std) - module.down_proj_bias.data.zero_() + module.gate_up_proj.normal_(mean=0.0, std=std) + module.gate_up_proj_bias.zero_() + module.down_proj.normal_(mean=0.0, std=std) + module.down_proj_bias.zero_() elif isinstance(module, GptOssAttention): - module.sinks.data.normal_(mean=0.0, std=std) + module.sinks.normal_(mean=0.0, std=std) elif isinstance(module, GptOssTopKRouter): - module.weight.data.normal_(mean=0.0, std=std) - module.bias.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) + module.bias.normal_(mean=0.0, std=std) class GptOssModel(MixtralModel): diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 24d3322ad658..8d8004577e57 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -447,19 +447,20 @@ class GPTJPreTrainedModel(PreTrainedModel): def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) + @torch.no_grad() def _init_weights(self, module): """Initialize the weights.""" if isinstance(module, (nn.Linear,)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) @auto_docstring @@ -722,7 +723,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( """ ) class GPTJForCausalLM(GPTJPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "transformer.wte.weight"} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index bf64a382700b..42de2e0724f3 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -502,7 +502,7 @@ def forward( @auto_docstring class GraniteForCausalLM(GranitePreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/granite_speech/modeling_granite_speech.py b/src/transformers/models/granite_speech/modeling_granite_speech.py index 6973124fb51f..07e7c2573e99 100644 --- a/src/transformers/models/granite_speech/modeling_granite_speech.py +++ b/src/transformers/models/granite_speech/modeling_granite_speech.py @@ -286,23 +286,24 @@ class GraniteSpeechPreTrainedModel(PreTrainedModel): _supports_flash_attn = False # `blip_2_qformer` dependency does not allow for this _supports_sdpa = True + @torch.no_grad() def _init_weights(self, module: nn.Module): """Initialize the weights.""" std = self.config.initializer_range if isinstance(module, (nn.Linear, nn.Conv1d)): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d)): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() elif isinstance(module, GraniteSpeechEncoderProjector): - module.query.data.normal_() + module.query.normal_() @auto_docstring( @@ -319,9 +320,6 @@ def __init__(self, config: GraniteSpeechConfig): # model; don't need to consider it twice self.language_model = AutoModelForCausalLM.from_config(config.text_config) - if self.language_model._tied_weights_keys is not None: - self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys] - self.encoder = GraniteSpeechCTCEncoder(config.encoder_config) self.projector = GraniteSpeechEncoderProjector(config) diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index 0eefadc9a1b9..0b3a893b9883 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -411,10 +411,9 @@ def __init__(self, config: GraniteMoeConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.self_attn = GraniteMoeAttention(config=config, layer_idx=layer_idx) - self.block_sparse_moe = GraniteMoeMoE(config) self.input_layernorm = GraniteMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = GraniteMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - + self.block_sparse_moe = GraniteMoeMoE(config) self.residual_multiplier = config.residual_multiplier # Only diff with mixtral! def forward( @@ -462,10 +461,11 @@ class GraniteMoePreTrainedModel(PreTrainedModel): "attentions": GraniteMoeAttention, } + @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) if isinstance(module, GraniteMoeParallelExperts): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) @auto_docstring @@ -635,7 +635,7 @@ def load_balancing_loss_func( @auto_docstring class GraniteMoeForCausalLM(GraniteMoePreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/granitemoe/modular_granitemoe.py b/src/transformers/models/granitemoe/modular_granitemoe.py index 3c5b73ebf899..53692da91773 100644 --- a/src/transformers/models/granitemoe/modular_granitemoe.py +++ b/src/transformers/models/granitemoe/modular_granitemoe.py @@ -105,7 +105,8 @@ def __init__(self, config: GraniteMoeConfig, layer_idx: int): self.block_sparse_moe = GraniteMoeMoE(config) self.input_layernorm = GraniteMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = GraniteMoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - + del self.mlp + self.block_sparse_moe = GraniteMoeMoE(config) self.residual_multiplier = config.residual_multiplier # Only diff with mixtral! def forward( @@ -147,10 +148,11 @@ class GraniteMoePreTrainedModel(LlamaPreTrainedModel, PreTrainedModel): _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) + @torch.no_grad() def _init_weights(self, module): PreTrainedModel._init_weights(self, module) if isinstance(module, GraniteMoeParallelExperts): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) @auto_docstring diff --git a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py index 947d250cd134..dc39370b7559 100644 --- a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py @@ -1119,10 +1119,9 @@ def __init__(self, config: GraniteMoeHybridConfig, layer_idx: int): self.hidden_size = config.hidden_size # Either attention or mamba will be initialized, depending on the layer type. self.self_attn = None - self.block_sparse_moe = GraniteMoeHybridMoE(config) self.input_layernorm = GraniteMoeHybridRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = GraniteMoeHybridRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - + self.block_sparse_moe = GraniteMoeHybridMoE(config) self.residual_multiplier = config.residual_multiplier # Only diff with mixtral! self.shared_mlp = GraniteMoeHybridMLP(config) self.mamba = None @@ -1202,16 +1201,17 @@ class GraniteMoeHybridPreTrainedModel(PreTrainedModel): } _is_stateful = True + @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) if isinstance(module, GraniteMoeHybridParallelExperts): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if isinstance(module, GraniteMoeHybridMambaLayer): - module.dt_bias.data.fill_(1.0) - module.A_log.data = torch.log(torch.arange(1, module.num_heads + 1)) - module.D.data.fill_(1.0) + module.dt_bias.fill_(1.0) + module.A_log.copy_(torch.log(torch.arange(1, module.num_heads + 1))) + module.D.fill_(1.0) elif isinstance(module, GraniteMoeHybridRMSNormGated): - module.weight.data.fill_(1.0) + module.weight.fill_(1.0) @auto_docstring @@ -1395,7 +1395,7 @@ def load_balancing_loss_func( @auto_docstring class GraniteMoeHybridForCausalLM(GraniteMoeHybridPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py index f1b8a5bfb110..ed0676752fbc 100644 --- a/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py @@ -176,14 +176,15 @@ class GraniteMoeHybridPreTrainedModel(GraniteMoeSharedPreTrainedModel): _no_split_modules = ["GraniteMoeHybridDecoderLayer"] _is_stateful = True + @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) if isinstance(module, GraniteMoeHybridMambaLayer): - module.dt_bias.data.fill_(1.0) - module.A_log.data = torch.log(torch.arange(1, module.num_heads + 1)) - module.D.data.fill_(1.0) + module.dt_bias.fill_(1.0) + module.A_log.copy_(torch.log(torch.arange(1, module.num_heads + 1))) + module.D.fill_(1.0) elif isinstance(module, GraniteMoeHybridRMSNormGated): - module.weight.data.fill_(1.0) + module.weight.fill_(1.0) class GraniteMoeHybridModel(GraniteMoeSharedModel): @@ -273,7 +274,7 @@ def _update_mamba_mask(self, attention_mask, cache_position): class GraniteMoeHybridForCausalLM(GraniteMoeSharedForCausalLM): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} def __init__(self, config: GraniteMoeHybridConfig): super().__init__(config) diff --git a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py index 8b1569722006..d2f228d0f197 100644 --- a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py +++ b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py @@ -401,10 +401,9 @@ def __init__(self, config: GraniteMoeSharedConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.self_attn = GraniteMoeSharedAttention(config=config, layer_idx=layer_idx) - self.block_sparse_moe = GraniteMoeSharedMoE(config) self.input_layernorm = GraniteMoeSharedRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = GraniteMoeSharedRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - + self.block_sparse_moe = GraniteMoeSharedMoE(config) self.residual_multiplier = config.residual_multiplier # Only diff with mixtral! self.shared_mlp = None if config.shared_intermediate_size == 0 else GraniteMoeSharedMLP(config) @@ -468,10 +467,11 @@ class GraniteMoeSharedPreTrainedModel(PreTrainedModel): "attentions": GraniteMoeSharedAttention, } + @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) if isinstance(module, GraniteMoeSharedParallelExperts): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) class GraniteMoeSharedRotaryEmbedding(nn.Module): @@ -706,7 +706,7 @@ def load_balancing_loss_func( @auto_docstring class GraniteMoeSharedForCausalLM(GraniteMoeSharedPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/granitemoeshared/modular_granitemoeshared.py b/src/transformers/models/granitemoeshared/modular_granitemoeshared.py index 5c3241e71b5d..4bc8f66e85c9 100644 --- a/src/transformers/models/granitemoeshared/modular_granitemoeshared.py +++ b/src/transformers/models/granitemoeshared/modular_granitemoeshared.py @@ -146,7 +146,7 @@ def __init__(self, config: GraniteMoeSharedConfig): class GraniteMoeSharedForCausalLM(GraniteMoeForCausalLM): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} def __init__(self, config: GraniteMoeSharedConfig): super().__init__(config) diff --git a/src/transformers/models/grounding_dino/modeling_grounding_dino.py b/src/transformers/models/grounding_dino/modeling_grounding_dino.py index 6c53d3ba21f2..506372c73fe4 100644 --- a/src/transformers/models/grounding_dino/modeling_grounding_dino.py +++ b/src/transformers/models/grounding_dino/modeling_grounding_dino.py @@ -1369,6 +1369,7 @@ class GroundingDinoPreTrainedModel(PreTrainedModel): main_input_name = "pixel_values" input_modalities = ["image", "text"] + @torch.no_grad() def _init_weights(self, module): std = self.config.init_std @@ -1376,7 +1377,7 @@ def _init_weights(self, module): nn.init.uniform_(module.row_embeddings.weight) nn.init.uniform_(module.column_embeddings.weight) elif isinstance(module, GroundingDinoMultiscaleDeformableAttention): - nn.init.constant_(module.sampling_offsets.weight.data, 0.0) + nn.init.constant_(module.sampling_offsets.weight, 0.0) default_dtype = torch.get_default_dtype() thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * ( 2.0 * math.pi / module.n_heads @@ -1391,46 +1392,46 @@ def _init_weights(self, module): grid_init[:, :, i, :] *= i + 1 with torch.no_grad(): module.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) - nn.init.constant_(module.attention_weights.weight.data, 0.0) - nn.init.constant_(module.attention_weights.bias.data, 0.0) - nn.init.xavier_uniform_(module.value_proj.weight.data) - nn.init.constant_(module.value_proj.bias.data, 0.0) - nn.init.xavier_uniform_(module.output_proj.weight.data) - nn.init.constant_(module.output_proj.bias.data, 0.0) + nn.init.constant_(module.attention_weights.weight, 0.0) + nn.init.constant_(module.attention_weights.bias, 0.0) + nn.init.xavier_uniform_(module.value_proj.weight) + nn.init.constant_(module.value_proj.bias, 0.0) + nn.init.xavier_uniform_(module.output_proj.weight) + nn.init.constant_(module.output_proj.bias, 0.0) elif isinstance(module, GroundingDinoBiMultiHeadAttention): nn.init.xavier_uniform_(module.vision_proj.weight) - module.vision_proj.bias.data.fill_(0) + module.vision_proj.bias.fill_(0) nn.init.xavier_uniform_(module.text_proj.weight) - module.text_proj.bias.data.fill_(0) + module.text_proj.bias.fill_(0) nn.init.xavier_uniform_(module.values_vision_proj.weight) - module.values_vision_proj.bias.data.fill_(0) + module.values_vision_proj.bias.fill_(0) nn.init.xavier_uniform_(module.values_text_proj.weight) - module.values_text_proj.bias.data.fill_(0) + module.values_text_proj.bias.fill_(0) nn.init.xavier_uniform_(module.out_vision_proj.weight) - module.out_vision_proj.bias.data.fill_(0) + module.out_vision_proj.bias.fill_(0) nn.init.xavier_uniform_(module.out_text_proj.weight) - module.out_text_proj.bias.data.fill_(0) + module.out_text_proj.bias.fill_(0) elif isinstance(module, GroundingDinoFusionLayer): - module.vision_param.data.fill_(1e-4) - module.text_param.data.fill_(1e-4) + module.vision_param.fill_(1e-4) + module.text_param.fill_(1e-4) elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, GroundingDinoMLPPredictionHead): - nn.init.constant_(module.layers[-1].weight.data, 0) - nn.init.constant_(module.layers[-1].bias.data, 0) + nn.init.constant_(module.layers[-1].weight, 0) + nn.init.constant_(module.layers[-1].bias, 0) if hasattr(module, "reference_points") and not self.config.two_stage: - nn.init.xavier_uniform_(module.reference_points.weight.data, gain=1.0) - nn.init.constant_(module.reference_points.bias.data, 0.0) + nn.init.xavier_uniform_(module.reference_points.weight, gain=1.0) + nn.init.constant_(module.reference_points.bias, 0.0) if hasattr(module, "level_embed"): nn.init.normal_(module.level_embed) @@ -2412,35 +2413,32 @@ def build_text_mask(logits, attention_mask): class GroundingDinoForObjectDetection(GroundingDinoPreTrainedModel): # When using clones, all layers > 0 will be clones, but layer 0 *is* required # the bbox_embed in the decoder are all clones though - _tied_weights_keys = [r"bbox_embed\.[1-9]\d*", r"model\.decoder\.bbox_embed\.[0-9]\d*"] + _tied_weights_keys = { + "model.decoder.bbox_embed":"bbox_embed", + "model.decoder.class_embed":"class_embed", + r"class_embed.(?![0])\d+": "class_embed.0", + } def __init__(self, config: GroundingDinoConfig): super().__init__(config) self.model = GroundingDinoModel(config) - _class_embed = GroundingDinoContrastiveEmbedding(config) - if config.decoder_bbox_embed_share: - # a single shared instance - shared_head = GroundingDinoMLPPredictionHead( - input_dim=config.d_model, hidden_dim=config.d_model, output_dim=4, num_layers=3 - ) - self.bbox_embed = nn.ModuleList([shared_head] * config.decoder_layers) - else: - # each layer has its own head (implicit deep copy through a new instance) - self.bbox_embed = nn.ModuleList( - [ - GroundingDinoMLPPredictionHead( - input_dim=config.d_model, - hidden_dim=config.d_model, - output_dim=4, - num_layers=3, - ) - for _ in range(config.decoder_layers) - ] - ) + self._tied_weights_keys[r"bbox_embed.(?![0])\d+"]= "bbox_embed.0" + + self.bbox_embed = nn.ModuleList( + [ + GroundingDinoMLPPredictionHead( + input_dim=config.d_model, + hidden_dim=config.d_model, + output_dim=4, + num_layers=3, + ) + for _ in range(config.decoder_layers) + ] + ) - self.class_embed = nn.ModuleList([_class_embed for _ in range(config.decoder_layers)]) + self.class_embed = nn.ModuleList([GroundingDinoContrastiveEmbedding(config) for _ in range(config.decoder_layers)]) # hack for box-refinement self.model.decoder.bbox_embed = self.bbox_embed # hack implementation for two-stage diff --git a/src/transformers/models/groupvit/modeling_groupvit.py b/src/transformers/models/groupvit/modeling_groupvit.py index 4c852db4668c..0c51c9052afc 100644 --- a/src/transformers/models/groupvit/modeling_groupvit.py +++ b/src/transformers/models/groupvit/modeling_groupvit.py @@ -748,22 +748,23 @@ class GroupViTPreTrainedModel(PreTrainedModel): input_modalities = ["image", "text"] supports_gradient_checkpointing = True + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" init_range = self.config.initializer_range if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=init_range) + module.weight.normal_(mean=0.0, std=init_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) factor = self.config.initializer_factor if isinstance(module, GroupViTTextEmbeddings): - module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) - module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) + module.token_embedding.weight.normal_(mean=0.0, std=factor * 0.02) + module.position_embedding.weight.normal_(mean=0.0, std=factor * 0.02) elif isinstance(module, GroupViTAttention): factor = self.config.initializer_factor in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py index a1d0a09e848f..2e7626714834 100644 --- a/src/transformers/models/helium/modeling_helium.py +++ b/src/transformers/models/helium/modeling_helium.py @@ -433,7 +433,7 @@ def forward( @auto_docstring class HeliumForCausalLM(HeliumPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/hiera/modeling_hiera.py b/src/transformers/models/hiera/modeling_hiera.py index af245b86220b..85cfa57ca7d8 100644 --- a/src/transformers/models/hiera/modeling_hiera.py +++ b/src/transformers/models/hiera/modeling_hiera.py @@ -776,6 +776,7 @@ class HieraPreTrainedModel(PreTrainedModel): input_modalities = "image" supports_gradient_checkpointing = True + @torch.no_grad() def _init_weights(self, module) -> None: """Initialize the weights""" std = self.config.initializer_range diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index 9729e481f402..84a8c98749fc 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -638,36 +638,37 @@ class HubertPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm1d)): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, nn.Conv1d): if is_deepspeed_zero3_enabled(): import deepspeed if hasattr(module, "weight_v") and hasattr(module, "weight_g"): with deepspeed.zero.GatheredParameters([module.weight_v, module.weight_g], modifier_rank=0): - nn.init.kaiming_normal_(module.weight.data) + nn.init.kaiming_normal_(module.weight) else: with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0): - nn.init.kaiming_normal_(module.weight.data) + nn.init.kaiming_normal_(module.weight) else: - nn.init.kaiming_normal_(module.weight.data) + nn.init.kaiming_normal_(module.weight) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, HubertModel): if hasattr(module, "masked_spec_embed"): - module.masked_spec_embed.data.uniform_() + module.masked_spec_embed.uniform_() elif isinstance(module, HubertForSequenceClassification): if hasattr(module, "layer_weights"): - module.layer_weights.data.fill_(1.0 / (self.config.num_hidden_layers + 1)) + module.layer_weights.fill_(1.0 / (self.config.num_hidden_layers + 1)) def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): """ @@ -992,7 +993,7 @@ def __init__(self, config, target_lang: Optional[str] = None): # Initialize weights and apply final processing self.post_init() - def tie_weights(self): + def tie_weights(self, missing_keys=None): """ This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when passing `target_lang=...` to `from_pretrained(...)`. diff --git a/src/transformers/models/hubert/modular_hubert.py b/src/transformers/models/hubert/modular_hubert.py index a0a7d805c973..d23cbc489b09 100644 --- a/src/transformers/models/hubert/modular_hubert.py +++ b/src/transformers/models/hubert/modular_hubert.py @@ -134,36 +134,37 @@ class HubertPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm1d)): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, nn.Conv1d): if is_deepspeed_zero3_enabled(): import deepspeed if hasattr(module, "weight_v") and hasattr(module, "weight_g"): with deepspeed.zero.GatheredParameters([module.weight_v, module.weight_g], modifier_rank=0): - nn.init.kaiming_normal_(module.weight.data) + nn.init.kaiming_normal_(module.weight) else: with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0): - nn.init.kaiming_normal_(module.weight.data) + nn.init.kaiming_normal_(module.weight) else: - nn.init.kaiming_normal_(module.weight.data) + nn.init.kaiming_normal_(module.weight) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, HubertModel): if hasattr(module, "masked_spec_embed"): - module.masked_spec_embed.data.uniform_() + module.masked_spec_embed.uniform_() elif isinstance(module, HubertForSequenceClassification): if hasattr(module, "layer_weights"): - module.layer_weights.data.fill_(1.0 / (self.config.num_hidden_layers + 1)) + module.layer_weights.fill_(1.0 / (self.config.num_hidden_layers + 1)) def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): """ diff --git a/src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py b/src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py index e3a55c296f6f..b55d9e3ccf5e 100644 --- a/src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +++ b/src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py @@ -290,16 +290,17 @@ class HunYuanDenseV1PreTrainedModel(PreTrainedModel): "attentions": HunYuanDenseV1Attention, } + @torch.no_grad() def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() class HunYuanDenseV1RotaryEmbedding(nn.Module): @@ -458,7 +459,7 @@ def forward( @auto_docstring class HunYuanDenseV1ForCausalLM(HunYuanDenseV1PreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py b/src/transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py index 31a03ac05cc7..945d2d1c27b1 100644 --- a/src/transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +++ b/src/transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py @@ -120,16 +120,17 @@ def __init__(self, config: HunYuanDenseV1Config, layer_idx: int): class HunYuanDenseV1PreTrainedModel(LlamaPreTrainedModel): + @torch.no_grad() def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() class HunYuanDenseV1RotaryEmbedding(LlamaRotaryEmbedding): diff --git a/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py b/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py index 732bafbd336d..3614c8f880c5 100644 --- a/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +++ b/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py @@ -243,20 +243,23 @@ def forward(self, hidden_states): return logits -class HunYuanMoEV1Experts(nn.ModuleList): - """ - ModuleList of experts. - """ +class HunYuanMoEV1Experts(nn.Module): + """Collection of expert weights stored as 3D tensors.""" def __init__(self, config: HunYuanMoEV1Config): super().__init__() - self.top_k = config.num_experts_per_tok self.num_experts = config.num_local_experts - for _ in range(self.num_experts): - self.append(HunYuanMoEV1MLP(config)) + self.hidden_dim = config.hidden_size + self.intermediate_dim = config.intermediate_size + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) + self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) + self.act_fn = ACT2FN[config.hidden_act] def forward( - self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, ) -> torch.Tensor: """ Args: @@ -267,14 +270,24 @@ def forward( (batch_size * sequence_length, hidden_dim) """ final_hidden_states = torch.zeros_like(hidden_states) - expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0) + num_experts = top_k_weights.shape[1] + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() for expert_idx in expert_hit: - idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) - current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1]) - current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None] - final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + expert_idx = expert_idx[0] + if expert_idx == num_experts: + continue + _, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) + current_hidden_states = self.act_fn(gate) * up + current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) + current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None] + final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) + return final_hidden_states @@ -293,6 +306,11 @@ def route_tokens_to_experts(self, hidden_states): routing_weights = F.softmax(hidden_states, dim=1, dtype=torch.float) routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + routing_weights = torch.zeros_like(hidden_states, dtype=torch.float32).scatter_( + 1, selected_experts, routing_weights + ) + return selected_experts, routing_weights.to(hidden_states.dtype) + return selected_experts, routing_weights.to(hidden_states.dtype) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -368,17 +386,6 @@ class HunYuanMoEV1PreTrainedModel(PreTrainedModel): "attentions": HunYuanMoEV1Attention, } - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - class HunYuanMoEV1RotaryEmbedding(nn.Module): inv_freq: torch.Tensor # fix linting for `register_buffer` @@ -536,7 +543,7 @@ def forward( @auto_docstring class HunYuanMoEV1ForCausalLM(HunYuanMoEV1PreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py b/src/transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py index 06269fedf784..7244f761f32c 100644 --- a/src/transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +++ b/src/transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py @@ -149,6 +149,11 @@ def route_tokens_to_experts(self, hidden_states): routing_weights = F.softmax(hidden_states, dim=1, dtype=torch.float) routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + routing_weights = torch.zeros_like(hidden_states, dtype=torch.float32).scatter_( + 1, selected_experts, routing_weights + ) + return selected_experts, routing_weights.to(hidden_states.dtype) + return selected_experts, routing_weights.to(hidden_states.dtype) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -177,17 +182,6 @@ def __init__(self, config: HunYuanMoEV1Config, layer_idx: int): class HunYuanMoEV1PreTrainedModel(LlamaPreTrainedModel): _can_compile_fullgraph = False - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - class HunYuanMoEV1RotaryEmbedding(HunYuanDenseV1RotaryEmbedding): pass diff --git a/src/transformers/models/ibert/modeling_ibert.py b/src/transformers/models/ibert/modeling_ibert.py index bbc86018a6ea..230e8fc04d42 100644 --- a/src/transformers/models/ibert/modeling_ibert.py +++ b/src/transformers/models/ibert/modeling_ibert.py @@ -585,21 +585,22 @@ class IBertPreTrainedModel(PreTrainedModel): config: IBertConfig base_model_prefix = "ibert" + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (QuantLinear, nn.Linear)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, (QuantEmbedding, nn.Embedding)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, (IntLayerNorm, nn.LayerNorm)): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, IBertLMHead): - module.bias.data.zero_() + module.bias.zero_() def resize_token_embeddings(self, new_num_tokens=None): raise NotImplementedError("`resize_token_embeddings` is not supported for I-BERT.") @@ -710,7 +711,10 @@ def forward( @auto_docstring class IBertForMaskedLM(IBertPreTrainedModel): - _tied_weights_keys = ["lm_head.decoder.bias", "lm_head.decoder.weight"] + _tied_weights_keys = { + "lm_head.decoder.weight": "ibert.embeddings.word_embeddings.weight", + "lm_head.decoder.bias": "lm_head.bias", + } def __init__(self, config): super().__init__(config) @@ -789,7 +793,6 @@ def __init__(self, config): self.decoder = nn.Linear(config.hidden_size, config.vocab_size) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - self.decoder.bias = self.bias def forward(self, features, **kwargs): x = self.dense(features) @@ -801,14 +804,6 @@ def forward(self, features, **kwargs): return x - def _tie_weights(self) -> None: - # For accelerate compatibility and to not break backward compatibility - if self.decoder.bias.device.type == "meta": - self.decoder.bias = self.bias - else: - # To tie those two weights if they get disconnected (on TPU or when the bias is resized) - self.bias = self.decoder.bias - @auto_docstring( custom_intro=""" diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 5cc389b79344..1e7fdb05360c 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -831,38 +831,39 @@ class IdeficsPreTrainedModel(PreTrainedModel): "attentions": OutputRecorder(IdeficsAttention, index=1, layer_name="self_attn"), } + @torch.no_grad() def _init_weights(self, module): # important: this ported version of Idefics isn't meant for training from scratch - only # inference and fine-tuning - so the proper init weights code has been removed - the m4 code # base should be used for training from scratch and it contains the correct code. std = self.config.initializer_range if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() elif isinstance(module, IdeficsRMSNorm): - module.weight.data.fill_(1.0) + module.weight.fill_(1.0) elif isinstance(module, IdeficsVisionEmbeddings): - module.class_embedding.data.normal_() + module.class_embedding.normal_() elif isinstance(module, IdeficsGatedCrossAttentionLayer): if self.config.alpha_initializer == "zeros": - module.alpha_cross_attn.data.zero_() - module.alpha_dense.data.zero_() + module.alpha_cross_attn.zero_() + module.alpha_dense.zero_() elif self.config.alpha_initializer == "ones": - module.alpha_cross_attn.data.fill_(1.0) - module.alpha_dense.data.fill_(1.0) + module.alpha_cross_attn.fill_(1.0) + module.alpha_dense.fill_(1.0) elif self.config.alpha_initializer in {"normal", "gaussian", "random"}: - module.alpha_cross_attn.data.normal_(mean=0.0, std=self.config.alphas_initializer_range) - module.alpha_dense.data.normal_(mean=0.0, std=self.config.alphas_initializer_range) + module.alpha_cross_attn.normal_(mean=0.0, std=self.config.alphas_initializer_range) + module.alpha_dense.normal_(mean=0.0, std=self.config.alphas_initializer_range) elif isinstance(module, IdeficsPerceiverResampler): - module.latents.data.normal_() + module.latents.normal_() @auto_docstring @@ -1105,7 +1106,7 @@ def forward( class IdeficsForVisionText2Text(IdeficsPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["model.embed_tokens.weight", "lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} def __init__(self, config, vision_model=None): super().__init__(config) @@ -1122,7 +1123,7 @@ def __init__(self, config, vision_model=None): # Initialize weights and apply final processing self.post_init() - def tie_weights(self): + def tie_weights(self, missing_keys=None): """ Overwrite `transformers.modeling_utils.PreTrainedModel.tie_weights` to handle the case of IdeficsDecoupledLinear and IdeficsDecoupledEmbedding. diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index 0ee1ca8bac68..c7c182be3a47 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -417,28 +417,29 @@ class Idefics2PreTrainedModel(PreTrainedModel): _supports_attention_backend = True + @torch.no_grad() def _init_weights(self, module): std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() elif isinstance(module, Idefics2RMSNorm): - module.weight.data.fill_(1.0) + module.weight.fill_(1.0) elif isinstance(module, nn.MultiheadAttention): module._reset_parameters() # native torch init elif isinstance(module, Idefics2MultiheadAttentionPoolingHead): - module.probe.data.normal_() + module.probe.normal_() elif isinstance(module, Idefics2PerceiverResampler): - module.latents.data.fill_(1.0) + module.latents.fill_(1.0) @auto_docstring( @@ -1010,7 +1011,7 @@ def forward( """ ) class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/idefics3/modeling_idefics3.py b/src/transformers/models/idefics3/modeling_idefics3.py index 1fe99f4e6855..208d08b23121 100644 --- a/src/transformers/models/idefics3/modeling_idefics3.py +++ b/src/transformers/models/idefics3/modeling_idefics3.py @@ -433,22 +433,23 @@ class Idefics3PreTrainedModel(PreTrainedModel): _supports_attention_backend = True + @torch.no_grad() def _init_weights(self, module): std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() elif isinstance(module, Idefics3RMSNorm): - module.weight.data.fill_(1.0) + module.weight.fill_(1.0) @auto_docstring( @@ -770,7 +771,7 @@ def forward( """ ) class Idefics3ForConditionalGeneration(Idefics3PreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.__init__ with Idefics2->Idefics3 def __init__(self, config): diff --git a/src/transformers/models/ijepa/modeling_ijepa.py b/src/transformers/models/ijepa/modeling_ijepa.py index 0a0c8fbb0321..a8c5878f35ef 100644 --- a/src/transformers/models/ijepa/modeling_ijepa.py +++ b/src/transformers/models/ijepa/modeling_ijepa.py @@ -324,27 +324,32 @@ class IJepaPreTrainedModel(PreTrainedModel): "attentions": IJepaSelfAttention, } + @torch.no_grad() def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid # `trunc_normal_cpu` not implemented in `half` issues - module.weight.data = nn.init.trunc_normal_( - module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range - ).to(module.weight.dtype) + module.weight.copy_( + nn.init.trunc_normal_(module.weight.to(torch.float32), mean=0.0, std=self.config.initializer_range).to( + module.weight.dtype + ) + ) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, IJepaEmbeddings): - module.position_embeddings.data = nn.init.trunc_normal_( - module.position_embeddings.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.position_embeddings.dtype) + module.position_embeddings.copy_( + nn.init.trunc_normal_( + module.position_embeddings.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + ) if module.mask_token is not None: - module.mask_token.data.zero_() + module.mask_token.zero_() class IJepaEncoder(nn.Module): diff --git a/src/transformers/models/ijepa/modular_ijepa.py b/src/transformers/models/ijepa/modular_ijepa.py index b37bc41d13bf..095945a3f39d 100644 --- a/src/transformers/models/ijepa/modular_ijepa.py +++ b/src/transformers/models/ijepa/modular_ijepa.py @@ -87,27 +87,32 @@ def forward( @auto_docstring class IJepaPreTrainedModel(ViTPreTrainedModel): + @torch.no_grad() def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid # `trunc_normal_cpu` not implemented in `half` issues - module.weight.data = nn.init.trunc_normal_( - module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range - ).to(module.weight.dtype) + module.weight.copy_( + nn.init.trunc_normal_(module.weight.to(torch.float32), mean=0.0, std=self.config.initializer_range).to( + module.weight.dtype + ) + ) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, IJepaEmbeddings): - module.position_embeddings.data = nn.init.trunc_normal_( - module.position_embeddings.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.position_embeddings.dtype) + module.position_embeddings.copy_( + nn.init.trunc_normal_( + module.position_embeddings.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + ) if module.mask_token is not None: - module.mask_token.data.zero_() + module.mask_token.zero_() class IJepaModel(IJepaPreTrainedModel, ViTModel): diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py index f1ae9ee0c926..b4c844eb4f49 100755 --- a/src/transformers/models/imagegpt/modeling_imagegpt.py +++ b/src/transformers/models/imagegpt/modeling_imagegpt.py @@ -369,18 +369,19 @@ class ImageGPTPreTrainedModel(PreTrainedModel): def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) + @torch.no_grad() def _init_weights(self, module): """Initialize the weights.""" if isinstance(module, (nn.Linear, Conv1D)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, ImageGPTLayerNorm): - module.weight.data.fill_(1.0) + module.weight.fill_(1.0) # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale @@ -388,10 +389,11 @@ def _init_weights(self, module): # > -- GPT-2 :: https://openai.com/blog/better-language-models/ # # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py - for name, p in module.named_parameters(): - if "c_proj" in name and "weight" in name: - # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block - p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) + if isinstance(module, PreTrainedModel): + for name, p in module.named_parameters(): + if "c_proj" in name and "weight" in name: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + p.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) @auto_docstring @@ -606,7 +608,7 @@ def forward( """ ) class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "transformer.wte.weight"} def __init__(self, config: ImageGPTConfig): super().__init__(config) diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index 901685a074ec..a8f618a43b69 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -250,6 +250,7 @@ class InformerPreTrainedModel(PreTrainedModel): input_modalities = "time" supports_gradient_checkpointing = True + @torch.no_grad() def _init_weights(self, module: nn.Module): super()._init_weights(module) if isinstance(module, InformerSinusoidalPositionalEmbedding): diff --git a/src/transformers/models/informer/modular_informer.py b/src/transformers/models/informer/modular_informer.py index 16d2f2d40105..0066f41a3e47 100644 --- a/src/transformers/models/informer/modular_informer.py +++ b/src/transformers/models/informer/modular_informer.py @@ -86,6 +86,7 @@ class InformerPreTrainedModel(PreTrainedModel): input_modalities = "time" supports_gradient_checkpointing = True + @torch.no_grad() def _init_weights(self, module: nn.Module): super()._init_weights(module) if isinstance(module, InformerSinusoidalPositionalEmbedding): diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index ceec6a15f6ac..25b54f2d2b9f 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -324,24 +324,25 @@ class InstructBlipPreTrainedModel(PreTrainedModel): "InstructBlipQFormerSelfOutput", ] + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" factor = self.config.initializer_range if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=factor) + module.weight.normal_(mean=0.0, std=factor) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=factor) + module.weight.normal_(mean=0.0, std=factor) elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, InstructBlipVisionEmbeddings): nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor) nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor) elif isinstance(module, (InstructBlipForConditionalGeneration, InstructBlipModel)): - module.query_tokens.data.zero_() + module.query_tokens.zero_() # Copied from transformers.models.blip.modeling_blip.BlipEncoder with Blip->InstructBlip @@ -961,11 +962,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def _tie_weights(self): - if not self.config.use_decoder_only_language_model: - self.language_model.encoder.embed_tokens = self.language_model.shared - self.language_model.decoder.embed_tokens = self.language_model.shared - def _preprocess_accelerate(self): r""" Some pre-processing hacks to make the model `accelerate` compatible. Check @@ -1160,12 +1156,6 @@ def get_encoder(self): def get_decoder(self): return self.language_model.get_decoder() - # Copied from transformers.models.instructblip.modeling_instructblip.InstructBlipModel._tie_weights - def _tie_weights(self): - if not self.config.use_decoder_only_language_model: - self.language_model.encoder.embed_tokens = self.language_model.shared - self.language_model.decoder.embed_tokens = self.language_model.shared - # Copied from transformers.models.instructblip.modeling_instructblip.InstructBlipModel._preprocess_accelerate def _preprocess_accelerate(self): r""" diff --git a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py index f2ec0fc9dbf0..f48baf11b925 100644 --- a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py @@ -147,24 +147,25 @@ class InstructBlipVideoPreTrainedModel(PreTrainedModel): "InstructBlipVideoQFormerSelfOutput", ] + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" factor = self.config.initializer_range if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=factor) + module.weight.normal_(mean=0.0, std=factor) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=factor) + module.weight.normal_(mean=0.0, std=factor) elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, InstructBlipVideoVisionEmbeddings): nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor) nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor) elif isinstance(module, (InstructBlipVideoForConditionalGeneration, InstructBlipVideoModel)): - module.query_tokens.data.zero_() + module.query_tokens.zero_() # Adapted from transformers.models.siglip.modeling_siglip.eager_attention_forward -> InstructBlipVideo doesn't cast attn weights to fp32 @@ -958,11 +959,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def _tie_weights(self): - if not self.config.use_decoder_only_language_model: - self.language_model.encoder.embed_tokens = self.language_model.shared - self.language_model.decoder.embed_tokens = self.language_model.shared - def _preprocess_accelerate(self): r""" Some pre-processing hacks to make the model `accelerate` compatible. Check @@ -1190,11 +1186,6 @@ def get_encoder(self): def get_decoder(self): return self.language_model.get_decoder() - def _tie_weights(self): - if not self.config.use_decoder_only_language_model: - self.language_model.encoder.embed_tokens = self.language_model.shared - self.language_model.decoder.embed_tokens = self.language_model.shared - def _preprocess_accelerate(self): r""" Some pre-processing hacks to make the model `accelerate` compatible. Check diff --git a/src/transformers/models/internvl/modeling_internvl.py b/src/transformers/models/internvl/modeling_internvl.py index 308bd8511038..6a5f82ab8a10 100644 --- a/src/transformers/models/internvl/modeling_internvl.py +++ b/src/transformers/models/internvl/modeling_internvl.py @@ -411,18 +411,19 @@ class InternVLVisionPreTrainedModel(PreTrainedModel): "attentions": InternVLVisionAttention, } + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" super()._init_weights(module) if isinstance(module, InternVLVisionEmbeddings): - module.cls_token.data.zero_() + module.cls_token.zero_() if module.mask_token is not None: - module.mask_token.data.zero_() + module.mask_token.zero_() if module.position_embeddings is not None: - module.position_embeddings.data.zero_() + module.position_embeddings.zero_() elif isinstance(module, InternVLVisionLayer): - module.lambda_1.data.fill_(self.config.layer_scale_init_value) - module.lambda_2.data.fill_(self.config.layer_scale_init_value) + module.lambda_1.fill_(self.config.layer_scale_init_value) + module.lambda_2.fill_(self.config.layer_scale_init_value) @auto_docstring @@ -766,7 +767,7 @@ class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin) "^multi_modal_projector": "model.multi_modal_projector", "^language_model.lm_head": "lm_head", } - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} def __init__(self, config: InternVLConfig): super().__init__(config) diff --git a/src/transformers/models/internvl/modular_internvl.py b/src/transformers/models/internvl/modular_internvl.py index 213c4a2dd81d..62ee383ce566 100644 --- a/src/transformers/models/internvl/modular_internvl.py +++ b/src/transformers/models/internvl/modular_internvl.py @@ -368,18 +368,19 @@ class InternVLVisionPreTrainedModel(PreTrainedModel): "attentions": InternVLVisionAttention, } + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" super()._init_weights(module) if isinstance(module, InternVLVisionEmbeddings): - module.cls_token.data.zero_() + module.cls_token.zero_() if module.mask_token is not None: - module.mask_token.data.zero_() + module.mask_token.zero_() if module.position_embeddings is not None: - module.position_embeddings.data.zero_() + module.position_embeddings.zero_() elif isinstance(module, InternVLVisionLayer): - module.lambda_1.data.fill_(self.config.layer_scale_init_value) - module.lambda_2.data.fill_(self.config.layer_scale_init_value) + module.lambda_1.fill_(self.config.layer_scale_init_value) + module.lambda_2.fill_(self.config.layer_scale_init_value) @auto_docstring diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 94d8cdc3f7be..a420121594ee 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -557,20 +557,23 @@ def forward(self, x): return down_proj -class JambaExperts(nn.ModuleList): - """ - ModuleList of experts. - """ +class JambaExperts(nn.Module): + """Collection of expert weights stored as 3D tensors.""" def __init__(self, config: JambaConfig): super().__init__() - self.top_k = config.num_experts_per_tok self.num_experts = config.num_local_experts - for _ in range(self.num_experts): - self.append(JambaMLP(config)) + self.hidden_dim = config.hidden_size + self.intermediate_dim = config.intermediate_size + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) + self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) + self.act_fn = ACT2FN[config.hidden_act] def forward( - self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, ) -> torch.Tensor: """ Args: @@ -581,14 +584,24 @@ def forward( (batch_size * sequence_length, hidden_dim) """ final_hidden_states = torch.zeros_like(hidden_states) - expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0) + num_experts = top_k_weights.shape[1] + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() for expert_idx in expert_hit: - idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) - current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1]) - current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None] - final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + expert_idx = expert_idx[0] + if expert_idx == num_experts: + continue + _, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) + current_hidden_states = self.act_fn(gate) * up + current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) + current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None] + final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) + return final_hidden_states @@ -717,13 +730,14 @@ class JambaPreTrainedModel(PreTrainedModel): "router_logits": OutputRecorder(nn.Linear, layer_name="router"), } + @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) if isinstance(module, JambaMambaMixer): A = torch.arange(1, module.ssm_state_size + 1)[None, :] A = A.expand(module.intermediate_size, -1).contiguous() - module.A_log.data.copy_(torch.log(A)) - module.D.data.fill_(1.0) + module.A_log.copy_(torch.log(A)) + module.D.fill_(1.0) ALL_DECODER_LAYER_TYPES = {"attention": JambaAttentionDecoderLayer, "mamba": JambaMambaDecoderLayer} @@ -916,7 +930,7 @@ def load_balancing_loss_func( @auto_docstring class JambaForCausalLM(JambaPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/jamba/modular_jamba.py b/src/transformers/models/jamba/modular_jamba.py index c6cfe339fabb..1c362c3f802a 100644 --- a/src/transformers/models/jamba/modular_jamba.py +++ b/src/transformers/models/jamba/modular_jamba.py @@ -607,13 +607,14 @@ class JambaPreTrainedModel(PreTrainedModel): "router_logits": OutputRecorder(nn.Linear, layer_name="router"), } + @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) if isinstance(module, JambaMambaMixer): A = torch.arange(1, module.ssm_state_size + 1)[None, :] A = A.expand(module.intermediate_size, -1).contiguous() - module.A_log.data.copy_(torch.log(A)) - module.D.data.fill_(1.0) + module.A_log.copy_(torch.log(A)) + module.D.fill_(1.0) @auto_docstring diff --git a/src/transformers/models/janus/modeling_janus.py b/src/transformers/models/janus/modeling_janus.py index 4cad10fc4216..fae49b7e3719 100644 --- a/src/transformers/models/janus/modeling_janus.py +++ b/src/transformers/models/janus/modeling_janus.py @@ -1164,7 +1164,7 @@ def forward( class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["model.language_model.embed_tokens.weight", "lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} output_modalities = ["image", "text"] _can_compile_fullgraph = True diff --git a/src/transformers/models/janus/modular_janus.py b/src/transformers/models/janus/modular_janus.py index 87cc11d73cda..24d1598a8e2b 100644 --- a/src/transformers/models/janus/modular_janus.py +++ b/src/transformers/models/janus/modular_janus.py @@ -980,7 +980,7 @@ def forward( class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["model.language_model.embed_tokens.weight", "lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} output_modalities = ["image", "text"] _can_compile_fullgraph = True diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index 1beb7be7626c..28a3dc151d70 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -582,22 +582,23 @@ class JetMoePreTrainedModel(PreTrainedModel): "attentions": OutputRecorder(JetMoeAttention, index=1), } + @torch.no_grad() def _init_weights(self, module): """Initialize the weights.""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, JetMoeRMSNorm): - module.weight.data.fill_(1.0) + module.weight.fill_(1.0) elif isinstance(module, JetMoeParallelExperts): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) elif isinstance(module, JetMoeMoA | JetMoeMoE): - module.bias.data.zero_() + module.bias.zero_() @auto_docstring @@ -766,7 +767,7 @@ def load_balancing_loss_func( class JetMoeForCausalLM(JetMoePreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/jetmoe/modular_jetmoe.py b/src/transformers/models/jetmoe/modular_jetmoe.py index d994388969e3..82c8e582d070 100644 --- a/src/transformers/models/jetmoe/modular_jetmoe.py +++ b/src/transformers/models/jetmoe/modular_jetmoe.py @@ -435,22 +435,23 @@ class JetMoePreTrainedModel(MixtralPreTrainedModel): _supports_flash_attn = True _supports_sdpa = True + @torch.no_grad() def _init_weights(self, module): """Initialize the weights.""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, JetMoeRMSNorm): - module.weight.data.fill_(1.0) + module.weight.fill_(1.0) elif isinstance(module, JetMoeParallelExperts): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) elif isinstance(module, JetMoeMoA | JetMoeMoE): - module.bias.data.zero_() + module.bias.zero_() @auto_docstring @@ -532,7 +533,7 @@ def forward( class JetMoeForCausalLM(JetMoePreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/kosmos2/modeling_kosmos2.py b/src/transformers/models/kosmos2/modeling_kosmos2.py index 62aeb8d1d1ad..5726eeacaad6 100644 --- a/src/transformers/models/kosmos2/modeling_kosmos2.py +++ b/src/transformers/models/kosmos2/modeling_kosmos2.py @@ -1120,6 +1120,7 @@ class Kosmos2PreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True + @torch.no_grad() def _init_weights(self, module: nn.Module): """Initialize the weights""" if isinstance(self, Kosmos2VisionModel): @@ -1162,15 +1163,15 @@ def _init_weights(self, module: nn.Module): nn.init.normal_(module.dense.weight, std=std) nn.init.normal_(module.latent_query) elif isinstance(module, Kosmos2TextTransformer): - module.embed_tokens.weight.data.normal_(mean=0.0, std=std) + module.embed_tokens.weight.normal_(mean=0.0, std=std) if module.embed_tokens.padding_idx is not None: - module.embed_tokens.weight.data[module.embed_tokens.padding_idx].zero_() + module.embed_tokens.weight[module.embed_tokens.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() class Kosmos2VisionModel(Kosmos2PreTrainedModel): @@ -1277,7 +1278,7 @@ def forward( ) class Kosmos2TextForCausalLM(Kosmos2PreTrainedModel, GenerationMixin): config: Kosmos2TextConfig - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} def __init__(self, config: Kosmos2TextConfig): super().__init__(config) @@ -1617,7 +1618,7 @@ def forward( class Kosmos2ForConditionalGeneration(Kosmos2PreTrainedModel, GenerationMixin): config: Kosmos2Config main_input_name = "pixel_values" - _tied_weights_keys = ["text_model.lm_head.weight"] + _tied_weights_keys = {"text_model.lm_head.weight": "text_model.model.embed_tokens.weight"} def __init__(self, config: Kosmos2Config): super().__init__(config) diff --git a/src/transformers/models/kosmos2_5/modeling_kosmos2_5.py b/src/transformers/models/kosmos2_5/modeling_kosmos2_5.py index f8756aa9b000..c0313f33eca2 100644 --- a/src/transformers/models/kosmos2_5/modeling_kosmos2_5.py +++ b/src/transformers/models/kosmos2_5/modeling_kosmos2_5.py @@ -1227,6 +1227,7 @@ class Kosmos2_5PreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_attention_backend = True + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(self, Kosmos2_5VisionModel): @@ -1237,19 +1238,19 @@ def _init_weights(self, module): elif isinstance(self, (Kosmos2_5Model, Kosmos2_5ForConditionalGeneration)): std = self.config.text_config.init_std if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, (nn.LayerNorm, Kosmos2_5LayerNorm)): - module.weight.data.fill_(1.0) + module.weight.fill_(1.0) if getattr(module, "bias", None) is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, Kosmos2_5ImageToTextProjection): - module.latent_query.data.normal_(mean=0.0, std=1.0) + module.latent_query.normal_(mean=0.0, std=1.0) class Kosmos2_5VisionModel(Kosmos2_5PreTrainedModel): @@ -1503,7 +1504,7 @@ def forward( class Kosmos2_5TextForCausalLM(Kosmos2_5PreTrainedModel): config_class = Kosmos2_5TextConfig input_modalities = "text" - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} def __init__(self, config: Kosmos2_5TextConfig): super().__init__(config) @@ -1660,7 +1661,6 @@ def prepare_inputs_for_generation( ) class Kosmos2_5ForConditionalGeneration(Kosmos2_5PreTrainedModel, GenerationMixin): config_class = Kosmos2_5Config - _tied_weights_keys = ["text_model.lm_head.weight"] def __init__(self, config: Kosmos2_5Config): super().__init__(config) diff --git a/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py b/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py index e3f9824de41d..989fd9706c79 100644 --- a/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +++ b/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py @@ -124,21 +124,22 @@ class KyutaiSpeechToTextPreTrainedModel(PreTrainedModel): main_input_name = "input_ids" + @torch.no_grad() def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, KyutaiSpeechToTextFlexibleLinear): - module.weight.data.normal_() + module.weight.normal_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, KyutaiSpeechToTextRMSNorm): - module.weight.data.fill_(1.0) + module.weight.fill_(1.0) class KyutaiSpeechToTextConv1dPaddingCache: @@ -1090,7 +1091,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( @auto_docstring class KyutaiSpeechToTextForConditionalGeneration(KyutaiSpeechToTextPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} _keep_in_fp32_modules_strict = ["codec_model"] diff --git a/src/transformers/models/layoutlm/modeling_layoutlm.py b/src/transformers/models/layoutlm/modeling_layoutlm.py index ec1c558dad73..146c395aa9ee 100644 --- a/src/transformers/models/layoutlm/modeling_layoutlm.py +++ b/src/transformers/models/layoutlm/modeling_layoutlm.py @@ -398,16 +398,9 @@ def __init__(self, config): # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. - self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` - self.decoder.bias = self.bias - - def _tie_weights(self): - self.decoder.bias = self.bias - def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -431,21 +424,22 @@ class LayoutLMPreTrainedModel(PreTrainedModel): base_model_prefix = "layoutlm" supports_gradient_checkpointing = True + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, LayoutLMLayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, LayoutLMLMPredictionHead): - module.bias.data.zero_() + module.bias.zero_() @auto_docstring @@ -577,7 +571,10 @@ def forward( @auto_docstring class LayoutLMForMaskedLM(LayoutLMPreTrainedModel): - _tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"] + _tied_weights_keys = { + "cls.predictions.decoder.bias": "cls.predictions.bias", + "cls.predictions.decoder.weight": "layoutlm.embeddings.word_embeddings.weight", + } def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py index faf3979d1edb..e276407a720b 100755 --- a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py +++ b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py @@ -458,26 +458,27 @@ class LayoutLMv2PreTrainedModel(PreTrainedModel): base_model_prefix = "layoutlmv2" input_modalities = ["image", "text"] + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, LayoutLMv2SelfAttention): if self.config.fast_qkv: - module.q_bias.data.zero_() - module.v_bias.data.zero_() + module.q_bias.zero_() + module.v_bias.zero_() elif isinstance(module, LayoutLMv2Model): if hasattr(module, "visual_segment_embedding"): - module.visual_segment_embedding.data.normal_(mean=0.0, std=self.config.initializer_range) + module.visual_segment_embedding.normal_(mean=0.0, std=self.config.initializer_range) def my_convert_sync_batchnorm(module, process_group=None): diff --git a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py index 3aa97051f855..a04875e72646 100644 --- a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py @@ -203,23 +203,24 @@ class LayoutLMv3PreTrainedModel(PreTrainedModel): base_model_prefix = "layoutlmv3" input_modalities = ["image", "text"] + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, LayoutLMv3Model): if self.config.visual_embed: - module.cls_token.data.zero_() - module.pos_embed.data.zero_() + module.cls_token.zero_() + module.pos_embed.zero_() class LayoutLMv3SelfAttention(nn.Module): diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index f5b5787a9ddf..418f60f77a61 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -1067,16 +1067,17 @@ class LEDPreTrainedModel(PreTrainedModel): base_model_prefix = "led" supports_gradient_checkpointing = True + @torch.no_grad() def _init_weights(self, module): std = self.config.init_std if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() @property def dummy_inputs(self): @@ -1290,7 +1291,7 @@ class LEDEncoder(LEDPreTrainedModel): embed_tokens (nn.Embedding): output embedding """ - def __init__(self, config: LEDConfig, embed_tokens: Optional[nn.Embedding] = None): + def __init__(self, config: LEDConfig): super().__init__(config) self.dropout = config.dropout @@ -1313,10 +1314,7 @@ def __init__(self, config: LEDConfig, embed_tokens: Optional[nn.Embedding] = Non f"Expected {config.num_hidden_layers}, given {len(config.attention_window)}" ) - if embed_tokens is not None: - self.embed_tokens = embed_tokens - else: - self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) + self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) self.embed_positions = LEDLearnedPositionalEmbedding( self.max_source_positions, @@ -1553,17 +1551,14 @@ class LEDDecoder(LEDPreTrainedModel): embed_tokens (nn.Embedding): output embedding """ - def __init__(self, config: LEDConfig, embed_tokens: Optional[nn.Embedding] = None): + def __init__(self, config: LEDConfig): super().__init__(config) self.dropout = config.dropout self.layerdrop = config.decoder_layerdrop self.padding_idx = config.pad_token_id self.max_target_positions = config.max_decoder_position_embeddings - if embed_tokens is not None: - self.embed_tokens = embed_tokens - else: - self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) + self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) self.embed_positions = LEDLearnedPositionalEmbedding( self.max_target_positions, @@ -1763,7 +1758,10 @@ def forward( @auto_docstring class LEDModel(LEDPreTrainedModel): - _tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"] + _tied_weights_keys = { + "encoder.embed_tokens.weight": "shared.weight", + "decoder.embed_tokens.weight": "shared.weight", + } def __init__(self, config: LEDConfig): super().__init__(config) @@ -1771,8 +1769,8 @@ def __init__(self, config: LEDConfig): padding_idx, vocab_size = config.pad_token_id, config.vocab_size self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) - self.encoder = LEDEncoder(config, self.shared) - self.decoder = LEDDecoder(config, self.shared) + self.encoder = LEDEncoder(config) + self.decoder = LEDDecoder(config) # Initialize weights and apply final processing self.post_init() @@ -1908,7 +1906,9 @@ def forward( class LEDForConditionalGeneration(LEDPreTrainedModel, GenerationMixin): base_model_prefix = "led" _keys_to_ignore_on_load_missing = ["final_logits_bias"] - _tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "lm_head.weight"] + _tied_weights_keys = { + "lm_head.weight": "led.shared.weight", + } def __init__(self, config: LEDConfig): super().__init__(config) @@ -2106,8 +2106,6 @@ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): """ ) class LEDForSequenceClassification(LEDPreTrainedModel): - _tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"] - def __init__(self, config: LEDConfig, **kwargs): warnings.warn( "The `transformers.LEDForSequenceClassification` class is deprecated and will be removed in version 5 of" @@ -2252,8 +2250,6 @@ def forward( @auto_docstring class LEDForQuestionAnswering(LEDPreTrainedModel): - _tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"] - def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/levit/modeling_levit.py b/src/transformers/models/levit/modeling_levit.py index 5d331081721c..ca7cc7589be7 100644 --- a/src/transformers/models/levit/modeling_levit.py +++ b/src/transformers/models/levit/modeling_levit.py @@ -472,15 +472,16 @@ class LevitPreTrainedModel(PreTrainedModel): input_modalities = "image" _no_split_modules = ["LevitResidualLayer"] + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d)): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) @auto_docstring diff --git a/src/transformers/models/lfm2/modeling_lfm2.py b/src/transformers/models/lfm2/modeling_lfm2.py index e8f8cf4e40e5..75b25544c750 100644 --- a/src/transformers/models/lfm2/modeling_lfm2.py +++ b/src/transformers/models/lfm2/modeling_lfm2.py @@ -695,7 +695,7 @@ def forward( @auto_docstring class Lfm2ForCausalLM(Lfm2PreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py b/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py index ebc8d892bf31..c9d557457e16 100644 --- a/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py +++ b/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py @@ -24,6 +24,7 @@ import torch.nn.functional as F from torch import nn +from ...activations import ACT2FN from ...cache_utils import Cache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub @@ -144,19 +145,23 @@ def forward(self, x): return self.w2(F.silu(self.w1(x)) * self.w3(x)) -class Lfm2MoeExperts(nn.ModuleList): - """ - ModuleList of experts. - """ +class Lfm2MoeExperts(nn.Module): + """Collection of expert weights stored as 3D tensors.""" def __init__(self, config): super().__init__() self.num_experts = config.num_experts - for _ in range(config.num_experts): - self.append(Lfm2MoeMLP(config, intermediate_size=config.moe_intermediate_size)) + self.hidden_dim = config.hidden_size + self.intermediate_dim = config.moe_intermediate_size + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) + self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) + self.act_fn = ACT2FN[config.hidden_act] def forward( - self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, ) -> torch.Tensor: """ Args: @@ -167,14 +172,24 @@ def forward( (batch_size * sequence_length, hidden_dim) """ final_hidden_states = torch.zeros_like(hidden_states) - expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0) + num_experts = top_k_weights.shape[1] + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() for expert_idx in expert_hit: - idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) - current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1]) - current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None] - final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + expert_idx = expert_idx[0] + if expert_idx == num_experts: + continue + _, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) + current_hidden_states = self.act_fn(gate) * up + current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) + current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None] + final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) + return final_hidden_states @@ -762,7 +777,7 @@ def forward( @auto_docstring class Lfm2MoeForCausalLM(Lfm2MoePreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/lfm2_vl/modeling_lfm2_vl.py b/src/transformers/models/lfm2_vl/modeling_lfm2_vl.py index 317786625ba8..eb761aabf4fa 100755 --- a/src/transformers/models/lfm2_vl/modeling_lfm2_vl.py +++ b/src/transformers/models/lfm2_vl/modeling_lfm2_vl.py @@ -307,7 +307,7 @@ def forward( ) class Lfm2VlForConditionalGeneration(Lfm2VlPreTrainedModel, GenerationMixin): _checkpoint_conversion_mapping = {} - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} def __init__(self, config: Lfm2VlConfig): super().__init__(config) diff --git a/src/transformers/models/lilt/modeling_lilt.py b/src/transformers/models/lilt/modeling_lilt.py index 31157b749e94..ec924e5000d6 100644 --- a/src/transformers/models/lilt/modeling_lilt.py +++ b/src/transformers/models/lilt/modeling_lilt.py @@ -500,19 +500,20 @@ class LiltPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = [] + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) @auto_docstring diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 3d8340091bee..2000c8092fb2 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -438,7 +438,7 @@ def forward( @auto_docstring class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/llama4/modeling_llama4.py b/src/transformers/models/llama4/modeling_llama4.py index 6b012a5b096a..c58848fbf299 100644 --- a/src/transformers/models/llama4/modeling_llama4.py +++ b/src/transformers/models/llama4/modeling_llama4.py @@ -54,7 +54,7 @@ def __init__(self, config: Llama4TextConfig): self.intermediate_size = config.intermediate_size self.hidden_size = config.hidden_size self.expert_dim = self.intermediate_size - self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim)) + self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size, 2 * self.expert_dim)) self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size))) self.act_fn = ACT2FN[config.hidden_act] @@ -473,6 +473,7 @@ class Llama4PreTrainedModel(PreTrainedModel): _can_compile_fullgraph = True _supports_attention_backend = True + @torch.no_grad() def _init_weights(self, module): std = ( self.config.initializer_range @@ -480,24 +481,24 @@ def _init_weights(self, module): else self.config.text_config.initializer_range ) if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() elif isinstance(module, Llama4TextRMSNorm): - module.weight.data.fill_(1.0) + module.weight.fill_(1.0) elif isinstance(module, Llama4TextExperts): - module.gate_up_proj.data.normal_(mean=0.0, std=std) - module.down_proj.data.normal_(mean=0.0, std=std) + module.gate_up_proj.normal_(mean=0.0, std=std) + module.down_proj.normal_(mean=0.0, std=std) elif isinstance(module, Llama4VisionModel): - module.class_embedding.data.normal_(std=module.scale) - module.positional_embedding_vlm.data.normal_(std=module.scale) + module.class_embedding.normal_(std=module.scale) + module.positional_embedding_vlm.normal_(std=module.scale) @auto_docstring @@ -604,7 +605,7 @@ def forward( class Llama4ForCausalLM(Llama4PreTrainedModel, GenerationMixin): _no_split_modules = ["Llama4TextDecoderLayer"] base_model_prefix = "language_model" - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} config: Llama4TextConfig diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index 0ee351b03b54..7ed86f7cd6be 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -313,7 +313,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin): "^multi_modal_projector": "model.multi_modal_projector", "^language_model.lm_head": "lm_head", } - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} def __init__(self, config: LlavaConfig): super().__init__(config) diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index 7e01bbb385f8..312ae609ef01 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -235,16 +235,17 @@ class LlavaNextPreTrainedModel(PreTrainedModel): _supports_flex_attn = True _supports_attention_backend = True + @torch.no_grad() def _init_weights(self, module): std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, LlavaNextModel): embed_std = 1 / math.sqrt(self.config.text_config.hidden_size) - module.image_newline.data.normal_(mean=0.0, std=embed_std) + module.image_newline.normal_(mean=0.0, std=embed_std) @auto_docstring( @@ -540,7 +541,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi "^image_newline": "model.image_newline", "^language_model.lm_head": "lm_head", } - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} def __init__(self, config: LlavaNextConfig): super().__init__(config) diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index 98b46e13f587..32b5f8a00932 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -176,16 +176,17 @@ class LlavaNextVideoPreTrainedModel(PreTrainedModel): _supports_flex_attn = True _supports_attention_backend = True + @torch.no_grad() def _init_weights(self, module): std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, LlavaNextVideoModel): embed_std = 1 / math.sqrt(self.config.text_config.hidden_size) - module.image_newline.data.normal_(mean=0.0, std=embed_std) + module.image_newline.normal_(mean=0.0, std=embed_std) def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): @@ -679,7 +680,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene "^image_newline": "model.image_newline", "^language_model.lm_head": "lm_head", } - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} def __init__(self, config: LlavaNextVideoConfig): super().__init__(config) diff --git a/src/transformers/models/llava_onevision/modeling_llava_onevision.py b/src/transformers/models/llava_onevision/modeling_llava_onevision.py index 4484d4647da1..15ed2f3a6645 100644 --- a/src/transformers/models/llava_onevision/modeling_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modeling_llava_onevision.py @@ -117,16 +117,17 @@ class LlavaOnevisionPreTrainedModel(PreTrainedModel): _supports_flex_attn = True _supports_attention_backend = True + @torch.no_grad() def _init_weights(self, module): std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, LlavaOnevisionModel): embed_std = 1 / math.sqrt(self.config.text_config.hidden_size) - module.image_newline.data.normal_(mean=0.0, std=embed_std) + module.image_newline.normal_(mean=0.0, std=embed_std) class LlavaOnevisionMultiModalProjector(nn.Module): @@ -667,7 +668,7 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene "^image_newline": "model.image_newline", "^language_model.lm_head": "lm_head", } - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} def __init__(self, config: LlavaOnevisionConfig): super().__init__(config) diff --git a/src/transformers/models/longcat_flash/modeling_longcat_flash.py b/src/transformers/models/longcat_flash/modeling_longcat_flash.py index c082eb43ee4d..516bfee99677 100644 --- a/src/transformers/models/longcat_flash/modeling_longcat_flash.py +++ b/src/transformers/models/longcat_flash/modeling_longcat_flash.py @@ -164,7 +164,7 @@ def forward(self, hidden_states): topk_indices = self.get_topk_indices(scores) topk_weights = scores.gather(1, topk_indices) topk_weights = topk_weights * self.routed_scaling_factor - return topk_indices, topk_weights + return topk_weights.to(router_logits.dtype), topk_indices @torch.no_grad() def get_topk_indices(self, scores): @@ -173,29 +173,51 @@ def get_topk_indices(self, scores): return topk_indices -class LongcatFlashExperts(nn.ModuleList): +class LongcatFlashExperts(nn.Module): def __init__(self, config): super().__init__() self.intermediate_size = config.expert_ffn_hidden_size self.hidden_size = config.hidden_size - self.num_experts = config.n_routed_experts + config.zero_expert_num - self.zero_expert_num = config.zero_expert_num + self.num_routed_experts = config.n_routed_experts + self.zero_expert_num = config.zero_expert_num or 0 + self.total_experts = self.num_routed_experts + self.zero_expert_num + self.act_fn = ACT2FN[config.hidden_act] - self.extend( - [LongcatFlashMLP(config, intermediate_size=self.intermediate_size) for _ in range(self.num_experts)] - + [nn.Identity() for _ in range(self.zero_expert_num)] - ) + if self.num_routed_experts > 0: + self.gate_up_proj = nn.Parameter( + torch.empty(self.total_experts, 2 * self.intermediate_size, self.hidden_size) + ) + self.down_proj = nn.Parameter( + torch.empty(self.num_routed_experts, self.hidden_size, self.intermediate_size) + ) + else: + self.register_parameter("gate_up_proj", None) + self.register_parameter("down_proj", None) def forward(self, hidden_states, top_k_index, top_k_weights): final_hidden_states = torch.zeros_like(hidden_states) - expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0) - - expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - for expert_idx in expert_hit: - idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) - current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1]) - current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None] - final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + if top_k_index.numel() == 0: + return final_hidden_states + + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.total_experts).permute(2, 1, 0) + + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero(as_tuple=False) + for expert_idx_tensor in expert_hit: + expert_idx = int(expert_idx_tensor.item()) + selection_idx, token_idx = torch.where(expert_mask[expert_idx].squeeze(0)) + if token_idx.numel() == 0: + continue + current_state = hidden_states[token_idx] + + if expert_idx >= self.num_routed_experts or self.gate_up_proj is None: + current_hidden_states = current_state + else: + gate, up = F.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) + current_hidden_states = self.act_fn(gate) * up + current_hidden_states = F.linear(current_hidden_states, self.down_proj[expert_idx]) + + current_hidden_states = current_hidden_states * top_k_weights[token_idx, selection_idx, None] + final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(hidden_states.dtype)) return final_hidden_states @@ -215,7 +237,7 @@ def __init__(self, config): def forward(self, hidden_states): orig_shape = hidden_states.shape - topk_indices, topk_weights = self.router(hidden_states) + topk_weights, topk_indices = self.router(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) hidden_states = self.experts(hidden_states, topk_indices, topk_weights).view(*orig_shape) return hidden_states @@ -535,10 +557,14 @@ class LongcatFlashPreTrainedModel(PreTrainedModel): "attentions": LongcatFlashMLA, } + @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) if isinstance(module, LongcatFlashTopkRouter): - module.classifier.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.classifier.weight.normal_(mean=0.0, std=self.config.initializer_range) + if isinstance(module, LongcatFlashExperts): + module.gate_up_proj.normal_(mean=0.0, std=self.config.initializer_range) + module.down_proj.normal_(mean=0.0, std=self.config.initializer_range) @auto_docstring @@ -630,7 +656,7 @@ def forward( @auto_docstring class LongcatFlashForCausalLM(LongcatFlashPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} _keys_to_ignore_on_load_unexpected = [r"model\.mtp.*"] diff --git a/src/transformers/models/longcat_flash/modular_longcat_flash.py b/src/transformers/models/longcat_flash/modular_longcat_flash.py index 588c7147cfd4..56fe0be969f6 100644 --- a/src/transformers/models/longcat_flash/modular_longcat_flash.py +++ b/src/transformers/models/longcat_flash/modular_longcat_flash.py @@ -20,6 +20,7 @@ import torch.nn.functional as F from torch import nn +from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs @@ -33,14 +34,13 @@ DeepseekV3ForCausalLM, DeepseekV3MLP, DeepseekV3Model, - DeepseekV3PreTrainedModel, DeepseekV3RMSNorm, DeepseekV3RotaryEmbedding, DeepseekV3TopkRouter, apply_rotary_pos_emb_interleave, eager_attention_forward, ) - +from .configuration_longcat_flash import LongcatFlashConfig logger = logging.get_logger(__name__) @@ -90,32 +90,54 @@ def forward(self, hidden_states): topk_indices = self.get_topk_indices(scores) topk_weights = scores.gather(1, topk_indices) topk_weights = topk_weights * self.routed_scaling_factor - return topk_indices, topk_weights + return topk_weights.to(router_logits.dtype), topk_indices -class LongcatFlashExperts(nn.ModuleList): +class LongcatFlashExperts(nn.Module): def __init__(self, config): super().__init__() self.intermediate_size = config.expert_ffn_hidden_size self.hidden_size = config.hidden_size - self.num_experts = config.n_routed_experts + config.zero_expert_num - self.zero_expert_num = config.zero_expert_num - - self.extend( - [LongcatFlashMLP(config, intermediate_size=self.intermediate_size) for _ in range(self.num_experts)] - + [nn.Identity() for _ in range(self.zero_expert_num)] - ) + self.num_routed_experts = config.n_routed_experts + self.zero_expert_num = config.zero_expert_num or 0 + self.total_experts = self.num_routed_experts + self.zero_expert_num + self.act_fn = ACT2FN[config.hidden_act] + + if self.num_routed_experts > 0: + self.gate_up_proj = nn.Parameter( + torch.empty(self.total_experts, 2 * self.intermediate_size, self.hidden_size) + ) + self.down_proj = nn.Parameter( + torch.empty(self.num_routed_experts, self.hidden_size, self.intermediate_size) + ) + else: + self.register_parameter("gate_up_proj", None) + self.register_parameter("down_proj", None) def forward(self, hidden_states, top_k_index, top_k_weights): final_hidden_states = torch.zeros_like(hidden_states) - expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0) - - expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - for expert_idx in expert_hit: - idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) - current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1]) - current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None] - final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + if top_k_index.numel() == 0: + return final_hidden_states + + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.total_experts).permute(2, 1, 0) + + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero(as_tuple=False) + for expert_idx_tensor in expert_hit: + expert_idx = int(expert_idx_tensor.item()) + selection_idx, token_idx = torch.where(expert_mask[expert_idx].squeeze(0)) + if token_idx.numel() == 0: + continue + current_state = hidden_states[token_idx] + + if expert_idx >= self.num_routed_experts or self.gate_up_proj is None: + current_hidden_states = current_state + else: + gate, up = F.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) + current_hidden_states = self.act_fn(gate) * up + current_hidden_states = F.linear(current_hidden_states, self.down_proj[expert_idx]) + + current_hidden_states = current_hidden_states * top_k_weights[token_idx, selection_idx, None] + final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(hidden_states.dtype)) return final_hidden_states @@ -135,7 +157,7 @@ def __init__(self, config): def forward(self, hidden_states): orig_shape = hidden_states.shape - topk_indices, topk_weights = self.router(hidden_states) + topk_weights, topk_indices = self.router(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) hidden_states = self.experts(hidden_states, topk_indices, topk_weights).view(*orig_shape) return hidden_states @@ -301,16 +323,31 @@ def forward( return hidden_states -class LongcatFlashPreTrainedModel(DeepseekV3PreTrainedModel): +@auto_docstring +class LongcatFlashPreTrainedModel(PreTrainedModel): + config: LongcatFlashConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LongcatFlashDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + _can_compile_fullgraph = False + _supports_attention_backend = True _can_record_outputs = { "hidden_states": LongcatFlashDecoderLayer, "attentions": LongcatFlashMLA, } + @torch.no_grad() def _init_weights(self, module): - PreTrainedModel._init_weights(self, module) + super()._init_weights(module) if isinstance(module, LongcatFlashTopkRouter): - module.classifier.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.classifier.weight.normal_(mean=0.0, std=self.config.initializer_range) + if isinstance(module, LongcatFlashExperts): + module.gate_up_proj.normal_(mean=0.0, std=self.config.initializer_range) + module.down_proj.normal_(mean=0.0, std=self.config.initializer_range) class LongcatFlashModel(DeepseekV3Model): diff --git a/src/transformers/models/longformer/modeling_longformer.py b/src/transformers/models/longformer/modeling_longformer.py index 8efb326c4c28..1168e9366f1d 100755 --- a/src/transformers/models/longformer/modeling_longformer.py +++ b/src/transformers/models/longformer/modeling_longformer.py @@ -1273,7 +1273,6 @@ def __init__(self, config): self.decoder = nn.Linear(config.hidden_size, config.vocab_size) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - self.decoder.bias = self.bias def forward(self, features, **kwargs): x = self.dense(features) @@ -1285,14 +1284,6 @@ def forward(self, features, **kwargs): return x - def _tie_weights(self): - # To tie those two weights if they get disconnected (on TPU or when the bias is resized) - # For accelerate compatibility and to not break backward compatibility - if self.decoder.bias.device.type == "meta": - self.decoder.bias = self.bias - else: - self.bias = self.decoder.bias - @auto_docstring class LongformerPreTrainedModel(PreTrainedModel): @@ -1301,19 +1292,20 @@ class LongformerPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["LongformerSelfAttention"] + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) @auto_docstring @@ -1557,7 +1549,10 @@ def forward( @auto_docstring class LongformerForMaskedLM(LongformerPreTrainedModel): - _tied_weights_keys = ["lm_head.decoder"] + _tied_weights_keys = { + "lm_head.decoder.weight": "longformer.embeddings.word_embeddings.weight", + "lm_head.decoder.bias": "lm_head.bias", + } def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index fbc9d4494e64..0aea13dc01b8 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -1176,75 +1176,45 @@ def dummy_inputs(self): } return dummy_inputs - def _try_load_missing_tied_module(self, key): - module = self - key = key.removesuffix(".weight") - for sub_key in key.split("."): - if not hasattr(module, sub_key): - return - module = getattr(module, sub_key) - - self._tie_embedding_weights(module, self.shared) - - @classmethod - def from_pretrained(self, *args, **kwargs): - requested_loading_info = kwargs.get("output_loading_info", False) - kwargs["output_loading_info"] = True - model, loading_info = super().from_pretrained(*args, **kwargs) - missing_keys = loading_info.get("missing_keys", []) - - if hasattr(model, "shared") and hasattr(model, "_tied_weights_keys"): - for missing_key in missing_keys: - logger.warning( - f"Recovering a missing tied weight {missing_key} from a legacy LongT5 checkpoint. " - f"Consider saving {missing_key} in your checkpoint or updating the config (tie_word_embeddings=true)." - ) - model._try_load_missing_tied_module(missing_key) - - if requested_loading_info: - return model, loading_info - return model - + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" factor = self.config.initializer_factor # Used for testing weights initialization if isinstance(module, LongT5LayerNorm): - module.weight.data.fill_(factor * 1.0) + module.weight.fill_(factor * 1.0) elif isinstance(module, (LongT5Model, LongT5ForConditionalGeneration, LongT5EncoderModel)): - module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) + module.shared.weight.normal_(mean=0.0, std=factor * 1.0) if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: - module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0) + module.lm_head.weight.normal_(mean=0.0, std=factor * 1.0) elif isinstance(module, LongT5DenseActDense): - module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + module.wi.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.wi, "bias") and module.wi.bias is not None: - module.wi.bias.data.zero_() - module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + module.wi.bias.zero_() + module.wo.weight.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) if hasattr(module.wo, "bias") and module.wo.bias is not None: - module.wo.bias.data.zero_() + module.wo.bias.zero_() elif isinstance(module, LongT5DenseGatedActDense): - module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + module.wi_0.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None: - module.wi_0.bias.data.zero_() - module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + module.wi_0.bias.zero_() + module.wi_1.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None: - module.wi_1.bias.data.zero_() - module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + module.wi_1.bias.zero_() + module.wo.weight.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) if hasattr(module.wo, "bias") and module.wo.bias is not None: - module.wo.bias.data.zero_() + module.wo.bias.zero_() elif isinstance(module, (LongT5Attention, LongT5LocalAttention, LongT5TransientGlobalAttention)): d_model = self.config.d_model key_value_proj_dim = self.config.d_kv n_heads = self.config.num_heads - module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) - module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) - module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) - module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) + module.q.weight.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) + module.k.weight.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.v.weight.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.o.weight.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) if module.has_relative_attention_bias: - module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) + module.relative_attention_bias.weight.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) if isinstance(module, LongT5TransientGlobalAttention): - module.global_relative_attention_bias.weight.data.normal_( - mean=0.0, std=factor * ((d_model) ** -0.5) - ) + module.global_relative_attention_bias.weight.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right with T5->LongT5 def _shift_right(self, input_ids): @@ -1270,12 +1240,10 @@ def _shift_right(self, input_ids): class LongT5Stack(LongT5PreTrainedModel): - def __init__(self, config, embed_tokens=None): + def __init__(self, config): super().__init__(config) self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model) - if embed_tokens is not None: - self.embed_tokens.weight = embed_tokens.weight self.is_decoder = config.is_decoder self.local_radius = config.local_radius @@ -1599,7 +1567,10 @@ class LongT5Model(LongT5PreTrainedModel): _keys_to_ignore_on_load_unexpected = [ r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", ] - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + _tied_weights_keys = { + "encoder.embed_tokens.weight": "shared.weight", + "decoder.embed_tokens.weight": "shared.weight", + } def __init__(self, config: LongT5Config): super().__init__(config) @@ -1609,13 +1580,13 @@ def __init__(self, config: LongT5Config): encoder_config.is_decoder = False encoder_config.use_cache = False encoder_config.tie_encoder_decoder = False - self.encoder = LongT5Stack(encoder_config, self.shared) + self.encoder = LongT5Stack(encoder_config) decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True decoder_config.tie_encoder_decoder = False decoder_config.num_layers = config.num_decoder_layers - self.decoder = LongT5Stack(decoder_config, self.shared) + self.decoder = LongT5Stack(decoder_config) # Initialize weights and apply final processing self.post_init() @@ -1628,11 +1599,6 @@ def set_input_embeddings(self, new_embeddings): self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) - def _tie_weights(self): - if self.config.tie_word_embeddings: - self._tie_embedding_weights(self.encoder.embed_tokens, self.shared) - self._tie_embedding_weights(self.decoder.embed_tokens, self.shared) - def get_encoder(self): return self.encoder @@ -1763,7 +1729,11 @@ class LongT5ForConditionalGeneration(LongT5PreTrainedModel, GenerationMixin): _keys_to_ignore_on_load_unexpected = [ r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", ] - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + _tied_weights_keys = { + "encoder.embed_tokens.weight": "shared.weight", + "decoder.embed_tokens.weight": "shared.weight", + "lm_head.weight": "shared.weight", + } def __init__(self, config: LongT5Config): super().__init__(config) @@ -1775,13 +1745,13 @@ def __init__(self, config: LongT5Config): encoder_config.is_decoder = False encoder_config.use_cache = False encoder_config.tie_encoder_decoder = False - self.encoder = LongT5Stack(encoder_config, self.shared) + self.encoder = LongT5Stack(encoder_config) decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True decoder_config.tie_encoder_decoder = False decoder_config.num_layers = config.num_decoder_layers - self.decoder = LongT5Stack(decoder_config, self.shared) + self.decoder = LongT5Stack(decoder_config) self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) @@ -1796,11 +1766,6 @@ def set_input_embeddings(self, new_embeddings): self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) - def _tie_weights(self): - if self.config.tie_word_embeddings: - self._tie_embedding_weights(self.encoder.embed_tokens, self.shared) - self._tie_embedding_weights(self.decoder.embed_tokens, self.shared) - def get_encoder(self): return self.encoder @@ -1952,7 +1917,9 @@ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): @auto_docstring class LongT5EncoderModel(LongT5PreTrainedModel): - _tied_weights_keys = ["encoder.embed_tokens.weight"] + _tied_weights_keys = { + "encoder.embed_tokens.weight": "shared.weight", + } _keys_to_ignore_on_load_unexpected = [r"decoder"] def __init__(self, config: LongT5Config): @@ -1961,8 +1928,7 @@ def __init__(self, config: LongT5Config): encoder_config = copy.deepcopy(config) encoder_config.use_cache = False - encoder_config.tie_encoder_decoder = False - self.encoder = LongT5Stack(encoder_config, self.shared) + self.encoder = LongT5Stack(encoder_config) # Initialize weights and apply final processing self.post_init() @@ -1974,10 +1940,6 @@ def set_input_embeddings(self, new_embeddings): self.shared = new_embeddings self.encoder.set_input_embeddings(new_embeddings) - def _tie_weights(self): - if self.config.tie_word_embeddings: - self._tie_embedding_weights(self.encoder.embed_tokens, self.shared) - def get_encoder(self): return self.encoder diff --git a/src/transformers/models/luke/modeling_luke.py b/src/transformers/models/luke/modeling_luke.py index b37b4a1e3e6d..79b63ac33d86 100644 --- a/src/transformers/models/luke/modeling_luke.py +++ b/src/transformers/models/luke/modeling_luke.py @@ -766,22 +766,23 @@ class LukePreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["LukeAttention", "LukeEntityEmbeddings"] + @torch.no_grad() def _init_weights(self, module: nn.Module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): if module.embedding_dim == 1: # embedding for bias parameters - module.weight.data.zero_() + module.weight.zero_() else: - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) @auto_docstring( @@ -1024,7 +1025,6 @@ def __init__(self, config): self.decoder = nn.Linear(config.hidden_size, config.vocab_size) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - self.decoder.bias = self.bias def forward(self, features, **kwargs): x = self.dense(features) @@ -1036,14 +1036,6 @@ def forward(self, features, **kwargs): return x - def _tie_weights(self): - # To tie those two weights if they get disconnected (on TPU or when the bias is resized) - # For accelerate compatibility and to not break backward compatibility - if self.decoder.bias.device.type == "meta": - self.decoder.bias = self.bias - else: - self.bias = self.decoder.bias - @auto_docstring( custom_intro=""" @@ -1052,7 +1044,10 @@ def _tie_weights(self): """ ) class LukeForMaskedLM(LukePreTrainedModel): - _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias", "entity_predictions.decoder.weight"] + _tied_weights_keys = { + "entity_predictions.decoder.weight": "luke.entity_embeddings.entity_embeddings.weight", + "lm_head.bias": "lm_head.decoder.bias", + } def __init__(self, config): super().__init__(config) @@ -1067,10 +1062,6 @@ def __init__(self, config): # Initialize weights and apply final processing self.post_init() - def tie_weights(self): - super().tie_weights() - self._tie_embedding_weights(self.entity_predictions.decoder, self.luke.entity_embeddings.entity_embeddings) - def get_output_embeddings(self): return self.lm_head.decoder diff --git a/src/transformers/models/lxmert/modeling_lxmert.py b/src/transformers/models/lxmert/modeling_lxmert.py index 08be81ae3c0e..707388f91248 100644 --- a/src/transformers/models/lxmert/modeling_lxmert.py +++ b/src/transformers/models/lxmert/modeling_lxmert.py @@ -682,21 +682,22 @@ class LxmertPreTrainedModel(PreTrainedModel): input_modalities = ["image", "text"] _supports_param_buffer_assignment = False + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, LxmertLMPredictionHead): - module.bias.data.zero_() + module.bias.zero_() @auto_docstring @@ -851,7 +852,7 @@ def forward( @auto_docstring class LxmertForPreTraining(LxmertPreTrainedModel): - _tied_weights_keys = ["cls.predictions.decoder.weight"] + _tied_weights_keys = {"cls.predictions.decoder.weight": "lxmert.embeddings.word_embeddings.weight"} def __init__(self, config): super().__init__(config) @@ -908,9 +909,6 @@ def __init__(self, config): } self.visual_losses = visual_losses - def _tie_weights(self): - self.cls.predictions.decoder.weight = self.lxmert.embeddings.word_embeddings.weight - def resize_token_embeddings( self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True ) -> nn.Embedding: diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 772026b7b465..60f41cd6ad00 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -516,19 +516,20 @@ class M2M100PreTrainedModel(PreTrainedModel): # Doesn't support `compile` (dynamic control flow). Can be fixed but low usage model _can_compile_fullgraph = False + @torch.no_grad() def _init_weights(self, module): std = self.config.init_std if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() class M2M100Encoder(M2M100PreTrainedModel): @@ -541,7 +542,7 @@ class M2M100Encoder(M2M100PreTrainedModel): embed_tokens (nn.Embedding): output embedding """ - def __init__(self, config: M2M100Config, embed_tokens: Optional[nn.Embedding] = None): + def __init__(self, config: M2M100Config): super().__init__(config) self.dropout = config.dropout @@ -556,9 +557,6 @@ def __init__(self, config: M2M100Config, embed_tokens: Optional[nn.Embedding] = config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale ) - if embed_tokens is not None: - self.embed_tokens.weight = embed_tokens.weight - self.embed_positions = M2M100SinusoidalPositionalEmbedding( config.max_position_embeddings, embed_dim, @@ -694,7 +692,7 @@ class M2M100Decoder(M2M100PreTrainedModel): embed_tokens (nn.Embedding): output embedding """ - def __init__(self, config: M2M100Config, embed_tokens: Optional[nn.Embedding] = None): + def __init__(self, config: M2M100Config): super().__init__(config) self.dropout = config.dropout self.layerdrop = config.decoder_layerdrop @@ -706,9 +704,6 @@ def __init__(self, config: M2M100Config, embed_tokens: Optional[nn.Embedding] = config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale ) - if embed_tokens is not None: - self.embed_tokens.weight = embed_tokens.weight - self.embed_positions = M2M100SinusoidalPositionalEmbedding( config.max_position_embeddings, config.d_model, @@ -920,7 +915,10 @@ def forward( @auto_docstring class M2M100Model(M2M100PreTrainedModel): - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + _tied_weights_keys = { + "decoder.embed_tokens.weight": "shared.weight", + "encoder.embed_tokens.weight": "shared.weight", + } def __init__(self, config: M2M100Config): super().__init__(config) @@ -929,8 +927,8 @@ def __init__(self, config: M2M100Config): embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 self.shared = M2M100ScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale) - self.encoder = M2M100Encoder(config, self.shared) - self.decoder = M2M100Decoder(config, self.shared) + self.encoder = M2M100Encoder(config) + self.decoder = M2M100Decoder(config) # Initialize weights and apply final processing self.post_init() @@ -943,11 +941,6 @@ def set_input_embeddings(self, value): self.encoder.embed_tokens = self.shared self.decoder.embed_tokens = self.shared - def _tie_weights(self): - if self.config.tie_word_embeddings: - self._tie_embedding_weights(self.encoder.embed_tokens, self.shared) - self._tie_embedding_weights(self.decoder.embed_tokens, self.shared) - def get_encoder(self): return self.encoder @@ -1045,7 +1038,7 @@ def forward( ) class M2M100ForConditionalGeneration(M2M100PreTrainedModel, GenerationMixin): base_model_prefix = "model" - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.shared.weight"} def __init__(self, config: M2M100Config): super().__init__(config) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 56744f354b27..f17bd66649af 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -504,6 +504,7 @@ class MambaPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _is_stateful = True + @torch.no_grad() def _init_weights(self, module): """Initialize the weights.""" std = self.config.initializer_range @@ -513,7 +514,7 @@ def _init_weights(self, module): A = torch.arange(1, module.ssm_state_size + 1, dtype=torch.float32)[None, :] A = A.expand(module.intermediate_size, -1).contiguous() module.A_log.copy_(torch.log(A)) - module.D.data.fill_(1.0) + module.D.fill_(1.0) dt_init_std = self.config.time_step_rank**-0.5 * self.config.time_step_scale if self.config.time_step_init_scheme == "constant": @@ -558,7 +559,7 @@ def _init_weights(self, module): if not getattr(module.bias, "_no_reinit", False): nn.init.zeros_(module.bias) elif isinstance(module, MambaRMSNorm): - module.weight.data.fill_(1.0) + module.weight.fill_(1.0) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, std=std) @@ -721,7 +722,7 @@ def forward( """ ) class MambaForCausalLM(MambaPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "backbone.embeddings.weight"} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/mamba2/modeling_mamba2.py b/src/transformers/models/mamba2/modeling_mamba2.py index 6f1f31b9002c..716f62e5d1b1 100644 --- a/src/transformers/models/mamba2/modeling_mamba2.py +++ b/src/transformers/models/mamba2/modeling_mamba2.py @@ -717,6 +717,7 @@ class Mamba2PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _is_stateful = True + @torch.no_grad() def _init_weights(self, module): """Initialize the weights.""" std = self.config.initializer_range @@ -725,7 +726,7 @@ def _init_weights(self, module): # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded A = torch.arange(1, self.config.num_heads + 1) module.A_log.copy_(torch.log(A)) - module.D.data.fill_(1.0) + module.D.fill_(1.0) dt = torch.exp( torch.rand(self.config.num_heads) @@ -765,7 +766,7 @@ def _init_weights(self, module): if not getattr(module.bias, "_no_reinit", False): nn.init.zeros_(module.bias) elif isinstance(module, (Mamba2RMSNorm, MambaRMSNormGated)): - module.weight.data.fill_(1.0) + module.weight.fill_(1.0) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, std=std) @@ -934,7 +935,7 @@ def forward( """ ) class Mamba2ForCausalLM(Mamba2PreTrainedModel, GenerationMixin): - _tied_weights_keys = [] + _tied_weights_keys = {} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index fe0f264581bc..ced0aa6c25a6 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -14,7 +14,6 @@ # limitations under the License. """PyTorch MarianMTModel model, ported from the Marian C++ repo.""" -import copy import math from collections.abc import Callable from typing import Optional, Union @@ -446,21 +445,22 @@ class MarianPreTrainedModel(PreTrainedModel): _can_compile_fullgraph = True + @torch.no_grad() def _init_weights(self, module: Union[nn.Linear, nn.Embedding, MarianSinusoidalPositionalEmbedding]): std = self.config.init_std if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, MarianSinusoidalPositionalEmbedding): module._init_weight() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() @property def dummy_inputs(self): @@ -484,7 +484,7 @@ class MarianEncoder(MarianPreTrainedModel): embed_tokens (nn.Embedding): output embedding """ - def __init__(self, config: MarianConfig, embed_tokens: Optional[nn.Embedding] = None): + def __init__(self, config: MarianConfig): super().__init__(config) self.dropout = config.dropout @@ -495,10 +495,7 @@ def __init__(self, config: MarianConfig, embed_tokens: Optional[nn.Embedding] = self.max_source_positions = config.max_position_embeddings self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 - if embed_tokens is not None: - self.embed_tokens = embed_tokens - else: - self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) + self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) self.embed_positions = MarianSinusoidalPositionalEmbedding( config.max_position_embeddings, embed_dim, self.padding_idx @@ -626,7 +623,7 @@ class MarianDecoder(MarianPreTrainedModel): embed_tokens (nn.Embedding): output embedding """ - def __init__(self, config: MarianConfig, embed_tokens: Optional[nn.Embedding] = None): + def __init__(self, config: MarianConfig): super().__init__(config) self.dropout = config.dropout self.layerdrop = config.decoder_layerdrop @@ -634,10 +631,7 @@ def __init__(self, config: MarianConfig, embed_tokens: Optional[nn.Embedding] = self.max_target_positions = config.max_position_embeddings self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 - if embed_tokens is not None: - self.embed_tokens = embed_tokens - else: - self.embed_tokens = nn.Embedding(config.decoder_vocab_size, config.d_model, self.padding_idx) + self.embed_tokens = nn.Embedding(config.decoder_vocab_size, config.d_model, self.padding_idx) self.embed_positions = MarianSinusoidalPositionalEmbedding( config.max_position_embeddings, config.d_model, self.padding_idx @@ -846,7 +840,10 @@ def forward( @auto_docstring class MarianModel(MarianPreTrainedModel): - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + _tied_weights_keys = { + "decoder.embed_tokens.weight": "shared.weight", + "encoder.embed_tokens.weight": "shared.weight", + } def __init__(self, config: MarianConfig): super().__init__(config) @@ -854,18 +851,11 @@ def __init__(self, config: MarianConfig): padding_idx, vocab_size = config.pad_token_id, config.vocab_size # We always use self.shared for token embeddings to ensure compatibility with all marian models - self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) if self.config.share_encoder_decoder_embeddings: - encoder_embed_tokens = decoder_embed_tokens = self.shared - else: - # Since the embeddings are not shared, deepcopy the embeddings here for encoder - # and decoder to make sure they are not tied. - encoder_embed_tokens = copy.deepcopy(self.shared) - decoder_embed_tokens = copy.deepcopy(self.shared) - self.shared = None + self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) - self.encoder = MarianEncoder(config, encoder_embed_tokens) - self.decoder = MarianDecoder(config, decoder_embed_tokens) + self.encoder = MarianEncoder(config) + self.decoder = MarianDecoder(config) # Initialize weights and apply final processing self.post_init() @@ -983,9 +973,9 @@ def forward( output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) - use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # If encoder_outputs are not given, pass the inputs to the encoder if encoder_outputs is None: encoder_outputs = self.encoder( input_ids=input_ids, @@ -1046,7 +1036,7 @@ class MarianMTModel(MarianPreTrainedModel, GenerationMixin): "decoder.embed_positions.weight", ] _keys_to_ignore_on_save = ["model.encoder.embed_positions.weight", "model.decoder.embed_positions.weight"] - _tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight", "lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.shared.weight"} def __init__(self, config: MarianConfig): super().__init__(config) @@ -1140,31 +1130,6 @@ def _resize_final_logits_bias(self, new_num_tokens: int) -> None: def set_output_embeddings(self, new_embeddings: nn.Embedding): self.lm_head = new_embeddings - def tie_weights(self): - """ - Tie the weights between the input embeddings and the output embeddings. - """ - output_embeddings = self.get_output_embeddings() - if output_embeddings is not None and getattr(self.config, "tie_word_embeddings", True): - # if embeddings are shared this will return shared embeddings otherwise decoder embed_tokens - word_embeddings = self.get_decoder().get_input_embeddings() - self._tie_embedding_weights(output_embeddings, word_embeddings) - - if getattr(self.config, "is_encoder_decoder", False) and getattr(self.config, "tie_encoder_decoder", False): - if hasattr(self, self.base_model_prefix): - self = getattr(self, self.base_model_prefix) - tied_weights = self._tie_encoder_decoder_weights( - self.encoder, self.decoder, self.base_model_prefix, "encoder" - ) - # Setting a dynamic variable instead of `_tied_weights_keys` because it's a class - # attributed not an instance member, therefore modifying it will modify the entire class - # Leading to issues on subsequent calls by different tests or subsequent calls. - self._dynamic_tied_weights_keys = tied_weights - - for module in self.modules(): - if hasattr(module, "_tie_weights"): - module._tie_weights() - @auto_docstring def forward( self, @@ -1293,7 +1258,9 @@ def forward(self, *args, **kwargs): # Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->Marian, facebook/bart-base->Helsinki-NLP/opus-mt-fr-en class MarianForCausalLM(MarianPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = { + "lm_head.weight": "model.decoder.embed_tokens.weight", + } def __init__(self, config): config.is_decoder = True diff --git a/src/transformers/models/markuplm/modeling_markuplm.py b/src/transformers/models/markuplm/modeling_markuplm.py index 7cd32c5cebd9..60be191c8285 100755 --- a/src/transformers/models/markuplm/modeling_markuplm.py +++ b/src/transformers/models/markuplm/modeling_markuplm.py @@ -294,16 +294,9 @@ def __init__(self, config): # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. - self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` - self.decoder.bias = self.bias - - def _tie_weights(self): - self.decoder.bias = self.bias - def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -517,22 +510,22 @@ class MarkupLMPreTrainedModel(PreTrainedModel): config: MarkupLMConfig base_model_prefix = "markuplm" - # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights with Bert->MarkupLM + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, MarkupLMLMPredictionHead): - module.bias.data.zero_() + module.bias.zero_() @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs): diff --git a/src/transformers/models/mask2former/modeling_mask2former.py b/src/transformers/models/mask2former/modeling_mask2former.py index 278f977320ed..24b1d1078b82 100644 --- a/src/transformers/models/mask2former/modeling_mask2former.py +++ b/src/transformers/models/mask2former/modeling_mask2former.py @@ -2102,6 +2102,7 @@ class Mask2FormerPreTrainedModel(PreTrainedModel): main_input_name = "pixel_values" input_modalities = "image" + @torch.no_grad() def _init_weights(self, module: nn.Module): xavier_std = self.config.init_xavier_std std = self.config.init_std @@ -2114,7 +2115,7 @@ def _init_weights(self, module: nn.Module): nn.init.constant_(input_projection.bias, 0) elif isinstance(module, Mask2FormerPixelDecoderEncoderMultiscaleDeformableAttention): - nn.init.constant_(module.sampling_offsets.weight.data, 0.0) + nn.init.constant_(module.sampling_offsets.weight, 0.0) thetas = torch.arange(module.n_heads, dtype=torch.int64).float() * (2.0 * math.pi / module.n_heads) grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) grid_init = ( @@ -2127,39 +2128,39 @@ def _init_weights(self, module: nn.Module): with torch.no_grad(): module.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) - nn.init.constant_(module.attention_weights.weight.data, 0.0) - nn.init.constant_(module.attention_weights.bias.data, 0.0) - nn.init.xavier_uniform_(module.value_proj.weight.data) - nn.init.constant_(module.value_proj.bias.data, 0.0) - nn.init.xavier_uniform_(module.output_proj.weight.data) - nn.init.constant_(module.output_proj.bias.data, 0.0) + nn.init.constant_(module.attention_weights.weight, 0.0) + nn.init.constant_(module.attention_weights.bias, 0.0) + nn.init.xavier_uniform_(module.value_proj.weight) + nn.init.constant_(module.value_proj.bias, 0.0) + nn.init.xavier_uniform_(module.output_proj.weight) + nn.init.constant_(module.output_proj.bias, 0.0) elif isinstance(module, Mask2FormerMaskedAttentionDecoderLayer): for p in module.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p, gain=xavier_std) - module.cross_attn.in_proj_bias.data.zero_() + module.cross_attn.in_proj_bias.zero_() elif isinstance(module, Mask2FormerPixelDecoder): nn.init.normal_(module.level_embed, std=0) elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() if hasattr(module, "reference_points"): - nn.init.xavier_uniform_(module.reference_points.weight.data, gain=1.0) - nn.init.constant_(module.reference_points.bias.data, 0.0) + nn.init.xavier_uniform_(module.reference_points.weight, gain=1.0) + nn.init.constant_(module.reference_points.bias, 0.0) @auto_docstring diff --git a/src/transformers/models/maskformer/modeling_maskformer.py b/src/transformers/models/maskformer/modeling_maskformer.py index bc961d2eb0ec..b2dc868f0138 100644 --- a/src/transformers/models/maskformer/modeling_maskformer.py +++ b/src/transformers/models/maskformer/modeling_maskformer.py @@ -1436,6 +1436,7 @@ class MaskFormerPreTrainedModel(PreTrainedModel): main_input_name = "pixel_values" input_modalities = "image" + @torch.no_grad() def _init_weights(self, module: nn.Module): xavier_std = self.config.init_xavier_std std = self.config.init_std @@ -1461,17 +1462,17 @@ def _init_weights(self, module: nn.Module): nn.init.xavier_uniform_(submodule.weight, gain=xavier_std) nn.init.constant_(submodule.bias, 0) elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) # copied from DETR if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() @auto_docstring diff --git a/src/transformers/models/maskformer/modeling_maskformer_swin.py b/src/transformers/models/maskformer/modeling_maskformer_swin.py index f0d5d1dc3dd8..b735b419c10d 100644 --- a/src/transformers/models/maskformer/modeling_maskformer_swin.py +++ b/src/transformers/models/maskformer/modeling_maskformer_swin.py @@ -701,20 +701,21 @@ class MaskFormerSwinPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["MaskFormerSwinStage"] + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, MaskFormerSwinEmbeddings): if module.position_embeddings is not None: - module.position_embeddings.data.zero_() + module.position_embeddings.zero_() elif isinstance(module, MaskFormerSwinSelfAttention): - module.relative_position_bias_table.data.zero_() + module.relative_position_bias_table.zero_() class MaskFormerSwinModel(MaskFormerSwinPreTrainedModel): diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 3f10516ed046..08cde27d7cce 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -479,19 +479,20 @@ class MBartPreTrainedModel(PreTrainedModel): _can_compile_fullgraph = True + @torch.no_grad() def _init_weights(self, module): std = self.config.init_std if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() @property def dummy_inputs(self): @@ -514,7 +515,7 @@ class MBartEncoder(MBartPreTrainedModel): embed_tokens (nn.Embedding): output embedding """ - def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = None): + def __init__(self, config: MBartConfig): super().__init__(config) self.dropout = config.dropout @@ -529,9 +530,6 @@ def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = N config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale ) - if embed_tokens is not None: - self.embed_tokens.weight = embed_tokens.weight - self.embed_positions = MBartLearnedPositionalEmbedding( config.max_position_embeddings, embed_dim, @@ -670,7 +668,7 @@ class MBartDecoder(MBartPreTrainedModel): embed_tokens (nn.Embedding): output embedding """ - def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = None): + def __init__(self, config: MBartConfig): super().__init__(config) self.dropout = config.dropout self.layerdrop = config.decoder_layerdrop @@ -682,9 +680,6 @@ def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = N config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale ) - if embed_tokens is not None: - self.embed_tokens.weight = embed_tokens.weight - self.embed_positions = MBartLearnedPositionalEmbedding( config.max_position_embeddings, config.d_model, @@ -898,7 +893,10 @@ def forward( @auto_docstring class MBartModel(MBartPreTrainedModel): - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + _tied_weights_keys = { + "decoder.embed_tokens.weight": "shared.weight", + "encoder.embed_tokens.weight": "shared.weight", + } def __init__(self, config: MBartConfig): super().__init__(config) @@ -907,8 +905,8 @@ def __init__(self, config: MBartConfig): embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 self.shared = MBartScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale) - self.encoder = MBartEncoder(config, self.shared) - self.decoder = MBartDecoder(config, self.shared) + self.encoder = MBartEncoder(config) + self.decoder = MBartDecoder(config) # Initialize weights and apply final processing self.post_init() @@ -924,11 +922,6 @@ def set_input_embeddings(self, value): def get_encoder(self): return self.encoder - def _tie_weights(self): - if self.config.tie_word_embeddings: - self._tie_embedding_weights(self.encoder.embed_tokens, self.get_input_embeddings()) - self._tie_embedding_weights(self.decoder.embed_tokens, self.get_input_embeddings()) - @auto_docstring def forward( self, @@ -1034,7 +1027,7 @@ def forward( class MBartForConditionalGeneration(MBartPreTrainedModel, GenerationMixin): base_model_prefix = "model" _keys_to_ignore_on_load_missing = ["final_logits_bias"] - _tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight", "lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.shared.weight"} def __init__(self, config: MBartConfig): super().__init__(config) @@ -1207,8 +1200,6 @@ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): """ ) class MBartForSequenceClassification(MBartPreTrainedModel): - _tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight"] - def __init__(self, config: MBartConfig, **kwargs): super().__init__(config, **kwargs) self.model = MBartModel(config) @@ -1342,8 +1333,6 @@ def forward( @auto_docstring class MBartForQuestionAnswering(MBartPreTrainedModel): - _tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight"] - def __init__(self, config): super().__init__(config) @@ -1479,7 +1468,9 @@ def forward(self, *args, **kwargs): # Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->MBart, facebook/bart-base->facebook/mbart-large-cc25 class MBartForCausalLM(MBartPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = { + "lm_head.weight": "model.decoder.embed_tokens.weight", + } def __init__(self, config): config.is_decoder = True diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py index 6f0a035eca95..d7a869cfd89a 100755 --- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py +++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py @@ -471,16 +471,9 @@ def __init__(self, config): # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. - self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` - self.decoder.bias = self.bias - - def _tie_weights(self): - self.decoder.bias = self.bias - def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -528,17 +521,18 @@ class MegatronBertPreTrainedModel(PreTrainedModel): base_model_prefix = "bert" supports_gradient_checkpointing = True + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Embedding)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if hasattr(module, "bias") and module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, MegatronBertLMPredictionHead): - module.bias.data.zero_() + module.bias.zero_() @dataclass @@ -708,7 +702,10 @@ def forward( """ ) class MegatronBertForPreTraining(MegatronBertPreTrainedModel): - _tied_weights_keys = ["cls.predictions.decoder"] + _tied_weights_keys = { + "cls.predictions.decoder.weight": "bert.embeddings.word_embeddings.weight", + "cls.predictions.decoder.bias": "cls.predictions.bias", + } def __init__(self, config, add_binary_head=True): r""" @@ -813,7 +810,10 @@ def forward( """ ) class MegatronBertForCausalLM(MegatronBertPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["cls.predictions.decoder"] + _tied_weights_keys = { + "cls.predictions.decoder.weight": "bert.embeddings.word_embeddings.weight", + "cls.predictions.decoder.bias": "cls.predictions.bias", + } def __init__(self, config): super().__init__(config) @@ -919,7 +919,10 @@ def forward( @auto_docstring class MegatronBertForMaskedLM(MegatronBertPreTrainedModel): - _tied_weights_keys = ["cls.predictions.decoder"] + _tied_weights_keys = { + "cls.predictions.decoder.weight": "bert.embeddings.word_embeddings.weight", + "cls.predictions.decoder.bias": "cls.predictions.bias", + } def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/metaclip_2/modeling_metaclip_2.py b/src/transformers/models/metaclip_2/modeling_metaclip_2.py index f352ce30e2be..c66bababfbe5 100644 --- a/src/transformers/models/metaclip_2/modeling_metaclip_2.py +++ b/src/transformers/models/metaclip_2/modeling_metaclip_2.py @@ -298,12 +298,13 @@ class MetaClip2PreTrainedModel(PreTrainedModel): "attentions": MetaClip2Attention, } + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" factor = self.config.initializer_factor if isinstance(module, MetaClip2TextEmbeddings): - module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) - module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) + module.token_embedding.weight.normal_(mean=0.0, std=factor * 0.02) + module.position_embedding.weight.normal_(mean=0.0, std=factor * 0.02) elif isinstance(module, MetaClip2VisionEmbeddings): factor = self.config.initializer_factor nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) @@ -349,10 +350,10 @@ def _init_weights(self, module): ) if isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() class MetaClip2Encoder(nn.Module): diff --git a/src/transformers/models/metaclip_2/modular_metaclip_2.py b/src/transformers/models/metaclip_2/modular_metaclip_2.py index ae465d40a3aa..79cdf35be7e9 100644 --- a/src/transformers/models/metaclip_2/modular_metaclip_2.py +++ b/src/transformers/models/metaclip_2/modular_metaclip_2.py @@ -217,12 +217,13 @@ class MetaClip2MLP(CLIPMLP): class MetaClip2PreTrainedModel(CLIPPreTrainedModel): base_model_prefix = "metaclip_2" + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" factor = self.config.initializer_factor if isinstance(module, MetaClip2TextEmbeddings): - module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) - module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) + module.token_embedding.weight.normal_(mean=0.0, std=factor * 0.02) + module.position_embedding.weight.normal_(mean=0.0, std=factor * 0.02) elif isinstance(module, MetaClip2VisionEmbeddings): factor = self.config.initializer_factor nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) @@ -268,10 +269,10 @@ def _init_weights(self, module): ) if isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() class MetaClip2TextTransformer(CLIPTextTransformer): diff --git a/src/transformers/models/mgp_str/modeling_mgp_str.py b/src/transformers/models/mgp_str/modeling_mgp_str.py index c57af7cb5f51..819d5d38fcc1 100644 --- a/src/transformers/models/mgp_str/modeling_mgp_str.py +++ b/src/transformers/models/mgp_str/modeling_mgp_str.py @@ -284,6 +284,7 @@ class MgpstrPreTrainedModel(PreTrainedModel): base_model_prefix = "mgp_str" _no_split_modules = [] + @torch.no_grad() def _init_weights(self, module: nn.Module) -> None: """Initialize the weights""" std = self.config.initializer_range @@ -291,12 +292,12 @@ def _init_weights(self, module: nn.Module) -> None: nn.init.trunc_normal_(module.pos_embed, mean=0.0, std=std) nn.init.trunc_normal_(module.cls_token, mean=0.0, std=std) elif isinstance(module, (nn.Linear, nn.Conv2d)): - nn.init.trunc_normal_(module.weight.data, mean=0.0, std=std) + nn.init.trunc_normal_(module.weight, mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) @auto_docstring diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index 83bcbd857a0d..8182c1b7372e 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -1395,22 +1395,23 @@ class MimiPreTrainedModel(PreTrainedModel): _can_compile_fullgraph = True + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, (nn.Conv1d, nn.ConvTranspose1d)): nn.init.kaiming_normal_(module.weight) if module.bias is not None: k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) nn.init.uniform_(module.bias, a=-k, b=k) elif isinstance(module, MimiLayerScale): - module.scale.data.fill_(self.config.layer_scale_initial_scale) + module.scale.fill_(self.config.layer_scale_initial_scale) @auto_docstring( diff --git a/src/transformers/models/minimax/configuration_minimax.py b/src/transformers/models/minimax/configuration_minimax.py index b99a61a277ea..77b971a7d1a9 100644 --- a/src/transformers/models/minimax/configuration_minimax.py +++ b/src/transformers/models/minimax/configuration_minimax.py @@ -137,10 +137,10 @@ class MiniMaxConfig(PreTrainedConfig): "layers.*.self_attn.k_proj": "colwise", "layers.*.self_attn.v_proj": "colwise", "layers.*.self_attn.o_proj": "rowwise", - "layers.*.block_sparse_moe.gate": "colwise_rep", # we need to replicate here to correctly route experts - "layers.*.block_sparse_moe.experts.*.w1": "colwise", - "layers.*.block_sparse_moe.experts.*.w2": "rowwise", - "layers.*.block_sparse_moe.experts.*.w3": "colwise", + "layers.*.mlp.gate": "colwise_rep", # we need to replicate here to correctly route experts + "layers.*.mlp.experts.*.w1": "colwise", + "layers.*.mlp.experts.*.w2": "rowwise", + "layers.*.mlp.experts.*.w3": "colwise", } base_model_pp_plan = { "embed_tokens": (["input_ids"], ["inputs_embeds"]), diff --git a/src/transformers/models/minimax/modeling_minimax.py b/src/transformers/models/minimax/modeling_minimax.py index 7e8a499ed56e..049e650811ca 100644 --- a/src/transformers/models/minimax/modeling_minimax.py +++ b/src/transformers/models/minimax/modeling_minimax.py @@ -452,38 +452,41 @@ def forward( return attn_output, attn_weights -class MiniMaxMLP(nn.Module): - def __init__(self, config: MiniMaxConfig): +class MiniMaxTopKRouter(nn.Module): + def __init__(self, config): super().__init__() - self.ffn_dim = config.intermediate_size + self.top_k = config.num_experts_per_tok + self.num_experts = config.num_local_experts self.hidden_dim = config.hidden_size - - self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) - self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) - self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) - - self.act_fn = ACT2FN[config.hidden_act] + self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim)) def forward(self, hidden_states): - current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) - current_hidden_states = self.w2(current_hidden_states) - return current_hidden_states + hidden_states = hidden_states.reshape(-1, self.hidden_dim) + router_logits = F.linear(hidden_states, self.weight) # (seq_len, num_experts) + router_logits = torch.nn.functional.softmax(router_logits.float(), dim=-1) + router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) + router_top_value /= router_top_value.sum(dim=-1, keepdim=True) + router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) + return router_scores, router_indices -class MiniMaxExperts(nn.ModuleList): - """ - ModuleList of experts. - """ +class MiniMaxExperts(nn.Module): + """Collection of expert weights stored as 3D tensors.""" def __init__(self, config: MiniMaxConfig): super().__init__() - self.top_k = config.num_experts_per_tok self.num_experts = config.num_local_experts - for _ in range(self.num_experts): - self.append(MiniMaxMLP(config)) + self.hidden_dim = config.hidden_size + self.intermediate_dim = config.intermediate_size + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) + self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) + self.act_fn = ACT2FN[config.hidden_act] def forward( - self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, ) -> torch.Tensor: """ Args: @@ -494,14 +497,24 @@ def forward( (batch_size * sequence_length, hidden_dim) """ final_hidden_states = torch.zeros_like(hidden_states) - expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0) + num_experts = top_k_weights.shape[1] + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() for expert_idx in expert_hit: - idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) - current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1]) - current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None] - final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + expert_idx = expert_idx[0] + if expert_idx == num_experts: + continue + _, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) + current_hidden_states = self.act_fn(gate) * up + current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) + current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None] + final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) + return final_hidden_states @@ -510,23 +523,16 @@ def __init__(self, config): super().__init__() self.top_k = config.num_experts_per_tok self.jitter_noise = config.router_jitter_noise - self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) + self.gate = MiniMaxTopKRouter(config) self.experts = MiniMaxExperts(config) - def route_tokens_to_experts(self, router_logits): - routing_weights = torch.nn.functional.softmax(router_logits.float(), dim=-1) - top_k_weights, top_k_index = torch.topk(routing_weights, self.top_k, dim=-1) - top_k_weights /= top_k_weights.sum(dim=-1, keepdim=True) - return top_k_index, top_k_weights.to(router_logits.dtype) - def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: batch_size, sequence_length, hidden_dim = hidden_states.shape if self.training and self.jitter_noise > 0: hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - router_logits = self.gate(hidden_states) - top_k_index, top_k_weights = self.route_tokens_to_experts(router_logits) - hidden_states = self.experts(hidden_states, top_k_index, top_k_weights.to(hidden_states.dtype)) + top_k_weights, top_k_index = self.gate(hidden_states) + hidden_states = self.experts(hidden_states, top_k_index, top_k_weights) hidden_states = hidden_states.reshape(batch_size, sequence_length, hidden_dim) return hidden_states @@ -537,8 +543,6 @@ def __init__(self, config: MiniMaxConfig, layer_idx: int): self.hidden_size = config.hidden_size self.self_attn = MiniMaxAttention(config, layer_idx) - - self.block_sparse_moe = MiniMaxSparseMoeBlock(config) self.input_layernorm = MiniMaxRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = MiniMaxRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -546,7 +550,7 @@ def __init__(self, config: MiniMaxConfig, layer_idx: int): self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None self.mlp_alpha_factor = config.mlp_alpha_factor self.mlp_beta_factor = config.mlp_beta_factor - + self.mlp = MiniMaxSparseMoeBlock(config) if self.layer_type == "linear_attention": self.self_attn = MiniMaxLightningAttention(config, layer_idx) self.attn_alpha_factor = config.linear_attn_alpha_factor @@ -582,7 +586,7 @@ def forward( hidden_states = residual * self.attn_alpha_factor + hidden_states * self.attn_beta_factor hidden_states = self.post_attention_layernorm(hidden_states) residual = hidden_states - hidden_states = self.block_sparse_moe(hidden_states) + hidden_states = self.mlp(hidden_states) hidden_states = residual * self.mlp_alpha_factor + hidden_states * self.mlp_beta_factor return hidden_states @@ -601,11 +605,21 @@ class MiniMaxPreTrainedModel(PreTrainedModel): _can_compile_fullgraph = False _supports_attention_backend = True _can_record_outputs = { - "router_logits": OutputRecorder(nn.Linear, layer_name="block_sparse_moe.gate", index=0), + "router_logits": OutputRecorder(MiniMaxTopKRouter, layer_name="mlp.gate", index=0), "hidden_states": MiniMaxDecoderLayer, "attentions": [MiniMaxAttention, MiniMaxLightningAttention], } + @torch.no_grad() + def _init_weights(self, module): + super()._init_weights(module) + std = self.config.initializer_range + if isinstance(module, MiniMaxExperts): + module.gate_up_proj.normal_(mean=0.0, std=std) + module.down_proj.normal_(mean=0.0, std=std) + elif isinstance(module, MiniMaxTopKRouter): + module.weight.normal_(mean=0.0, std=std) + @auto_docstring class MiniMaxModel(MiniMaxPreTrainedModel): @@ -781,7 +795,7 @@ def load_balancing_loss_func( @auto_docstring class MiniMaxForCausalLM(MiniMaxPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/minimax/modular_minimax.py b/src/transformers/models/minimax/modular_minimax.py index d1bbb96bb5c1..fff1b3fe8745 100644 --- a/src/transformers/models/minimax/modular_minimax.py +++ b/src/transformers/models/minimax/modular_minimax.py @@ -44,6 +44,7 @@ MixtralPreTrainedModel, MixtralRMSNorm, MixtralSparseMoeBlock, + MixtralTopKRouter, ) @@ -161,10 +162,10 @@ class MiniMaxConfig(PreTrainedConfig): "layers.*.self_attn.k_proj": "colwise", "layers.*.self_attn.v_proj": "colwise", "layers.*.self_attn.o_proj": "rowwise", - "layers.*.block_sparse_moe.gate": "colwise_rep", # we need to replicate here to correctly route experts - "layers.*.block_sparse_moe.experts.*.w1": "colwise", - "layers.*.block_sparse_moe.experts.*.w2": "rowwise", - "layers.*.block_sparse_moe.experts.*.w3": "colwise", + "layers.*.mlp.gate": "colwise_rep", # we need to replicate here to correctly route experts + "layers.*.mlp.experts.*.w1": "colwise", + "layers.*.mlp.experts.*.w2": "rowwise", + "layers.*.mlp.experts.*.w3": "colwise", } base_model_pp_plan = { "embed_tokens": (["input_ids"], ["inputs_embeds"]), @@ -464,6 +465,10 @@ class MiniMaxAttention(MixtralAttention): pass +class MiniMaxTopKRouter(MixtralTopKRouter): + pass + + class MiniMaxSparseMoeBlock(MixtralSparseMoeBlock): pass @@ -476,7 +481,8 @@ def __init__(self, config: MiniMaxConfig, layer_idx: int): self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None self.mlp_alpha_factor = config.mlp_alpha_factor self.mlp_beta_factor = config.mlp_beta_factor - + del self.mlp + self.mlp = MiniMaxSparseMoeBlock(config) if self.layer_type == "linear_attention": self.self_attn = MiniMaxLightningAttention(config, layer_idx) self.attn_alpha_factor = config.linear_attn_alpha_factor @@ -512,7 +518,7 @@ def forward( hidden_states = residual * self.attn_alpha_factor + hidden_states * self.attn_beta_factor hidden_states = self.post_attention_layernorm(hidden_states) residual = hidden_states - hidden_states = self.block_sparse_moe(hidden_states) + hidden_states = self.mlp(hidden_states) hidden_states = residual * self.mlp_alpha_factor + hidden_states * self.mlp_beta_factor return hidden_states @@ -521,7 +527,7 @@ def forward( class MiniMaxPreTrainedModel(MixtralPreTrainedModel): _can_compile_fullgraph = False _can_record_outputs = { - "router_logits": OutputRecorder(nn.Linear, layer_name="block_sparse_moe.gate", index=0), + "router_logits": OutputRecorder(MiniMaxTopKRouter, layer_name="mlp.gate", index=0), "hidden_states": MiniMaxDecoderLayer, "attentions": [MiniMaxAttention, MiniMaxLightningAttention], } diff --git a/src/transformers/models/ministral/modeling_ministral.py b/src/transformers/models/ministral/modeling_ministral.py index 239d2fc2047b..b1c8555fd96b 100644 --- a/src/transformers/models/ministral/modeling_ministral.py +++ b/src/transformers/models/ministral/modeling_ministral.py @@ -425,7 +425,7 @@ def forward( @auto_docstring class MinistralForCausalLM(MinistralPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index ab3cae55bb6e..60c7e2d49eed 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -412,7 +412,7 @@ def forward( @auto_docstring class MistralForCausalLM(MistralPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/mistral3/modeling_mistral3.py b/src/transformers/models/mistral3/modeling_mistral3.py index b98efd38e824..00eb7af262b6 100644 --- a/src/transformers/models/mistral3/modeling_mistral3.py +++ b/src/transformers/models/mistral3/modeling_mistral3.py @@ -364,7 +364,7 @@ class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin) "^multi_modal_projector": "model.multi_modal_projector", "^language_model.lm_head": "lm_head", } - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} def __init__(self, config: Mistral3Config): super().__init__(config) diff --git a/src/transformers/models/mixtral/configuration_mixtral.py b/src/transformers/models/mixtral/configuration_mixtral.py index 6784b7eb5f19..7cf6afc1d342 100644 --- a/src/transformers/models/mixtral/configuration_mixtral.py +++ b/src/transformers/models/mixtral/configuration_mixtral.py @@ -115,14 +115,16 @@ class MixtralConfig(PreTrainedConfig): model_type = "mixtral" keys_to_ignore_at_inference = ["past_key_values"] base_model_tp_plan = { - "layers.*.self_attn.q_proj": "colwise", - "layers.*.self_attn.k_proj": "colwise", - "layers.*.self_attn.v_proj": "colwise", - "layers.*.self_attn.o_proj": "rowwise", - "layers.*.block_sparse_moe.gate": "colwise_rep", # we need to replicate here to correctly route experts - "layers.*.block_sparse_moe.experts.*.w1": "colwise", - "layers.*.block_sparse_moe.experts.*.w2": "rowwise", - "layers.*.block_sparse_moe.experts.*.w3": "colwise", + "layers.*.self_attn.q_proj": "local_colwise", + "layers.*.self_attn.k_proj": "local_colwise", + "layers.*.self_attn.v_proj": "local_colwise", + "layers.*.self_attn.o_proj": "local_rowwise", + "layers.*.self_attn": "gather", + "layers.*.mlp.gate": "ep_router", # we need to replicate here to correctly route experts + "layers.*.mlp.experts.gate_up_proj": "local_colwise", + "layers.*.mlp.experts.down_proj": "local_rowwise", + "layers.*.mlp.experts": "gather", + # "layers.*.mlp.experts.gate_up_proj": "local_packed_rowwise" ? if you load from } base_model_pp_plan = { "embed_tokens": (["input_ids"], ["inputs_embeds"]), diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index d6b1b1100ba0..556353e5e7fc 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -28,6 +28,7 @@ from typing import Optional, Union import torch +import torch.nn.functional as F from torch import nn from transformers.utils.generic import check_model_inputs @@ -53,57 +54,62 @@ from .configuration_mixtral import MixtralConfig -class MixtralMLP(nn.Module): +class MixtralExperts(nn.Module): + """Collection of expert weights stored as 3D tensors.""" + def __init__(self, config: MixtralConfig): super().__init__() - self.ffn_dim = config.intermediate_size + self.num_experts = config.num_local_experts self.hidden_dim = config.hidden_size - - self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) - self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) - self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) - + self.intermediate_dim = config.intermediate_size + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) + self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) self.act_fn = ACT2FN[config.hidden_act] - def forward(self, hidden_states): - current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) - current_hidden_states = self.w2(current_hidden_states) - return current_hidden_states + def forward( + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + ) -> torch.Tensor: + final_hidden_states = torch.zeros_like(hidden_states) + num_experts = top_k_weights.shape[1] + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + for expert_idx in expert_hit: + expert_idx = expert_idx[0] + if expert_idx == num_experts: + continue + _, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) + current_hidden_states = self.act_fn(gate) * up + current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) + current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None] + final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) -class MixtralExperts(nn.ModuleList): - """ - ModuleList of experts. - """ + return final_hidden_states - def __init__(self, config: MixtralConfig): + +class MixtralTopKRouter(nn.Module): + def __init__(self, config): super().__init__() self.top_k = config.num_experts_per_tok self.num_experts = config.num_local_experts - for _ in range(self.num_experts): - self.append(MixtralMLP(config)) - - def forward( - self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor - ) -> torch.Tensor: - """ - Args: - hidden_states: (batch_size * sequence_length, hidden_dim) - top_k_index: (batch_size * sequence_length, top_k) - top_k_weights: (batch_size * sequence_length, top_k) - Returns: - (batch_size * sequence_length, hidden_dim) - """ - final_hidden_states = torch.zeros_like(hidden_states) - expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0) + self.hidden_dim = config.hidden_size + self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim)) - expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - for expert_idx in expert_hit: - idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) - current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1]) - current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None] - final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) - return final_hidden_states + def forward(self, hidden_states): + hidden_states = hidden_states.reshape(-1, self.hidden_dim) + router_logits = F.linear(hidden_states, self.weight) # (seq_len, num_experts) + router_logits = torch.nn.functional.softmax(router_logits.float(), dim=-1) + router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) + router_top_value /= router_top_value.sum(dim=-1, keepdim=True) + router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) + return router_scores, router_indices class MixtralSparseMoeBlock(nn.Module): @@ -111,23 +117,16 @@ def __init__(self, config): super().__init__() self.top_k = config.num_experts_per_tok self.jitter_noise = config.router_jitter_noise - self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) + self.gate = MixtralTopKRouter(config) self.experts = MixtralExperts(config) - def route_tokens_to_experts(self, router_logits): - routing_weights = torch.nn.functional.softmax(router_logits.float(), dim=-1) - top_k_weights, top_k_index = torch.topk(routing_weights, self.top_k, dim=-1) - top_k_weights /= top_k_weights.sum(dim=-1, keepdim=True) - return top_k_index, top_k_weights.to(router_logits.dtype) - def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: batch_size, sequence_length, hidden_dim = hidden_states.shape if self.training and self.jitter_noise > 0: hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - router_logits = self.gate(hidden_states) - top_k_index, top_k_weights = self.route_tokens_to_experts(router_logits) - hidden_states = self.experts(hidden_states, top_k_index, top_k_weights.to(hidden_states.dtype)) + top_k_weights, top_k_index = self.gate(hidden_states) + hidden_states = self.experts(hidden_states, top_k_index, top_k_weights) hidden_states = hidden_states.reshape(batch_size, sequence_length, hidden_dim) return hidden_states @@ -359,7 +358,7 @@ def __init__(self, config: MixtralConfig, layer_idx: int): self.self_attn = MixtralAttention(config, layer_idx) - self.block_sparse_moe = MixtralSparseMoeBlock(config) + self.mlp = MixtralSparseMoeBlock(config) self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -387,7 +386,7 @@ def forward( hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.block_sparse_moe(hidden_states) + hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states @@ -405,11 +404,21 @@ class MixtralPreTrainedModel(PreTrainedModel): _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) _supports_attention_backend = True _can_record_outputs = { - "router_logits": OutputRecorder(nn.Linear, layer_name="block_sparse_moe.gate", index=0), + "router_logits": OutputRecorder(nn.Linear, layer_name="mlp.gate", index=0), "hidden_states": MixtralDecoderLayer, "attentions": MixtralAttention, } + @torch.no_grad() + def _init_weights(self, module): + super()._init_weights(module) + std = self.config.initializer_range + if isinstance(module, MixtralExperts): + module.gate_up_proj.normal_(mean=0.0, std=std) + module.down_proj.normal_(mean=0.0, std=std) + elif isinstance(module, MixtralTopKRouter): + module.weight.normal_(mean=0.0, std=std) + @auto_docstring class MixtralModel(MixtralPreTrainedModel): @@ -576,7 +585,7 @@ def load_balancing_loss_func( @auto_docstring class MixtralForCausalLM(MixtralPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/mixtral/modular_mixtral.py b/src/transformers/models/mixtral/modular_mixtral.py index 65ea9b2e6b36..b369537fdeed 100644 --- a/src/transformers/models/mixtral/modular_mixtral.py +++ b/src/transformers/models/mixtral/modular_mixtral.py @@ -22,6 +22,7 @@ from typing import Optional, Union import torch +import torch.nn.functional as F from torch import nn from ...activations import ACT2FN @@ -29,6 +30,7 @@ from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast +from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, logging from ...utils.generic import OutputRecorder @@ -131,57 +133,62 @@ def load_balancing_loss_func( return overall_loss * num_experts -class MixtralMLP(nn.Module): +class MixtralExperts(nn.Module): + """Collection of expert weights stored as 3D tensors.""" + def __init__(self, config: MixtralConfig): super().__init__() - self.ffn_dim = config.intermediate_size + self.num_experts = config.num_local_experts self.hidden_dim = config.hidden_size - - self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) - self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) - self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) - + self.intermediate_dim = config.intermediate_size + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) + self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) self.act_fn = ACT2FN[config.hidden_act] - def forward(self, hidden_states): - current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) - current_hidden_states = self.w2(current_hidden_states) - return current_hidden_states + def forward( + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + ) -> torch.Tensor: + final_hidden_states = torch.zeros_like(hidden_states) + num_experts = top_k_weights.shape[1] + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + for expert_idx in expert_hit: + expert_idx = expert_idx[0] + if expert_idx == num_experts: + continue + _, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) + current_hidden_states = self.act_fn(gate) * up + current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) + current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None] + final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) -class MixtralExperts(nn.ModuleList): - """ - ModuleList of experts. - """ + return final_hidden_states - def __init__(self, config: MixtralConfig): + +class MixtralTopKRouter(nn.Module): + def __init__(self, config): super().__init__() self.top_k = config.num_experts_per_tok self.num_experts = config.num_local_experts - for _ in range(self.num_experts): - self.append(MixtralMLP(config)) - - def forward( - self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor - ) -> torch.Tensor: - """ - Args: - hidden_states: (batch_size * sequence_length, hidden_dim) - top_k_index: (batch_size * sequence_length, top_k) - top_k_weights: (batch_size * sequence_length, top_k) - Returns: - (batch_size * sequence_length, hidden_dim) - """ - final_hidden_states = torch.zeros_like(hidden_states) - expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0) + self.hidden_dim = config.hidden_size + self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim)) - expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - for expert_idx in expert_hit: - idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) - current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1]) - current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None] - final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) - return final_hidden_states + def forward(self, hidden_states): + hidden_states = hidden_states.reshape(-1, self.hidden_dim) + router_logits = F.linear(hidden_states, self.weight) # (seq_len, num_experts) + router_logits = torch.nn.functional.softmax(router_logits.float(), dim=-1) + router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) + router_top_value /= router_top_value.sum(dim=-1, keepdim=True) + router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) + return router_scores, router_indices class MixtralSparseMoeBlock(nn.Module): @@ -189,23 +196,16 @@ def __init__(self, config): super().__init__() self.top_k = config.num_experts_per_tok self.jitter_noise = config.router_jitter_noise - self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) + self.gate = MixtralTopKRouter(config) self.experts = MixtralExperts(config) - def route_tokens_to_experts(self, router_logits): - routing_weights = torch.nn.functional.softmax(router_logits.float(), dim=-1) - top_k_weights, top_k_index = torch.topk(routing_weights, self.top_k, dim=-1) - top_k_weights /= top_k_weights.sum(dim=-1, keepdim=True) - return top_k_index, top_k_weights.to(router_logits.dtype) - def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: batch_size, sequence_length, hidden_dim = hidden_states.shape if self.training and self.jitter_noise > 0: hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - router_logits = self.gate(hidden_states) - top_k_index, top_k_weights = self.route_tokens_to_experts(router_logits) - hidden_states = self.experts(hidden_states, top_k_index, top_k_weights.to(hidden_states.dtype)) + top_k_weights, top_k_index = self.gate(hidden_states) + hidden_states = self.experts(hidden_states, top_k_index, top_k_weights) hidden_states = hidden_states.reshape(batch_size, sequence_length, hidden_dim) return hidden_states @@ -229,7 +229,7 @@ def __init__(self, config: MixtralConfig, layer_idx: int): self.self_attn = MixtralAttention(config, layer_idx) - self.block_sparse_moe = MixtralSparseMoeBlock(config) + self.mlp = MixtralSparseMoeBlock(config) self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -257,7 +257,7 @@ def forward( hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.block_sparse_moe(hidden_states) + hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states @@ -265,11 +265,21 @@ def forward( class MixtralPreTrainedModel(MistralPreTrainedModel): _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) _can_record_outputs = { - "router_logits": OutputRecorder(nn.Linear, layer_name="block_sparse_moe.gate", index=0), + "router_logits": OutputRecorder(nn.Linear, layer_name="mlp.gate", index=0), "hidden_states": MixtralDecoderLayer, "attentions": MixtralAttention, } + @torch.no_grad() + def _init_weights(self, module): + PreTrainedModel._init_weights(self, module) + std = self.config.initializer_range + if isinstance(module, MixtralExperts): + module.gate_up_proj.normal_(mean=0.0, std=std) + module.down_proj.normal_(mean=0.0, std=std) + elif isinstance(module, MixtralTopKRouter): + module.weight.normal_(mean=0.0, std=std) + class MixtralModel(MistralModel): def forward( @@ -334,7 +344,7 @@ def forward( class MixtralForCausalLM(MistralForCausalLM): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/mlcd/modeling_mlcd.py b/src/transformers/models/mlcd/modeling_mlcd.py index fe7e8682b469..a4dd82865202 100644 --- a/src/transformers/models/mlcd/modeling_mlcd.py +++ b/src/transformers/models/mlcd/modeling_mlcd.py @@ -415,6 +415,7 @@ class MLCDPreTrainedModel(PreTrainedModel): "attentions": MLCDAttention, } + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" factor = self.config.initializer_factor @@ -441,10 +442,10 @@ def _init_weights(self, module): pos_emb_std = (module.config.hidden_size // module.config.num_attention_heads // 2) ** -0.5 * factor nn.init.normal_(module.class_pos_emb, mean=0.0, std=pos_emb_std) elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, nn.Linear) and module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() class MLCDVisionTransformer(nn.Module): diff --git a/src/transformers/models/mlcd/modular_mlcd.py b/src/transformers/models/mlcd/modular_mlcd.py index 2be712febf2f..e3a70b798496 100644 --- a/src/transformers/models/mlcd/modular_mlcd.py +++ b/src/transformers/models/mlcd/modular_mlcd.py @@ -354,6 +354,7 @@ class MLCDPreTrainedModel(PreTrainedModel): "attentions": MLCDAttention, } + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" factor = self.config.initializer_factor @@ -380,10 +381,10 @@ def _init_weights(self, module): pos_emb_std = (module.config.hidden_size // module.config.num_attention_heads // 2) ** -0.5 * factor nn.init.normal_(module.class_pos_emb, mean=0.0, std=pos_emb_std) elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, nn.Linear) and module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() class MLCDVisionTransformer(CLIPVisionTransformer): diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index c3c1930e386e..a5ffcac18f76 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -816,36 +816,37 @@ class MllamaPreTrainedModel(PreTrainedModel): ], } + @torch.no_grad() def _init_weights(self, module): std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() elif isinstance(module, MllamaTextRMSNorm): - module.weight.data.fill_(1.0) + module.weight.fill_(1.0) elif isinstance(module, MllamaVisionModel): - nn.init.normal_(module.class_embedding.data, std=std) + nn.init.normal_(module.class_embedding, std=std) elif isinstance(module, MllamaPrecomputedPositionEmbedding): - nn.init.normal_(module.embedding.data, std=std) - nn.init.zeros_(module.gate.data) + nn.init.normal_(module.embedding, std=std) + nn.init.zeros_(module.gate) elif isinstance(module, MllamaVisionEncoderLayer) and module.is_gated: - nn.init.normal_(module.gate_attn.data, std=std) - nn.init.normal_(module.gate_ffn.data, std=std) + nn.init.normal_(module.gate_attn, std=std) + nn.init.normal_(module.gate_ffn, std=std) elif isinstance(module, MllamaCrossAttentionDecoderLayer): - module.cross_attn_attn_gate.data.zero_() - module.cross_attn_mlp_gate.data.zero_() + module.cross_attn_attn_gate.zero_() + module.cross_attn_mlp_gate.zero_() elif isinstance(module, MllamaPrecomputedAspectRatioEmbedding): if module.is_gated: - module.gate.data.zero_() + module.gate.zero_() # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask def _update_causal_mask( @@ -1326,7 +1327,6 @@ class MllamaForCausalLM(MllamaPreTrainedModel, GenerationMixin): config: MllamaTextConfig _can_compile_fullgraph = True # only the LLM without cross attn can do compile base_model_prefix = "language_model" - _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): super().__init__(config.get_text_config()) @@ -1583,7 +1583,7 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin): "^multi_modal_projector": "model.multi_modal_projector", "^language_model.lm_head": "lm_head", } - _tied_weights_keys = ["lm_head.weight"] + # _tied_weights_keys = {"lm_head.weight": "model.language_moddel.embed_tokens.weight"} def __init__(self, config: MllamaConfig): super().__init__(config) diff --git a/src/transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py b/src/transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py index 3af9608e0b24..9de2d64b8e06 100644 --- a/src/transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +++ b/src/transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py @@ -506,6 +506,7 @@ class MMGroundingDinoPreTrainedModel(PreTrainedModel): main_input_name = "pixel_values" input_modalities = ["image", "text"] + @torch.no_grad() def _init_weights(self, module): std = self.config.init_std @@ -513,7 +514,7 @@ def _init_weights(self, module): nn.init.uniform_(module.row_embeddings.weight) nn.init.uniform_(module.column_embeddings.weight) elif isinstance(module, MMGroundingDinoMultiscaleDeformableAttention): - nn.init.constant_(module.sampling_offsets.weight.data, 0.0) + nn.init.constant_(module.sampling_offsets.weight, 0.0) default_dtype = torch.get_default_dtype() thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * ( 2.0 * math.pi / module.n_heads @@ -528,46 +529,46 @@ def _init_weights(self, module): grid_init[:, :, i, :] *= i + 1 with torch.no_grad(): module.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) - nn.init.constant_(module.attention_weights.weight.data, 0.0) - nn.init.constant_(module.attention_weights.bias.data, 0.0) - nn.init.xavier_uniform_(module.value_proj.weight.data) - nn.init.constant_(module.value_proj.bias.data, 0.0) - nn.init.xavier_uniform_(module.output_proj.weight.data) - nn.init.constant_(module.output_proj.bias.data, 0.0) + nn.init.constant_(module.attention_weights.weight, 0.0) + nn.init.constant_(module.attention_weights.bias, 0.0) + nn.init.xavier_uniform_(module.value_proj.weight) + nn.init.constant_(module.value_proj.bias, 0.0) + nn.init.xavier_uniform_(module.output_proj.weight) + nn.init.constant_(module.output_proj.bias, 0.0) elif isinstance(module, MMGroundingDinoBiMultiHeadAttention): nn.init.xavier_uniform_(module.vision_proj.weight) - module.vision_proj.bias.data.fill_(0) + module.vision_proj.bias.fill_(0) nn.init.xavier_uniform_(module.text_proj.weight) - module.text_proj.bias.data.fill_(0) + module.text_proj.bias.fill_(0) nn.init.xavier_uniform_(module.values_vision_proj.weight) - module.values_vision_proj.bias.data.fill_(0) + module.values_vision_proj.bias.fill_(0) nn.init.xavier_uniform_(module.values_text_proj.weight) - module.values_text_proj.bias.data.fill_(0) + module.values_text_proj.bias.fill_(0) nn.init.xavier_uniform_(module.out_vision_proj.weight) - module.out_vision_proj.bias.data.fill_(0) + module.out_vision_proj.bias.fill_(0) nn.init.xavier_uniform_(module.out_text_proj.weight) - module.out_text_proj.bias.data.fill_(0) + module.out_text_proj.bias.fill_(0) elif isinstance(module, MMGroundingDinoFusionLayer): - module.vision_param.data.fill_(1e-4) - module.text_param.data.fill_(1e-4) + module.vision_param.fill_(1e-4) + module.text_param.fill_(1e-4) elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, MMGroundingDinoMLPPredictionHead): - nn.init.constant_(module.layers[-1].weight.data, 0) - nn.init.constant_(module.layers[-1].bias.data, 0) + nn.init.constant_(module.layers[-1].weight, 0) + nn.init.constant_(module.layers[-1].bias, 0) if hasattr(module, "reference_points") and not self.config.two_stage: - nn.init.xavier_uniform_(module.reference_points.weight.data, gain=1.0) - nn.init.constant_(module.reference_points.bias.data, 0.0) + nn.init.xavier_uniform_(module.reference_points.weight, gain=1.0) + nn.init.constant_(module.reference_points.bias, 0.0) if hasattr(module, "level_embed"): nn.init.normal_(module.level_embed) if isinstance(module, MMGroundingDinoContrastiveEmbedding): @@ -2386,12 +2387,10 @@ def build_text_mask(logits, attention_mask): """ ) class MMGroundingDinoForObjectDetection(MMGroundingDinoPreTrainedModel): - _tied_weights_keys = [ - r"bbox_embed\.[1-9]\d*", - r"model\.decoder\.bbox_embed\.[0-9]\d*", - r"class_embed\.[1-9]\d*", - r"model\.decoder\.class_embed\.[0-9]\d*", - ] + _tied_weights_keys = { + r"bbox_embed.(?![0])\d+": "bbox_embed.0", + r"class_embed.(?![0])\d+": "class_embed.0", + } def __init__(self, config: MMGroundingDinoConfig): super().__init__(config) @@ -2410,12 +2409,6 @@ def __init__(self, config: MMGroundingDinoConfig): for _ in range(config.decoder_layers) ] ) - - # hack for box-refinement - self.model.decoder.bbox_embed = self.bbox_embed - # hack implementation for two-stage - self.model.decoder.class_embed = self.class_embed - # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py b/src/transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py index 4aed0c1a9b64..ab7c1d16e602 100644 --- a/src/transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +++ b/src/transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py @@ -318,6 +318,7 @@ def forward( class MMGroundingDinoPreTrainedModel(GroundingDinoPreTrainedModel): + @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) if isinstance(module, MMGroundingDinoContrastiveEmbedding): @@ -397,12 +398,11 @@ class MMGroundingDinoMLPPredictionHead(GroundingDinoMLPPredictionHead): class MMGroundingDinoForObjectDetection(GroundingDinoForObjectDetection, MMGroundingDinoPreTrainedModel): - _tied_weights_keys = [ - r"bbox_embed\.[1-9]\d*", - r"model\.decoder\.bbox_embed\.[0-9]\d*", - r"class_embed\.[1-9]\d*", - r"model\.decoder\.class_embed\.[0-9]\d*", - ] + _tied_weights_keys = { + "model.decoder.bbox_embed":"bbox_embed", + "model.decoder.class_embed":"class_embed", + r"class_embed.(?![0])\d+": "class_embed.0", + } def __init__(self, config: MMGroundingDinoConfig): MMGroundingDinoPreTrainedModel.__init__(self, config) @@ -421,12 +421,6 @@ def __init__(self, config: MMGroundingDinoConfig): for _ in range(config.decoder_layers) ] ) - - # hack for box-refinement - self.model.decoder.bbox_embed = self.bbox_embed - # hack implementation for two-stage - self.model.decoder.class_embed = self.class_embed - # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/mobilebert/modeling_mobilebert.py b/src/transformers/models/mobilebert/modeling_mobilebert.py index d08b70399da2..58964f4ad234 100644 --- a/src/transformers/models/mobilebert/modeling_mobilebert.py +++ b/src/transformers/models/mobilebert/modeling_mobilebert.py @@ -500,13 +500,8 @@ def __init__(self, config): # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. self.dense = nn.Linear(config.vocab_size, config.hidden_size - config.embedding_size, bias=False) - self.decoder = nn.Linear(config.embedding_size, config.vocab_size, bias=False) + self.decoder = nn.Linear(config.embedding_size, config.vocab_size, bias=True) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` - self.decoder.bias = self.bias - - def _tie_weights(self) -> None: - self.decoder.bias = self.bias def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.transform(hidden_states) @@ -551,21 +546,22 @@ class MobileBertPreTrainedModel(PreTrainedModel): "attentions": MobileBertSelfAttention, } + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, (nn.LayerNorm, NoNorm)): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, MobileBertLMPredictionHead): - module.bias.data.zero_() + module.bias.zero_() @dataclass @@ -670,7 +666,10 @@ def forward( """ ) class MobileBertForPreTraining(MobileBertPreTrainedModel): - _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"] + _tied_weights_keys = { + "cls.predictions.decoder.bias": "cls.predictions.bias", + "cls.predictions.decoder.weight": "mobilebert.embeddings.word_embeddings.weight", + } def __init__(self, config): super().__init__(config) @@ -766,7 +765,10 @@ def forward( @auto_docstring class MobileBertForMaskedLM(MobileBertPreTrainedModel): - _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"] + _tied_weights_keys = { + "cls.predictions.decoder.bias": "cls.predictions.bias", + "cls.predictions.decoder.weight": "mobilebert.embeddings.word_embeddings.weight", + } def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/mobilenet_v1/modeling_mobilenet_v1.py b/src/transformers/models/mobilenet_v1/modeling_mobilenet_v1.py index 25f8a826437c..a75da78ae3fb 100755 --- a/src/transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +++ b/src/transformers/models/mobilenet_v1/modeling_mobilenet_v1.py @@ -132,15 +132,16 @@ class MobileNetV1PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = False _no_split_modules = [] + @torch.no_grad() def _init_weights(self, module: Union[nn.Linear, nn.Conv2d]) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.BatchNorm2d): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) @auto_docstring diff --git a/src/transformers/models/mobilenet_v2/modeling_mobilenet_v2.py b/src/transformers/models/mobilenet_v2/modeling_mobilenet_v2.py index 0a92fb2f1093..ae5979de21b2 100755 --- a/src/transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +++ b/src/transformers/models/mobilenet_v2/modeling_mobilenet_v2.py @@ -258,15 +258,16 @@ class MobileNetV2PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = False _no_split_modules = [] + @torch.no_grad() def _init_weights(self, module: Union[nn.Linear, nn.Conv2d]) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.BatchNorm2d): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) @auto_docstring diff --git a/src/transformers/models/mobilevit/modeling_mobilevit.py b/src/transformers/models/mobilevit/modeling_mobilevit.py index f7f30b7faf1d..e2646d6c3e46 100755 --- a/src/transformers/models/mobilevit/modeling_mobilevit.py +++ b/src/transformers/models/mobilevit/modeling_mobilevit.py @@ -607,15 +607,16 @@ class MobileViTPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["MobileViTLayer"] + @torch.no_grad() def _init_weights(self, module: nn.Module) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) @auto_docstring diff --git a/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py b/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py index c637273f0395..d87aee1d7e63 100644 --- a/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py +++ b/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py @@ -574,15 +574,16 @@ class MobileViTV2PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["MobileViTV2Layer"] + @torch.no_grad() def _init_weights(self, module: nn.Module) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.GroupNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) @auto_docstring diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index 727640ac87c8..33d9411941e4 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -621,6 +621,7 @@ class ModernBertPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = False + @torch.no_grad() def _init_weights(self, module: nn.Module): cutoff_factor = self.config.initializer_cutoff_factor if cutoff_factor is None: @@ -669,9 +670,9 @@ def init_weight(module: nn.Module, std: float): ): init_weight(module.classifier, stds["final_out"]) elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) + module.weight.fill_(1.0) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() def _check_and_adjust_attn_implementation( self, attn_implementation: Optional[str], is_init_check: bool = False @@ -1020,7 +1021,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """ ) class ModernBertForMaskedLM(ModernBertPreTrainedModel): - _tied_weights_keys = ["decoder.weight"] + _tied_weights_keys = {"decoder.weight": "model.embeddings.tok_embeddings.weight"} def __init__(self, config: ModernBertConfig): super().__init__(config) diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 131a01e6db5c..3cbdf0d0a6c7 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -802,6 +802,7 @@ class ModernBertPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = False + @torch.no_grad() def _init_weights(self, module: nn.Module): cutoff_factor = self.config.initializer_cutoff_factor if cutoff_factor is None: @@ -850,9 +851,9 @@ def init_weight(module: nn.Module, std: float): ): init_weight(module.classifier, stds["final_out"]) elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) + module.weight.fill_(1.0) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() def _check_and_adjust_attn_implementation( self, attn_implementation: Optional[str], is_init_check: bool = False @@ -1129,7 +1130,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """ ) class ModernBertForMaskedLM(ModernBertPreTrainedModel): - _tied_weights_keys = ["decoder.weight"] + _tied_weights_keys = {"decoder.weight": "model.embeddings.tok_embeddings.weight"} def __init__(self, config: ModernBertConfig): super().__init__(config) diff --git a/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py b/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py index bb5c8dad9fa4..75d46ef20df7 100644 --- a/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +++ b/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py @@ -394,6 +394,7 @@ class ModernBertDecoderPreTrainedModel(PreTrainedModel): "attentions": ModernBertDecoderAttention, } + @torch.no_grad() def _init_weights(self, module: nn.Module): cutoff_factor = self.config.initializer_cutoff_factor if cutoff_factor is None: @@ -436,9 +437,9 @@ def init_weight(module: nn.Module, std: float): elif isinstance(module, ModernBertDecoderForCausalLM): init_weight(module.decoder, stds["out"]) elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) + module.weight.fill_(1.0) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() @auto_docstring @@ -549,7 +550,7 @@ def forward( """ ) class ModernBertDecoderForCausalLM(ModernBertDecoderPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["decoder.weight"] + _tied_weights_keys = {"decoder.weight": "model.embeddings.tok_embeddings.weight"} def __init__(self, config: ModernBertDecoderConfig): super().__init__(config) diff --git a/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py b/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py index e7935b9f2159..b5a38f6f716c 100644 --- a/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py +++ b/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py @@ -420,6 +420,7 @@ class ModernBertDecoderPreTrainedModel(ModernBertPreTrainedModel): "attentions": ModernBertDecoderAttention, } + @torch.no_grad() def _init_weights(self, module: nn.Module): cutoff_factor = self.config.initializer_cutoff_factor if cutoff_factor is None: @@ -462,9 +463,9 @@ def init_weight(module: nn.Module, std: float): elif isinstance(module, ModernBertDecoderForCausalLM): init_weight(module.decoder, stds["out"]) elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) + module.weight.fill_(1.0) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() def _check_and_adjust_attn_implementation(self, attn_implementation, is_init_check): raise AttributeError("No need to inherit!") @@ -584,7 +585,7 @@ def forward( """ ) class ModernBertDecoderForCausalLM(ModernBertDecoderPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["decoder.weight"] + _tied_weights_keys = {"decoder.weight": "model.embeddings.tok_embeddings.weight"} def __init__(self, config: ModernBertDecoderConfig): super().__init__(config) diff --git a/src/transformers/models/moonshine/modeling_moonshine.py b/src/transformers/models/moonshine/modeling_moonshine.py index 63b93f9c2651..0840c1623489 100644 --- a/src/transformers/models/moonshine/modeling_moonshine.py +++ b/src/transformers/models/moonshine/modeling_moonshine.py @@ -1009,7 +1009,7 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start """ ) class MoonshineForConditionalGeneration(MoonshinePreTrainedModel, GenerationMixin): - _tied_weights_keys = ["proj_out.weight"] + _tied_weights_keys = {"proj_out.weight": "model.decoder.embed_tokens.weight"} def __init__(self, config: MoonshineConfig): super().__init__(config) diff --git a/src/transformers/models/moonshine/modular_moonshine.py b/src/transformers/models/moonshine/modular_moonshine.py index bb66a7916f00..38314c4535a6 100644 --- a/src/transformers/models/moonshine/modular_moonshine.py +++ b/src/transformers/models/moonshine/modular_moonshine.py @@ -764,7 +764,7 @@ def forward( """ ) class MoonshineForConditionalGeneration(MoonshinePreTrainedModel, GenerationMixin): - _tied_weights_keys = ["proj_out.weight"] + _tied_weights_keys = {"proj_out.weight": "model.decoder.embed_tokens.weight"} def __init__(self, config: MoonshineConfig): super().__init__(config) diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index 01c89ecb52cc..8cb52f98e5e7 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -837,21 +837,22 @@ class MoshiPreTrainedModel(PreTrainedModel): main_input_name = "input_ids" + @torch.no_grad() def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, MoshiFlexibleLinear): - module.weight.data.normal_() + module.weight.normal_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, MoshiRMSNorm): - module.weight.data.fill_(1.0) + module.weight.fill_(1.0) class MoshiDepthDecoder(MoshiPreTrainedModel, GenerationMixin): @@ -1485,7 +1486,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position( ) class MoshiForCausalLM(MoshiPreTrainedModel, GenerationMixin): input_modalities = "text" - _tied_weights_keys = ["model.embed_tokens.weight", "lm_head.weight"] # Copied from transformers.models.gemma.modeling_gemma.GemmaForCausalLM.__init__ with Gemma->Moshi def __init__(self, config): @@ -1602,7 +1602,6 @@ def forward( """ ) class MoshiForConditionalGeneration(MoshiPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["decoder.model.embed_tokens.weight", "decoder.lm_head.weight"] config: MoshiConfig output_modalities = ["audio", "text"] main_input_name = "input_ids" diff --git a/src/transformers/models/mpnet/modeling_mpnet.py b/src/transformers/models/mpnet/modeling_mpnet.py index 233073814388..975dd0eaff57 100644 --- a/src/transformers/models/mpnet/modeling_mpnet.py +++ b/src/transformers/models/mpnet/modeling_mpnet.py @@ -45,21 +45,22 @@ class MPNetPreTrainedModel(PreTrainedModel): config: MPNetConfig base_model_prefix = "mpnet" + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, MPNetLMHead): - module.bias.data.zero_() + module.bias.zero_() class MPNetEmbeddings(nn.Module): @@ -464,7 +465,10 @@ def forward( class MPNetForMaskedLM(MPNetPreTrainedModel): - _tied_weights_keys = ["lm_head.decoder"] + _tied_weights_keys = { + "lm_head.decoder.weight": "mpnet.embeddings.word_embeddings.weight", + "lm_head.decoder.bias": "lm_head.bias", + } def __init__(self, config): super().__init__(config) @@ -540,15 +544,9 @@ def __init__(self, config): self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` - self.decoder.bias = self.bias - - def _tie_weights(self): - self.decoder.bias = self.bias - def forward(self, features, **kwargs): x = self.dense(features) x = gelu(x) diff --git a/src/transformers/models/mpt/modeling_mpt.py b/src/transformers/models/mpt/modeling_mpt.py index 00cdac508d64..0d666447910b 100644 --- a/src/transformers/models/mpt/modeling_mpt.py +++ b/src/transformers/models/mpt/modeling_mpt.py @@ -222,25 +222,22 @@ class MptPreTrainedModel(PreTrainedModel): base_model_prefix = "transformer" supports_gradient_checkpointing = True _no_split_modules = ["MptBlock"] - _keys_to_ignore_on_load_missing = [r"lm_head.*."] - - def __init__(self, *inputs, **kwargs): - super().__init__(*inputs, **kwargs) + @torch.no_grad() def _init_weights(self, module: nn.Module): """Initialize the weights.""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, LayerNorm): if module.bias is not None: - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) @auto_docstring @@ -396,7 +393,7 @@ def forward( """ ) class MptForCausalLM(MptPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "transformer.wte.weight"} def __init__(self, config: MptConfig): super().__init__(config) @@ -502,6 +499,9 @@ def __init__(self, config: MptConfig): # Initialize weights and apply final processing self.post_init() + def set_output_embeddings(self, new_embeddings: torch.Tensor): + self.score = new_embeddings + @auto_docstring def forward( self, diff --git a/src/transformers/models/mra/modeling_mra.py b/src/transformers/models/mra/modeling_mra.py index 478d66781851..9bd95879a05b 100644 --- a/src/transformers/models/mra/modeling_mra.py +++ b/src/transformers/models/mra/modeling_mra.py @@ -762,16 +762,9 @@ def __init__(self, config): # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. - self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` - self.decoder.bias = self.bias - - def _tie_weights(self): - self.decoder.bias = self.bias - def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -796,22 +789,23 @@ class MraPreTrainedModel(PreTrainedModel): base_model_prefix = "mra" supports_gradient_checkpointing = True + @torch.no_grad() def _init_weights(self, module: nn.Module): """Initialize the weights""" std = self.config.initializer_range if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, MraLMPredictionHead): - module.bias.data.zero_() + module.bias.zero_() @auto_docstring @@ -903,7 +897,10 @@ def forward( @auto_docstring class MraForMaskedLM(MraPreTrainedModel): - _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"] + _tied_weights_keys = { + "cls.predictions.decoder.bias": "cls.predictions.bias", + "cls.predictions.decoder.weight": "mra.embeddings.word_embeddings.weight", + } def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index d1268c609446..8b48ec869bbd 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -566,59 +566,60 @@ def dummy_inputs(self): } return dummy_inputs + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" factor = self.config.initializer_factor # Used for testing weights initialization if isinstance(module, MT5LayerNorm): - module.weight.data.fill_(factor * 1.0) + module.weight.fill_(factor * 1.0) elif isinstance( module, (MT5Model, MT5ForConditionalGeneration, MT5EncoderModel, MT5ForQuestionAnswering), ): - module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) + module.shared.weight.normal_(mean=0.0, std=factor * 1.0) if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: - module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0) + module.lm_head.weight.normal_(mean=0.0, std=factor * 1.0) if hasattr(module, "qa_outputs"): - module.qa_outputs.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) - module.qa_outputs.bias.data.zero_() + module.qa_outputs.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + module.qa_outputs.bias.zero_() elif isinstance(module, MT5ForTokenClassification): if hasattr(module, "classifier"): - module.classifier.weight.data.normal_(mean=0.0, std=factor * 1.0) - module.classifier.bias.data.zero_() + module.classifier.weight.normal_(mean=0.0, std=factor * 1.0) + module.classifier.bias.zero_() elif isinstance(module, MT5ClassificationHead): - module.dense.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + module.dense.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.dense, "bias") and module.dense.bias is not None: - module.dense.bias.data.zero_() - module.out_proj.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + module.dense.bias.zero_() + module.out_proj.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None: - module.out_proj.bias.data.zero_() + module.out_proj.bias.zero_() elif isinstance(module, MT5DenseActDense): - module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + module.wi.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.wi, "bias") and module.wi.bias is not None: - module.wi.bias.data.zero_() - module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + module.wi.bias.zero_() + module.wo.weight.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) if hasattr(module.wo, "bias") and module.wo.bias is not None: - module.wo.bias.data.zero_() + module.wo.bias.zero_() elif isinstance(module, MT5DenseGatedActDense): - module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + module.wi_0.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None: - module.wi_0.bias.data.zero_() - module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + module.wi_0.bias.zero_() + module.wi_1.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None: - module.wi_1.bias.data.zero_() - module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + module.wi_1.bias.zero_() + module.wo.weight.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) if hasattr(module.wo, "bias") and module.wo.bias is not None: - module.wo.bias.data.zero_() + module.wo.bias.zero_() elif isinstance(module, MT5Attention): d_model = self.config.d_model key_value_proj_dim = self.config.d_kv n_heads = self.config.num_heads - module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) - module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) - module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) - module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) + module.q.weight.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) + module.k.weight.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.v.weight.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.o.weight.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) if module.has_relative_attention_bias: - module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) + module.relative_attention_bias.weight.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id @@ -644,10 +645,10 @@ def _shift_right(self, input_ids): # Copied from transformers.models.t5.modeling_t5.T5Stack with T5->MT5 class MT5Stack(MT5PreTrainedModel): - def __init__(self, config, embed_tokens=None): + def __init__(self, config): super().__init__(config) - self.embed_tokens = embed_tokens + self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model) self.is_decoder = config.is_decoder self.block = nn.ModuleList( @@ -985,7 +986,10 @@ class MT5Model(MT5PreTrainedModel): model_type = "mt5" config: MT5Config _keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"] - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + _tied_weights_keys = { + "encoder.embed_tokens.weight": "shared.weight", + "decoder.embed_tokens.weight": "shared.weight", + } # Copied from transformers.models.t5.modeling_t5.T5Model.__init__ with T5->MT5 def __init__(self, config: MT5Config): @@ -996,13 +1000,13 @@ def __init__(self, config: MT5Config): encoder_config.is_decoder = False encoder_config.use_cache = False encoder_config.tie_encoder_decoder = False - self.encoder = MT5Stack(encoder_config, self.shared) + self.encoder = MT5Stack(encoder_config) decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True decoder_config.tie_encoder_decoder = False decoder_config.num_layers = config.num_decoder_layers - self.decoder = MT5Stack(decoder_config, self.shared) + self.decoder = MT5Stack(decoder_config) # Initialize weights and apply final processing self.post_init() @@ -1165,7 +1169,11 @@ class MT5ForConditionalGeneration(MT5PreTrainedModel, GenerationMixin): model_type = "mt5" config: MT5Config _keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"] - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + _tied_weights_keys = { + "encoder.embed_tokens.weight": "shared.weight", + "decoder.embed_tokens.weight": "shared.weight", + "lm_head.weight": "shared.weight", + } # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.__init__ with T5->MT5 def __init__(self, config: MT5Config): @@ -1178,13 +1186,13 @@ def __init__(self, config: MT5Config): encoder_config.is_decoder = False encoder_config.use_cache = False encoder_config.tie_encoder_decoder = False - self.encoder = MT5Stack(encoder_config, self.shared) + self.encoder = MT5Stack(encoder_config) decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True decoder_config.tie_encoder_decoder = False decoder_config.num_layers = config.num_decoder_layers - self.decoder = MT5Stack(decoder_config, self.shared) + self.decoder = MT5Stack(decoder_config) self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) @@ -1372,7 +1380,9 @@ class MT5EncoderModel(MT5PreTrainedModel): model_type = "mt5" config: MT5Config - _tied_weights_keys = ["encoder.embed_tokens.weight"] + _tied_weights_keys = { + "encoder.embed_tokens.weight": "shared.weight", + } # Copied from transformers.models.t5.modeling_t5.T5EncoderModel.__init__ with T5->MT5 def __init__(self, config: MT5Config): @@ -1382,7 +1392,7 @@ def __init__(self, config: MT5Config): encoder_config = config encoder_config.use_cache = False encoder_config.is_encoder_decoder = False - self.encoder = MT5Stack(encoder_config, self.shared) + self.encoder = MT5Stack(encoder_config) # Initialize weights and apply final processing self.post_init() @@ -1456,7 +1466,6 @@ def forward( ) class MT5ForSequenceClassification(MT5PreTrainedModel): _keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"] - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] # Copied from transformers.models.t5.modeling_t5.T5ForSequenceClassification.__init__ with T5->MT5 def __init__(self, config: MT5Config): @@ -1600,8 +1609,6 @@ def forward( @auto_docstring class MT5ForTokenClassification(MT5PreTrainedModel): - _tied_weights_keys = ["transformer.encoder.embed_tokens.weight"] - # Copied from transformers.models.t5.modeling_t5.T5ForTokenClassification.__init__ with T5->MT5 def __init__(self, config: MT5Config): super().__init__(config) @@ -1675,7 +1682,10 @@ def forward( @auto_docstring class MT5ForQuestionAnswering(MT5PreTrainedModel): _keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"] - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + _tied_weights_keys = { + "encoder.embed_tokens.weight": "shared.weight", + "decoder.embed_tokens.weight": "shared.weight", + } # Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering.__init__ with T5->MT5 def __init__(self, config: MT5Config): @@ -1688,13 +1698,13 @@ def __init__(self, config: MT5Config): encoder_config.is_decoder = False encoder_config.use_cache = False encoder_config.tie_encoder_decoder = False - self.encoder = MT5Stack(encoder_config, self.shared) + self.encoder = MT5Stack(encoder_config) decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True decoder_config.tie_encoder_decoder = False decoder_config.num_layers = config.num_decoder_layers - self.decoder = MT5Stack(decoder_config, self.shared) + self.decoder = MT5Stack(decoder_config) self.num_labels = config.num_labels self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index 61b5f2948e3f..86988f9da002 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -416,19 +416,20 @@ class MusicgenPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True + @torch.no_grad() def _init_weights(self, module): std = self.config.initializer_factor if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() class MusicgenDecoder(MusicgenPreTrainedModel): @@ -1393,23 +1394,7 @@ def __init__( ) # tie text encoder, decoder weights if config set accordingly - self.tie_weights() - - def tie_weights(self): - # tie text encoder & decoder if needed - if self.config.tie_encoder_decoder: - # tie text encoder and decoder base model - decoder_base_model_prefix = self.decoder.base_model_prefix - tied_weights = self._tie_encoder_decoder_weights( - self.text_encoder, - self.decoder._modules[decoder_base_model_prefix], - self.decoder.base_model_prefix, - "text_encoder", - ) - # Setting a dynamic variable instead of `_tied_weights_keys` because it's a class - # attributed not an instance member, therefore modifying it will modify the entire class - # Leading to issues on subsequent calls by different tests or subsequent calls. - self._dynamic_tied_weights_keys = tied_weights + self.post_init() def get_audio_encoder(self): return self.audio_encoder diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index 74632ec86c81..0e48bab3a768 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -387,19 +387,20 @@ class MusicgenMelodyPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True + @torch.no_grad() def _init_weights(self, module): std = self.config.initializer_factor if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder with MUSICGEN->MUSICGEN_MELODY,Musicgen->MusicgenMelody @@ -1305,30 +1306,15 @@ def __init__( # Initialize projection layers weights and tie text encoder and decoder weights if set accordingly self.post_init() + @torch.no_grad() def _init_weights(self, module): # MusicgenMelodyForConditionalGeneration is made of PreTrainedModels that have already been initialized # Projection layers still need to be initialized. std = self.decoder.config.initializer_factor if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() - - def tie_weights(self): - # tie text encoder & decoder if needed - if self.config.tie_encoder_decoder: - # tie text encoder and decoder base model - decoder_base_model_prefix = self.decoder.base_model_prefix - tied_weights = self._tie_encoder_decoder_weights( - self.text_encoder, - self.decoder._modules[decoder_base_model_prefix], - self.decoder.base_model_prefix, - "text_encoder", - ) - # Setting a dynamic variable instead of `_tied_weights_keys` because it's a class - # attributed not an instance member, therefore modifying it will modify the entire class - # Leading to issues on subsequent calls by different tests or subsequent calls. - self._dynamic_tied_weights_keys = tied_weights + module.bias.zero_() def get_text_encoder(self): return self.text_encoder diff --git a/src/transformers/models/mvp/modeling_mvp.py b/src/transformers/models/mvp/modeling_mvp.py index 6f2bf620cfe4..c4d3350dc129 100644 --- a/src/transformers/models/mvp/modeling_mvp.py +++ b/src/transformers/models/mvp/modeling_mvp.py @@ -469,16 +469,17 @@ class MvpPreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True + @torch.no_grad() def _init_weights(self, module): std = self.config.init_std if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() @property def dummy_inputs(self): @@ -515,10 +516,7 @@ def __init__( self.max_source_positions = config.max_position_embeddings self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 - if embed_tokens is not None: - self.embed_tokens = embed_tokens - else: - self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) + self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) self.embed_positions = MvpLearnedPositionalEmbedding( config.max_position_embeddings, @@ -665,9 +663,7 @@ class MvpDecoder(MvpPreTrainedModel): use_prompt (bool): whether to use prompt """ - def __init__( - self, config: MvpConfig, embed_tokens: Optional[nn.Embedding] = None, use_prompt: Optional[bool] = False - ): + def __init__(self, config: MvpConfig, use_prompt: Optional[bool] = False): super().__init__(config) self.dropout = config.dropout self.layerdrop = config.decoder_layerdrop @@ -675,11 +671,7 @@ def __init__( self.max_target_positions = config.max_position_embeddings self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 - if embed_tokens is not None: - self.embed_tokens = embed_tokens - else: - self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) - + self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) self.embed_positions = MvpLearnedPositionalEmbedding( config.max_position_embeddings, config.d_model, @@ -887,7 +879,10 @@ def forward( @auto_docstring class MvpModel(MvpPreTrainedModel): _keys_to_ignore_on_load_unexpected = ["final_logits_bias"] - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + _tied_weights_keys = { + "encoder.embed_tokens.weight": "shared.weight", + "decoder.embed_tokens.weight": "shared.weight", + } def __init__(self, config: MvpConfig): super().__init__(config) @@ -896,8 +891,8 @@ def __init__(self, config: MvpConfig): self.use_prompt = config.use_prompt self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) - self.encoder = MvpEncoder(config, self.shared, config.use_prompt) - self.decoder = MvpDecoder(config, self.shared, config.use_prompt) + self.encoder = MvpEncoder(config, config.use_prompt) + self.decoder = MvpDecoder(config, config.use_prompt) # Initialize weights and apply final processing self.post_init() @@ -1035,7 +1030,9 @@ def forward( """ ) class MvpForConditionalGeneration(MvpPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + _tied_weights_keys = { + "lm_head.weight": "model.shared.weight", + } def __init__(self, config: MvpConfig): super().__init__(config) @@ -1205,8 +1202,6 @@ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): """ ) class MvpForSequenceClassification(MvpPreTrainedModel): - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] - def __init__(self, config: MvpConfig, **kwargs): super().__init__(config, **kwargs) self.model = MvpModel(config) @@ -1366,8 +1361,6 @@ def forward( @auto_docstring class MvpForQuestionAnswering(MvpPreTrainedModel): - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] - def __init__(self, config): super().__init__(config) @@ -1537,7 +1530,7 @@ def forward(self, *args, **kwargs): class MvpForCausalLM(MvpPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.decoder.embed_tokens.weight"} def __init__(self, config): config.is_decoder = True diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index 1c8c7eca861f..c9f9ade48632 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -622,19 +622,20 @@ class NemotronPreTrainedModel(PreTrainedModel): _can_compile_fullgraph = True + @torch.no_grad() def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, NemotronLayerNorm1P): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() @auto_docstring @@ -881,7 +882,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( # TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron class NemotronForCausalLM(NemotronPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index dc4fb4e22bd1..b8bdd3efb14f 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -665,20 +665,21 @@ class NllbMoePreTrainedModel(PreTrainedModel): _supports_sdpa = False _supports_flex_attn = False + @torch.no_grad() def _init_weights(self, module: nn.Module): """Initialize the weights""" std = self.config.init_std if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() class NllbMoeEncoder(NllbMoePreTrainedModel): @@ -688,7 +689,7 @@ class NllbMoeEncoder(NllbMoePreTrainedModel): "attentions": NllbMoeAttention, } - def __init__(self, config: NllbMoeConfig, embed_tokens: Optional[nn.Embedding] = None): + def __init__(self, config: NllbMoeConfig): super().__init__(config) self.dropout = config.dropout @@ -703,9 +704,6 @@ def __init__(self, config: NllbMoeConfig, embed_tokens: Optional[nn.Embedding] = config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale ) - if embed_tokens is not None: - self.embed_tokens.weight = embed_tokens.weight - self.embed_positions = NllbMoeSinusoidalPositionalEmbedding( config.max_position_embeddings, embed_dim, @@ -775,7 +773,7 @@ class NllbMoeDecoder(NllbMoePreTrainedModel): "cross_attentions": OutputRecorder(NllbMoeAttention, layer_name="cross_attention", index=1), } - def __init__(self, config: NllbMoeConfig, embed_tokens: Optional[nn.Embedding] = None): + def __init__(self, config: NllbMoeConfig): super().__init__(config) self.dropout = config.dropout self.layerdrop = config.decoder_layerdrop @@ -787,9 +785,6 @@ def __init__(self, config: NllbMoeConfig, embed_tokens: Optional[nn.Embedding] = config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale ) - if embed_tokens is not None: - self.embed_tokens.weight = embed_tokens.weight - self.embed_positions = NllbMoeSinusoidalPositionalEmbedding( config.max_position_embeddings, config.d_model, @@ -888,7 +883,10 @@ def forward( @auto_docstring class NllbMoeModel(NllbMoePreTrainedModel): - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + _tied_weights_keys = { + "encoder.embed_tokens.weight": "shared.weight", + "decoder.embed_tokens.weight": "shared.weight", + } def __init__(self, config: NllbMoeConfig): super().__init__(config) @@ -897,8 +895,8 @@ def __init__(self, config: NllbMoeConfig): embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 self.shared = NllbMoeScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale) - self.encoder = NllbMoeEncoder(config, self.shared) - self.decoder = NllbMoeDecoder(config, self.shared) + self.encoder = NllbMoeEncoder(config) + self.decoder = NllbMoeDecoder(config) # Initialize weights and apply final processing self.post_init() @@ -911,11 +909,6 @@ def set_input_embeddings(self, value): self.encoder.embed_tokens = self.shared self.decoder.embed_tokens = self.shared - def _tie_weights(self): - if self.config.tie_word_embeddings: - self._tie_embedding_weights(self.encoder.embed_tokens, self.shared) - self._tie_embedding_weights(self.decoder.embed_tokens, self.shared) - def get_encoder(self): return self.encoder @@ -1075,7 +1068,9 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start ) class NllbMoeForConditionalGeneration(NllbMoePreTrainedModel, GenerationMixin): base_model_prefix = "model" - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + _tied_weights_keys = { + "lm_head.weight": "model.shared.weight", + } def __init__(self, config: NllbMoeConfig): super().__init__(config) diff --git a/src/transformers/models/nystromformer/modeling_nystromformer.py b/src/transformers/models/nystromformer/modeling_nystromformer.py index 07902d4d1946..cbde955ecde2 100755 --- a/src/transformers/models/nystromformer/modeling_nystromformer.py +++ b/src/transformers/models/nystromformer/modeling_nystromformer.py @@ -387,16 +387,9 @@ def __init__(self, config): # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. - self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` - self.decoder.bias = self.bias - - def _tie_weights(self): - self.decoder.bias = self.bias - def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -420,19 +413,20 @@ class NystromformerPreTrainedModel(PreTrainedModel): base_model_prefix = "nystromformer" supports_gradient_checkpointing = True + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) @auto_docstring @@ -527,7 +521,10 @@ def forward( @auto_docstring class NystromformerForMaskedLM(NystromformerPreTrainedModel): - _tied_weights_keys = ["cls.predictions.decoder"] + _tied_weights_keys = { + "cls.predictions.decoder.weight": "nystromformer.embeddings.word_embeddings.weight", + "cls.predictions.decoder.bias": "cls.predictions.bias", + } def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 6a3432c31d18..4df5dbbd5a35 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -436,7 +436,7 @@ def forward( @auto_docstring class OlmoForCausalLM(OlmoPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 7315661282c9..d1f037ce33d3 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -441,7 +441,7 @@ def forward( @auto_docstring class Olmo2ForCausalLM(Olmo2PreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/olmo3/modeling_olmo3.py b/src/transformers/models/olmo3/modeling_olmo3.py index 2888f787399b..d49570982f48 100644 --- a/src/transformers/models/olmo3/modeling_olmo3.py +++ b/src/transformers/models/olmo3/modeling_olmo3.py @@ -448,7 +448,7 @@ def forward( @auto_docstring class Olmo3ForCausalLM(Olmo3PreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/olmoe/configuration_olmoe.py b/src/transformers/models/olmoe/configuration_olmoe.py index 511d7968fb78..efc04e8a56bb 100644 --- a/src/transformers/models/olmoe/configuration_olmoe.py +++ b/src/transformers/models/olmoe/configuration_olmoe.py @@ -104,6 +104,7 @@ class OlmoeConfig(PreTrainedConfig): model_type = "olmoe" keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = {"num_local_experts": "num_experts"} def __init__( self, diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index f6034bd9fc6f..2e2d334e3d7e 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -20,6 +20,7 @@ from typing import Optional, Union import torch +import torch.nn.functional as F from torch import nn from ...activations import ACT2FN @@ -294,64 +295,77 @@ def forward( return attn_output, attn_weights -class OlmoeExperts(nn.ModuleList): - """ - ModuleList of experts. - """ +class OlmoeExperts(nn.Module): + """Collection of expert weights stored as 3D tensors.""" - def __init__(self, config): + def __init__(self, config: OlmoeConfig): super().__init__() - for _ in range(config.num_experts): - self.append(OlmoeMLP(config)) - self.num_experts = config.num_experts - self.top_k = config.num_experts_per_tok - self.norm_topk_prob = config.norm_topk_prob + self.num_experts = config.num_local_experts + self.hidden_dim = config.hidden_size + self.intermediate_dim = config.intermediate_size + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) + self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) + self.act_fn = ACT2FN[config.hidden_act] def forward( - self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, ) -> torch.Tensor: - """ - Args: - hidden_states: (batch_size * sequence_length, hidden_dim) - top_k_index: (batch_size * sequence_length, top_k) - top_k_weights: (batch_size * sequence_length, top_k) - Returns: - (batch_size * sequence_length, hidden_dim) - """ final_hidden_states = torch.zeros_like(hidden_states) - expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0) + num_experts = top_k_weights.shape[1] + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() for expert_idx in expert_hit: - idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) - current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1]) - current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None] - final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + expert_idx = expert_idx[0] + if expert_idx == num_experts: + continue + _, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) + current_hidden_states = self.act_fn(gate) * up + current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) + current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None] + final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) + return final_hidden_states -class OlmoeSparseMoeBlock(nn.Module): +class OlmoeTopKRouter(nn.Module): def __init__(self, config): super().__init__() - self.num_experts = config.num_experts self.top_k = config.num_experts_per_tok + self.num_experts = config.num_experts self.norm_topk_prob = config.norm_topk_prob - self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False) - self.experts = OlmoeExperts(config) + self.hidden_dim = config.hidden_size + self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim)) - def route_tokens_to_experts(self, hidden_states, router_logits): - routing_weights = torch.nn.functional.softmax(router_logits.float(), dim=-1) - top_k_weights, top_k_index = torch.topk(routing_weights, self.top_k, dim=-1) + def forward(self, hidden_states): + hidden_states = hidden_states.reshape(-1, self.hidden_dim) + router_logits = F.linear(hidden_states, self.weight) # (seq_len, num_experts) + router_logits = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1) + router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) if self.norm_topk_prob: - top_k_weights /= top_k_weights.sum(dim=-1, keepdim=True) - top_k_weights = top_k_weights.to(hidden_states.dtype) - return top_k_index, top_k_weights + router_top_value /= router_top_value.sum(dim=-1, keepdim=True) + router_top_value = router_top_value.to(router_logits.dtype) + router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) + return router_scores, router_indices + + +class OlmoeSparseMoeBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.gate = OlmoeTopKRouter(config) + self.experts = OlmoeExperts(config) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) - router_logits = self.gate(hidden_states) - top_k_index, top_k_weights = self.route_tokens_to_experts(hidden_states, router_logits) + top_k_weights, top_k_index = self.gate(hidden_states) final_hidden_states = self.experts(hidden_states, top_k_index, top_k_weights).reshape( batch_size, sequence_length, hidden_dim ) @@ -411,7 +425,7 @@ class OlmoePreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _can_record_outputs = { - "router_logits": OutputRecorder(nn.Linear, layer_name="gate", index=1), + "router_logits": OutputRecorder(OlmoeTopKRouter, index=0), "hidden_states": OlmoeDecoderLayer, "attentions": OlmoeAttention, } @@ -584,7 +598,7 @@ def load_balancing_loss_func( @auto_docstring class OlmoeForCausalLM(OlmoePreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/olmoe/modular_olmoe.py b/src/transformers/models/olmoe/modular_olmoe.py index 8220a0d7a0f0..ac50b93d5dc1 100644 --- a/src/transformers/models/olmoe/modular_olmoe.py +++ b/src/transformers/models/olmoe/modular_olmoe.py @@ -35,6 +35,7 @@ eager_attention_forward, ) from ..mixtral.modeling_mixtral import MixtralExperts, MixtralForCausalLM, MixtralModel +from ..qwen2_moe.modeling_qwen2_moe import Qwen2MoeTopKRouter from .configuration_olmoe import OlmoeConfig @@ -115,38 +116,24 @@ def forward( return attn_output, attn_weights -class OlmoeExperts(MixtralExperts, nn.ModuleList): - def __init__(self, config): - nn.ModuleList.__init__(self) - for _ in range(config.num_experts): - self.append(OlmoeMLP(config)) - self.num_experts = config.num_experts - self.top_k = config.num_experts_per_tok - self.norm_topk_prob = config.norm_topk_prob +class OlmoeExperts(MixtralExperts): + pass + + +class OlmoeTopKRouter(Qwen2MoeTopKRouter): + pass class OlmoeSparseMoeBlock(nn.Module): def __init__(self, config): super().__init__() - self.num_experts = config.num_experts - self.top_k = config.num_experts_per_tok - self.norm_topk_prob = config.norm_topk_prob - self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False) + self.gate = OlmoeTopKRouter(config) self.experts = OlmoeExperts(config) - def route_tokens_to_experts(self, hidden_states, router_logits): - routing_weights = torch.nn.functional.softmax(router_logits.float(), dim=-1) - top_k_weights, top_k_index = torch.topk(routing_weights, self.top_k, dim=-1) - if self.norm_topk_prob: - top_k_weights /= top_k_weights.sum(dim=-1, keepdim=True) - top_k_weights = top_k_weights.to(hidden_states.dtype) - return top_k_index, top_k_weights - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) - router_logits = self.gate(hidden_states) - top_k_index, top_k_weights = self.route_tokens_to_experts(hidden_states, router_logits) + top_k_weights, top_k_index = self.gate(hidden_states) final_hidden_states = self.experts(hidden_states, top_k_index, top_k_weights).reshape( batch_size, sequence_length, hidden_dim ) @@ -173,7 +160,7 @@ class OlmoePreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _can_record_outputs = { - "router_logits": OutputRecorder(nn.Linear, layer_name="gate", index=1), + "router_logits": OutputRecorder(OlmoeTopKRouter, index=0), "hidden_states": OlmoeDecoderLayer, "attentions": OlmoeAttention, } @@ -255,7 +242,7 @@ def forward( class OlmoeForCausalLM(MixtralForCausalLM, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py b/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py index 3c552a4b5cb5..fe899ef89e98 100644 --- a/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py +++ b/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py @@ -987,6 +987,7 @@ class OmDetTurboPreTrainedModel(PreTrainedModel): main_input_name = "pixel_values" input_modalities = ["image", "text"] + @torch.no_grad() def _init_weights(self, module): def linear_init_(module_to_init): bound = 1 / math.sqrt(module_to_init.weight.shape[0]) @@ -1014,12 +1015,12 @@ def linear_init_(module_to_init): elif isinstance(module, OmDetTurboLanguageBackbone): nn.init.normal_(module.text_projection, std=self.config.text_projection_in_dim**-0.5) elif isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=self.config.init_std) + module.weight.normal_(mean=0.0, std=self.config.init_std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, OmDetTurboDecoder): diff --git a/src/transformers/models/oneformer/modeling_oneformer.py b/src/transformers/models/oneformer/modeling_oneformer.py index 929d21fa341a..0f4b16d072b1 100644 --- a/src/transformers/models/oneformer/modeling_oneformer.py +++ b/src/transformers/models/oneformer/modeling_oneformer.py @@ -2766,6 +2766,7 @@ class OneFormerPreTrainedModel(PreTrainedModel): main_input_name = "pixel_values" input_modalities = "image" + @torch.no_grad() def _init_weights(self, module: nn.Module): xavier_std = self.config.init_xavier_std std = self.config.init_std @@ -2779,7 +2780,7 @@ def _init_weights(self, module: nn.Module): nn.init.xavier_uniform_(module.query_input_projection.weight, gain=xavier_std) nn.init.constant_(module.query_input_projection.bias, 0) elif isinstance(module, OneFormerPixelDecoderEncoderMultiscaleDeformableAttention): - nn.init.constant_(module.sampling_offsets.weight.data, 0.0) + nn.init.constant_(module.sampling_offsets.weight, 0.0) thetas = torch.arange(module.n_heads, dtype=torch.int64).float() * (2.0 * math.pi / module.n_heads) grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) grid_init = ( @@ -2791,12 +2792,12 @@ def _init_weights(self, module: nn.Module): grid_init[:, :, i, :] *= i + 1 with torch.no_grad(): module.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) - nn.init.constant_(module.attention_weights.weight.data, 0.0) - nn.init.constant_(module.attention_weights.bias.data, 0.0) - nn.init.xavier_uniform_(module.value_proj.weight.data) - nn.init.constant_(module.value_proj.bias.data, 0.0) - nn.init.xavier_uniform_(module.output_proj.weight.data) - nn.init.constant_(module.output_proj.bias.data, 0.0) + nn.init.constant_(module.attention_weights.weight, 0.0) + nn.init.constant_(module.attention_weights.bias, 0.0) + nn.init.xavier_uniform_(module.value_proj.weight) + nn.init.constant_(module.value_proj.bias, 0.0) + nn.init.xavier_uniform_(module.output_proj.weight) + nn.init.constant_(module.output_proj.bias, 0.0) elif isinstance(module, OneFormerPixelDecoder): nn.init.normal_(module.level_embed, std=0) elif isinstance(module, (OneFormerTransformerDecoderLayer, OneFormerTransformerDecoderQueryTransformer)): @@ -2816,29 +2817,29 @@ def _init_weights(self, module: nn.Module): nn.init.normal_(module.token_embedding.weight, std=0.02) nn.init.normal_(module.positional_embedding, std=0.01) if hasattr(module, "reference_points"): - nn.init.xavier_uniform_(module.reference_points.weight.data, gain=1.0) - nn.init.constant_(module.reference_points.bias.data, 0.0) + nn.init.xavier_uniform_(module.reference_points.weight, gain=1.0) + nn.init.constant_(module.reference_points.bias, 0.0) elif isinstance(module, OneFormerMLPPredictionHead): for submodule in module.modules(): if isinstance(submodule, nn.Linear): nn.init.xavier_uniform_(submodule.weight, gain=xavier_std) nn.init.constant_(submodule.bias, 0) elif isinstance(module, nn.MultiheadAttention): - module.in_proj_weight.data.normal_(mean=0.0, std=std) - module.in_proj_bias.data.zero_() + module.in_proj_weight.normal_(mean=0.0, std=std) + module.in_proj_bias.zero_() elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, OneFormerLoss): - module.logit_scale.data.fill_(np.log(1 / self.config.contrastive_temperature)) + module.logit_scale.fill_(np.log(1 / self.config.contrastive_temperature)) @auto_docstring diff --git a/src/transformers/models/openai/modeling_openai.py b/src/transformers/models/openai/modeling_openai.py index aebe5074c706..18a12bce9dc8 100644 --- a/src/transformers/models/openai/modeling_openai.py +++ b/src/transformers/models/openai/modeling_openai.py @@ -259,19 +259,20 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel): config: OpenAIGPTConfig base_model_prefix = "transformer" + @torch.no_grad() def _init_weights(self, module): """Initialize the weights.""" if isinstance(module, (nn.Linear, Conv1D)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) @dataclass @@ -416,7 +417,7 @@ def forward( """ ) class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "transformer.tokens_embed.weight"} def __init__(self, config): super().__init__(config) @@ -501,7 +502,7 @@ def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) - """ ) class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"transformer.tokens_embed.weight": "lm_head.weight"} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 9de23d596f3a..2d88858a6c0d 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -304,19 +304,20 @@ class OPTPreTrainedModel(PreTrainedModel): _can_compile_fullgraph = True + @torch.no_grad() def _init_weights(self, module): std = self.config.init_std if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() class OPTDecoder(OPTPreTrainedModel): @@ -717,7 +718,7 @@ def forward( class OPTForCausalLM(OPTPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.decoder.embed_tokens.weight"} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/ovis2/modeling_ovis2.py b/src/transformers/models/ovis2/modeling_ovis2.py index 02a8af5d5865..710f0a5603bf 100644 --- a/src/transformers/models/ovis2/modeling_ovis2.py +++ b/src/transformers/models/ovis2/modeling_ovis2.py @@ -671,7 +671,7 @@ def forward( @auto_docstring class Ovis2ForConditionalGeneration(Ovis2PreTrainedModel, GenerationMixin): _checkpoint_conversion_mapping = {} - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} def __init__(self, config: Ovis2Config): super().__init__(config) diff --git a/src/transformers/models/owlv2/modeling_owlv2.py b/src/transformers/models/owlv2/modeling_owlv2.py index 391470ccb1de..f10631a7071a 100644 --- a/src/transformers/models/owlv2/modeling_owlv2.py +++ b/src/transformers/models/owlv2/modeling_owlv2.py @@ -567,12 +567,13 @@ class Owlv2PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["Owlv2EncoderLayer"] + @torch.no_grad() def _init_weights(self, module: nn.Module): """Initialize the weights""" factor = self.config.initializer_factor if isinstance(module, Owlv2TextEmbeddings): - module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) - module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) + module.token_embedding.weight.normal_(mean=0.0, std=factor * 0.02) + module.position_embedding.weight.normal_(mean=0.0, std=factor * 0.02) elif isinstance(module, Owlv2VisionEmbeddings): nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) @@ -598,14 +599,14 @@ def _init_weights(self, module: nn.Module): module.visual_projection.weight, std=module.vision_embed_dim**-0.5 * factor, ) - module.logit_scale.data.fill_(self.config.logit_scale_init_value) + module.logit_scale.fill_(self.config.logit_scale_init_value) if isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=factor) + module.weight.normal_(mean=0.0, std=factor) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() # Copied from transformers.models.owlvit.modeling_owlvit.OwlViTEncoder with OwlViT->Owlv2 diff --git a/src/transformers/models/owlvit/modeling_owlvit.py b/src/transformers/models/owlvit/modeling_owlvit.py index 0eb4ddbcd445..95cd4ccb6034 100644 --- a/src/transformers/models/owlvit/modeling_owlvit.py +++ b/src/transformers/models/owlvit/modeling_owlvit.py @@ -554,12 +554,13 @@ class OwlViTPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["OwlViTEncoderLayer"] + @torch.no_grad() def _init_weights(self, module: nn.Module): """Initialize the weights""" factor = self.config.initializer_factor if isinstance(module, OwlViTTextEmbeddings): - module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) - module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) + module.token_embedding.weight.normal_(mean=0.0, std=factor * 0.02) + module.position_embedding.weight.normal_(mean=0.0, std=factor * 0.02) elif isinstance(module, OwlViTVisionEmbeddings): nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) @@ -585,14 +586,14 @@ def _init_weights(self, module: nn.Module): module.visual_projection.weight, std=module.vision_embed_dim**-0.5 * factor, ) - module.logit_scale.data.fill_(self.config.logit_scale_init_value) + module.logit_scale.fill_(self.config.logit_scale_init_value) if isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=factor) + module.weight.normal_(mean=0.0, std=factor) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() class OwlViTEncoder(nn.Module): diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index 2779022e3329..dcbe454a9867 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -226,15 +226,16 @@ class PaliGemmaPreTrainedModel(PreTrainedModel): _supports_flex_attn = True _supports_attention_backend = True + @torch.no_grad() def _init_weights(self, module): # important: this ported version of PaliGemmaisn't meant for training from scratch - only # inference and fine-tuning std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() @auto_docstring( @@ -447,7 +448,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi "^multi_modal_projector": "model.multi_modal_projector", "^language_model.lm_head": "lm_head", } - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} def __init__(self, config: PaliGemmaConfig): super().__init__(config) diff --git a/src/transformers/models/parakeet/modeling_parakeet.py b/src/transformers/models/parakeet/modeling_parakeet.py index 34697507ffc7..3c8698c7c9b0 100644 --- a/src/transformers/models/parakeet/modeling_parakeet.py +++ b/src/transformers/models/parakeet/modeling_parakeet.py @@ -455,6 +455,7 @@ class ParakeetPreTrainedModel(PreTrainedModel): "attentions": ParakeetEncoderAttention, } + @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) @@ -466,8 +467,8 @@ def _init_weights(self, module): if isinstance(module, ParakeetEncoderAttention): # Initialize positional bias parameters - module.bias_u.data.normal_(mean=0.0, std=std) - module.bias_v.data.normal_(mean=0.0, std=std) + module.bias_u.normal_(mean=0.0, std=std) + module.bias_v.normal_(mean=0.0, std=std) def _get_subsampling_output_length(self, input_lengths: torch.Tensor): encoder_config = self.config.encoder_config if isinstance(self.config, ParakeetCTCConfig) else self.config diff --git a/src/transformers/models/parakeet/modular_parakeet.py b/src/transformers/models/parakeet/modular_parakeet.py index 6b597e1b50a3..f792b19c9315 100644 --- a/src/transformers/models/parakeet/modular_parakeet.py +++ b/src/transformers/models/parakeet/modular_parakeet.py @@ -331,6 +331,7 @@ class ParakeetPreTrainedModel(PreTrainedModel): "attentions": ParakeetEncoderAttention, } + @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) @@ -342,8 +343,8 @@ def _init_weights(self, module): if isinstance(module, ParakeetEncoderAttention): # Initialize positional bias parameters - module.bias_u.data.normal_(mean=0.0, std=std) - module.bias_v.data.normal_(mean=0.0, std=std) + module.bias_u.normal_(mean=0.0, std=std) + module.bias_v.normal_(mean=0.0, std=std) def _get_subsampling_output_length(self, input_lengths: torch.Tensor): encoder_config = self.config.encoder_config if isinstance(self.config, ParakeetCTCConfig) else self.config diff --git a/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py b/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py index 8cd4ec059473..3402386596d2 100644 --- a/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py +++ b/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py @@ -685,6 +685,7 @@ class PatchTSMixerPreTrainedModel(PreTrainedModel): input_modalities = "time" supports_gradient_checkpointing = False + @torch.no_grad() def _init_weights(self, module): """Initialize weights""" if isinstance(module, PatchTSMixerPositionalEncoding): @@ -692,15 +693,15 @@ def _init_weights(self, module): if self.config.positional_encoding_type == "random": nn.init.normal_(module.position_enc, mean=0.0, std=0.1) elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d)): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, PatchTSMixerBatchNorm): - module.batchnorm.bias.data.zero_() - module.batchnorm.weight.data.fill_(1.0) + module.batchnorm.bias.zero_() + module.batchnorm.weight.fill_(1.0) elif isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.init_std) + module.weight.normal_(mean=0.0, std=self.config.init_std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() class PatchTSMixerPretrainHead(nn.Module): diff --git a/src/transformers/models/patchtst/modeling_patchtst.py b/src/transformers/models/patchtst/modeling_patchtst.py index 6411b8956743..fe99982803d9 100755 --- a/src/transformers/models/patchtst/modeling_patchtst.py +++ b/src/transformers/models/patchtst/modeling_patchtst.py @@ -555,6 +555,7 @@ class PatchTSTPreTrainedModel(PreTrainedModel): input_modalities = "time" supports_gradient_checkpointing = False + @torch.no_grad() def _init_weights(self, module: nn.Module): """ Initialize weights @@ -571,15 +572,15 @@ def _init_weights(self, module: nn.Module): # initialize positional encoding module.position_enc = module._init_pe(self.config, num_patches) elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, PatchTSTBatchNorm): - module.batchnorm.bias.data.zero_() - module.batchnorm.weight.data.fill_(1.0) + module.batchnorm.bias.zero_() + module.batchnorm.weight.fill_(1.0) elif isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.init_std) + module.weight.normal_(mean=0.0, std=self.config.init_std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (PatchTSTEncoder)): diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index a23f45bf8437..e1009cc96e5a 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -438,21 +438,22 @@ class PegasusPreTrainedModel(PreTrainedModel): _can_compile_fullgraph = True + @torch.no_grad() def _init_weights(self, module): std = self.config.init_std if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, PegasusSinusoidalPositionalEmbedding): module._init_weight() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() class PegasusEncoder(PegasusPreTrainedModel): @@ -465,7 +466,7 @@ class PegasusEncoder(PegasusPreTrainedModel): embed_tokens (nn.Embedding): output embedding """ - def __init__(self, config: PegasusConfig, embed_tokens: Optional[nn.Embedding] = None): + def __init__(self, config: PegasusConfig): super().__init__(config) self.dropout = config.dropout @@ -476,10 +477,7 @@ def __init__(self, config: PegasusConfig, embed_tokens: Optional[nn.Embedding] = self.max_source_positions = config.max_position_embeddings self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 - if embed_tokens is not None: - self.embed_tokens = embed_tokens - else: - self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) + self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) self.embed_positions = PegasusSinusoidalPositionalEmbedding( config.max_position_embeddings, @@ -643,7 +641,7 @@ class PegasusDecoder(PegasusPreTrainedModel): embed_tokens (nn.Embedding): output embedding """ - def __init__(self, config: PegasusConfig, embed_tokens: Optional[nn.Embedding] = None): + def __init__(self, config: PegasusConfig): super().__init__(config) self.dropout = config.dropout self.layerdrop = config.decoder_layerdrop @@ -651,10 +649,7 @@ def __init__(self, config: PegasusConfig, embed_tokens: Optional[nn.Embedding] = self.max_target_positions = config.max_position_embeddings self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 - if embed_tokens is not None: - self.embed_tokens = embed_tokens - else: - self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) + self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) self.embed_positions = PegasusSinusoidalPositionalEmbedding( config.max_position_embeddings, @@ -898,7 +893,10 @@ def forward( @auto_docstring class PegasusModel(PegasusPreTrainedModel): - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + _tied_weights_keys = { + "decoder.embed_tokens.weight": "shared.weight", + "encoder.embed_tokens.weight": "shared.weight", + } def __init__(self, config: PegasusConfig): super().__init__(config) @@ -906,8 +904,8 @@ def __init__(self, config: PegasusConfig): padding_idx, vocab_size = config.pad_token_id, config.vocab_size self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) - self.encoder = PegasusEncoder(config, self.shared) - self.decoder = PegasusDecoder(config, self.shared) + self.encoder = PegasusEncoder(config) + self.decoder = PegasusDecoder(config) # Initialize weights and apply final processing self.post_init() @@ -1058,7 +1056,9 @@ def forward( class PegasusForConditionalGeneration(PegasusPreTrainedModel, GenerationMixin): base_model_prefix = "model" _keys_to_ignore_on_load_missing = ["final_logits_bias"] - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + _tied_weights_keys = { + "lm_head.weight": "model.shared.weight", + } def __init__(self, config: PegasusConfig): super().__init__(config) @@ -1242,7 +1242,9 @@ def forward(self, *args, **kwargs): class PegasusForCausalLM(PegasusPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = { + "lm_head.weight": "model.decoder.embed_tokens.weight", + } def __init__(self, config): config = copy.deepcopy(config) diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index d76759e9104c..0e9b8bc1e255 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -747,17 +747,18 @@ class PegasusXPreTrainedModel(PreTrainedModel): _can_compile_fullgraph = True + @torch.no_grad() def _init_weights(self, module): std = self.config.init_std if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() class PegasusXEncoder(PegasusXPreTrainedModel): @@ -770,7 +771,7 @@ class PegasusXEncoder(PegasusXPreTrainedModel): embed_tokens (nn.Embedding): output embedding """ - def __init__(self, config: PegasusXConfig, embed_tokens: Optional[nn.Embedding] = None): + def __init__(self, config: PegasusXConfig): super().__init__(config) self.dropout = config.dropout @@ -781,12 +782,9 @@ def __init__(self, config: PegasusXConfig, embed_tokens: Optional[nn.Embedding] self.max_source_positions = config.max_position_embeddings embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 - if embed_tokens is not None: - self.embed_tokens = embed_tokens - else: - self.embed_tokens = PegasusXScaledWordEmbedding( - config.vocab_size, embed_dim, padding_idx, embed_scale=embed_scale - ) + self.embed_tokens = PegasusXScaledWordEmbedding( + config.vocab_size, embed_dim, padding_idx, embed_scale=embed_scale + ) self.embed_global = nn.Embedding(config.num_global_tokens, embed_dim) self.embed_positions = PegasusXSinusoidalPositionalEmbedding(embed_dim) @@ -972,7 +970,7 @@ class PegasusXDecoder(PegasusXPreTrainedModel): embed_tokens (nn.Embedding): output embedding """ - def __init__(self, config: PegasusXConfig, embed_tokens: Optional[nn.Embedding] = None): + def __init__(self, config: PegasusXConfig): super().__init__(config) self.dropout = config.dropout self.layerdrop = config.decoder_layerdrop @@ -980,12 +978,9 @@ def __init__(self, config: PegasusXConfig, embed_tokens: Optional[nn.Embedding] embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 padding_idx = config.pad_token_id - if embed_tokens is not None: - self.embed_tokens = embed_tokens - else: - self.embed_tokens = PegasusXScaledWordEmbedding( - config.vocab_size, config.d_model, padding_idx=padding_idx, embed_scale=embed_scale - ) + self.embed_tokens = PegasusXScaledWordEmbedding( + config.vocab_size, config.d_model, padding_idx=padding_idx, embed_scale=embed_scale + ) self.embed_positions = PegasusXSinusoidalPositionalEmbedding(config.d_model) self.layers = nn.ModuleList([PegasusXDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]) @@ -1192,7 +1187,10 @@ def forward( @auto_docstring class PegasusXModel(PegasusXPreTrainedModel): - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + _tied_weights_keys = { + "encoder.embed_tokens.weight": "shared.weight", + "decoder.embed_tokens.weight": "shared.weight", + } def __init__(self, config: PegasusXConfig): super().__init__(config) @@ -1204,8 +1202,8 @@ def __init__(self, config: PegasusXConfig): vocab_size, config.d_model, padding_idx=padding_idx, embed_scale=embed_scale ) - self.encoder = PegasusXEncoder(config, self.shared) - self.decoder = PegasusXDecoder(config, self.shared) + self.encoder = PegasusXEncoder(config) + self.decoder = PegasusXDecoder(config) # Initialize weights and apply final processing self.post_init() @@ -1355,7 +1353,9 @@ def forward( ) class PegasusXForConditionalGeneration(PegasusXPreTrainedModel, GenerationMixin): base_model_prefix = "model" - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + _tied_weights_keys = { + "lm_head.weight": "model.shared.weight", + } def __init__(self, config: PegasusXConfig): super().__init__(config) diff --git a/src/transformers/models/perceiver/modeling_perceiver.py b/src/transformers/models/perceiver/modeling_perceiver.py index 0b734c0714ee..4ddad1c5b2c6 100755 --- a/src/transformers/models/perceiver/modeling_perceiver.py +++ b/src/transformers/models/perceiver/modeling_perceiver.py @@ -531,26 +531,27 @@ class PerceiverPreTrainedModel(PreTrainedModel): main_input_name = "inputs" input_modalities = "image" # techinically can be anything but HF impl has only image processor + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif hasattr(module, "latents"): - module.latents.data.normal_(mean=0.0, std=self.config.initializer_range) + module.latents.normal_(mean=0.0, std=self.config.initializer_range) elif hasattr(module, "position_embeddings") and isinstance(module, PerceiverTrainablePositionEncoding): - module.position_embeddings.data.normal_(mean=0.0, std=self.config.initializer_range) + module.position_embeddings.normal_(mean=0.0, std=self.config.initializer_range) elif isinstance(module, nn.ParameterDict): for modality in module: - module[modality].data.normal_(mean=0.0, std=self.config.initializer_range) + module[modality].normal_(mean=0.0, std=self.config.initializer_range) elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) @auto_docstring( diff --git a/src/transformers/models/perception_lm/modeling_perception_lm.py b/src/transformers/models/perception_lm/modeling_perception_lm.py index 9fb7ede3e9f8..0a601deac183 100644 --- a/src/transformers/models/perception_lm/modeling_perception_lm.py +++ b/src/transformers/models/perception_lm/modeling_perception_lm.py @@ -323,7 +323,7 @@ def forward( @auto_docstring class PerceptionLMForConditionalGeneration(PerceptionLMPreTrainedModel, GenerationMixin): _checkpoint_conversion_mapping = {} - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} def __init__(self, config: PerceptionLMConfig): super().__init__(config) diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 205d5b1fc1d7..8bb936c41461 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -429,19 +429,20 @@ class PersimmonPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_attention_backend = True + @torch.no_grad() def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() @auto_docstring @@ -685,7 +686,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( class PersimmonForCausalLM(PersimmonPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with LLAMA->PERSIMMON,Llama->Persimmon def __init__(self, config): diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 3fb8de6e32e3..4a1530b78564 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -459,7 +459,7 @@ def forward( @auto_docstring class PhiForCausalLM(PhiPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index d1ebf1ea99c0..29b3d2847ed1 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -446,7 +446,7 @@ def forward( @auto_docstring class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py index aebf09174575..31ef21fbda1e 100644 --- a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py @@ -322,6 +322,7 @@ class Phi4MultimodalVisionPreTrainedModel(PreTrainedModel): "attentions": Phi4MultimodalVisionAttention, } + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, Phi4MultimodalVisionEmbeddings): @@ -348,16 +349,16 @@ def _init_weights(self, module): nn.init.normal_(module.fc1.bias, std=1e-6) nn.init.normal_(module.fc2.bias, std=1e-6) elif isinstance(module, Phi4MultimodalVisionMultiheadAttentionPoolingHead): - nn.init.normal_(module.probe.data) - nn.init.normal_(module.attention.in_proj_weight.data) - nn.init.zeros_(module.attention.in_proj_bias.data) + nn.init.normal_(module.probe) + nn.init.normal_(module.attention.in_proj_weight) + nn.init.zeros_(module.attention.in_proj_bias) elif isinstance(module, (nn.Linear, nn.Conv2d)): lecun_normal_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) class Phi4MultimodalVisionEmbeddings(nn.Module): @@ -939,11 +940,12 @@ class Phi4MultimodalAudioPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True + @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) if isinstance(module, Phi4MultimodalAudioGluPointWiseConv): - module.b1.data.zero_() - module.b2.data.zero_() + module.b1.zero_() + module.b2.zero_() def unfold_tensor(tensor, max_seq_len): @@ -1497,11 +1499,12 @@ class Phi4MultimodalPreTrainedModel(PreTrainedModel): _version = "0.0.5" input_modalities = ["image", "audio", "text"] + @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) if isinstance(module, Phi4MultimodalImageEmbedding): - module.global_img_feature_extensor.data.zero_() - module.sub_img_feature_extensor.data.zero_() + module.global_img_feature_extensor.zero_() + module.sub_img_feature_extensor.zero_() class Phi4MultimodalRotaryEmbedding(nn.Module): @@ -1690,7 +1693,7 @@ def forward( @auto_docstring class Phi4MultimodalForCausalLM(Phi4MultimodalPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py index 9095c4375c7e..62c7fb50748f 100644 --- a/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py @@ -546,6 +546,7 @@ class Phi4MultimodalVisionPreTrainedModel(SiglipPreTrainedModel): "attentions": Phi4MultimodalVisionAttention, } + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, Phi4MultimodalVisionEmbeddings): @@ -572,16 +573,16 @@ def _init_weights(self, module): nn.init.normal_(module.fc1.bias, std=1e-6) nn.init.normal_(module.fc2.bias, std=1e-6) elif isinstance(module, Phi4MultimodalVisionMultiheadAttentionPoolingHead): - nn.init.normal_(module.probe.data) - nn.init.normal_(module.attention.in_proj_weight.data) - nn.init.zeros_(module.attention.in_proj_bias.data) + nn.init.normal_(module.probe) + nn.init.normal_(module.attention.in_proj_weight) + nn.init.zeros_(module.attention.in_proj_bias) elif isinstance(module, (nn.Linear, nn.Conv2d)): lecun_normal_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) class Phi4MultimodalVisionEmbeddings(SiglipVisionEmbeddings): @@ -1119,11 +1120,12 @@ class Phi4MultimodalAudioPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True + @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) if isinstance(module, Phi4MultimodalAudioGluPointWiseConv): - module.b1.data.zero_() - module.b2.data.zero_() + module.b1.zero_() + module.b2.zero_() class Phi4MultimodalAudioModel(Phi4MultimodalAudioPreTrainedModel): @@ -1441,11 +1443,12 @@ def forward( class Phi4MultimodalPreTrainedModel(Phi3PreTrainedModel): input_modalities = ["image", "audio", "text"] + @torch.no_grad() def _init_weights(self, module): PreTrainedModel._init_weights(self, module) if isinstance(module, Phi4MultimodalImageEmbedding): - module.global_img_feature_extensor.data.zero_() - module.sub_img_feature_extensor.data.zero_() + module.global_img_feature_extensor.zero_() + module.sub_img_feature_extensor.zero_() class Phi4MultimodalModel(Phi3Model): @@ -1563,7 +1566,7 @@ def forward( class Phi4MultimodalForCausalLM(Phi3ForCausalLM): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index 58733405678d..50479af0dac8 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -262,24 +262,6 @@ def forward( return attn_output, attn_weights -class PhimoeMLP(nn.Module): - def __init__(self, config: PhimoeConfig): - super().__init__() - self.ffn_dim = config.intermediate_size - self.hidden_dim = config.hidden_size - - self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) - self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) - self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) - - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, hidden_states): - current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) - current_hidden_states = self.w2(current_hidden_states) - return current_hidden_states - - class PhimoeMultiplier(torch.autograd.Function): @staticmethod def forward( @@ -342,56 +324,44 @@ def backward( ) -class PhimoeExperts(nn.ModuleList): - """ - ModuleList of experts. - """ +class PhimoeExperts(nn.Module): + """Collection of expert weights stored as 3D tensors.""" def __init__(self, config: PhimoeConfig): super().__init__() - self.top_k = config.num_experts_per_tok self.num_experts = config.num_local_experts - for _ in range(self.num_experts): - self.append(PhimoeMLP(config)) + self.hidden_dim = config.hidden_size + self.intermediate_dim = config.intermediate_size + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) + self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) + self.act_fn = ACT2FN[config.hidden_act] def forward( - self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, ) -> torch.Tensor: - """ - Args: - hidden_states: (batch_size * sequence_length, hidden_dim) - top_k_index: (batch_size * sequence_length, top_k) - top_k_weights: (batch_size * sequence_length, top_k) - Returns: - (batch_size * sequence_length, hidden_dim) - """ final_hidden_states = torch.zeros_like(hidden_states) - expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0) + num_experts = top_k_weights.shape[1] + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() for expert_idx in expert_hit: - idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) - current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1]) - current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None] - final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) - return final_hidden_states - - -class PhimoeRouter(nn.Linear): - def __init__(self, config: PhimoeConfig): - super().__init__(config.hidden_size, config.num_local_experts, bias=False) - self.top_k = config.num_experts_per_tok - self.hidden_dim = config.hidden_size - self.router_jitter_noise = config.router_jitter_noise - self.input_jitter_noise = config.router_jitter_noise + expert_idx = expert_idx[0] + if expert_idx == num_experts: + continue + _, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) + current_hidden_states = self.act_fn(gate) * up + current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) + current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None] + final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) - def forward(self, hidden_states): - if self.training and self.input_jitter_noise > 0: - hidden_states *= torch.empty_like(hidden_states).uniform_( - 1.0 - self.input_jitter_noise, 1.0 + self.input_jitter_noise - ) - router_logits = super().forward(hidden_states) - return router_logits + return final_hidden_states def sparsemixer(scores, jitter_eps, training, top_k=2): @@ -517,6 +487,27 @@ def sparsemixer(scores, jitter_eps, training, top_k=2): ) +class PhimoeTopKRouter(nn.Linear): + def __init__(self, config: PhimoeConfig): + super().__init__(config.hidden_size, config.num_local_experts, bias=False) + self.router_jitter_noise = config.router_jitter_noise + self.input_jitter_noise = config.input_jitter_noise + + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + if self.training and self.input_jitter_noise > 0: + hidden_states *= torch.empty_like(hidden_states).uniform_( + 1.0 - self.input_jitter_noise, 1.0 + self.input_jitter_noise + ) + router_logits = super().forward(hidden_states) + routing_weights, selected_experts = sparsemixer( + router_logits, + jitter_eps=self.router_jitter_noise, + training=self.training, + ) + routing_weights = torch.zeros_like(router_logits).scatter_(1, selected_experts, routing_weights) + return routing_weights, selected_experts + + class PhimoeSparseMoeBlock(nn.Module): """ This implementation is @@ -535,19 +526,10 @@ def __init__(self, config): self.ffn_dim = config.intermediate_size self.num_experts = config.num_local_experts self.top_k = config.num_experts_per_tok - self.router_jitter_noise = config.router_jitter_noise - self.gate = PhimoeRouter(config) + self.router = PhimoeTopKRouter(config) self.experts = PhimoeExperts(config) self.input_jitter_noise = config.input_jitter_noise - def route_tokens_to_experts(self, router_logits): - routing_weights, selected_experts = sparsemixer( - router_logits, - jitter_eps=self.router_jitter_noise, - training=self.training, - ) - return routing_weights, selected_experts - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_dim = hidden_states.shape if self.training and self.input_jitter_noise > 0: @@ -557,8 +539,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.reshape(-1, hidden_dim) - router_logits = self.gate(hidden_states) - routing_weights, selected_experts = self.route_tokens_to_experts(router_logits) + routing_weights, selected_experts = self.router(hidden_states) final_hidden_states = self.experts(hidden_states, selected_experts, routing_weights) return final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) @@ -591,7 +572,7 @@ def __init__(self, config: PhimoeConfig, layer_idx: int): self.self_attn = PhimoeAttention(config, layer_idx) - self.block_sparse_moe = PhimoeSparseMoeBlock(config) + self.mlp = PhimoeSparseMoeBlock(config) self.input_layernorm = PhimoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = PhimoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -619,7 +600,7 @@ def forward( hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.block_sparse_moe(hidden_states) + hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states @@ -637,11 +618,21 @@ class PhimoePreTrainedModel(PreTrainedModel): _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) _supports_attention_backend = True _can_record_outputs = { - "router_logits": OutputRecorder(nn.Linear, layer_name="mlp.gate", index=0), + "router_logits": OutputRecorder(PhimoeTopKRouter, layer_name="mlp.router", index=0), "hidden_states": PhimoeDecoderLayer, "attentions": PhimoeAttention, } + @torch.no_grad() + def _init_weights(self, module): + super()._init_weights(module) + std = self.config.initializer_range + if isinstance(module, PhimoeExperts): + module.gate_up_proj.normal_(mean=0.0, std=std) + module.down_proj.normal_(mean=0.0, std=std) + elif isinstance(module, PhimoeTopKRouter): + module.weight.normal_(mean=0.0, std=std) + @auto_docstring class PhimoeModel(PhimoePreTrainedModel): @@ -808,7 +799,7 @@ def load_balancing_loss_func( @auto_docstring class PhimoeForCausalLM(PhimoePreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/phimoe/modular_phimoe.py b/src/transformers/models/phimoe/modular_phimoe.py index 59f5761987b9..76693282256a 100644 --- a/src/transformers/models/phimoe/modular_phimoe.py +++ b/src/transformers/models/phimoe/modular_phimoe.py @@ -30,7 +30,6 @@ MixtralDecoderLayer, MixtralExperts, MixtralForCausalLM, - MixtralMLP, MixtralModel, MixtralPreTrainedModel, MixtralRotaryEmbedding, @@ -87,10 +86,6 @@ class PhimoeAttention(LlamaAttention): pass -class PhimoeMLP(MixtralMLP): - pass - - class PhimoeMultiplier(torch.autograd.Function): @staticmethod def forward( @@ -276,30 +271,29 @@ def sparsemixer(scores, jitter_eps, training, top_k=2): ) -class PhimoeExperts(MixtralExperts, nn.ModuleList): - def __init__(self, config: PhimoeConfig): - nn.ModuleList.__init__(self) - self.top_k = config.num_experts_per_tok - self.num_experts = config.num_local_experts - for _ in range(self.num_experts): - self.append(PhimoeMLP(config)) +class PhimoeExperts(MixtralExperts): + pass -class PhimoeRouter(nn.Linear): +class PhimoeTopKRouter(nn.Linear): def __init__(self, config: PhimoeConfig): super().__init__(config.hidden_size, config.num_local_experts, bias=False) - self.top_k = config.num_experts_per_tok - self.hidden_dim = config.hidden_size self.router_jitter_noise = config.router_jitter_noise - self.input_jitter_noise = config.router_jitter_noise + self.input_jitter_noise = config.input_jitter_noise - def forward(self, hidden_states): + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: if self.training and self.input_jitter_noise > 0: hidden_states *= torch.empty_like(hidden_states).uniform_( 1.0 - self.input_jitter_noise, 1.0 + self.input_jitter_noise ) router_logits = super().forward(hidden_states) - return router_logits + routing_weights, selected_experts = sparsemixer( + router_logits, + jitter_eps=self.router_jitter_noise, + training=self.training, + ) + routing_weights = torch.zeros_like(router_logits).scatter_(1, selected_experts, routing_weights) + return routing_weights, selected_experts class PhimoeSparseMoeBlock(nn.Module): @@ -320,19 +314,10 @@ def __init__(self, config): self.ffn_dim = config.intermediate_size self.num_experts = config.num_local_experts self.top_k = config.num_experts_per_tok - self.router_jitter_noise = config.router_jitter_noise - self.gate = PhimoeRouter(config) + self.router = PhimoeTopKRouter(config) self.experts = PhimoeExperts(config) self.input_jitter_noise = config.input_jitter_noise - def route_tokens_to_experts(self, router_logits): - routing_weights, selected_experts = sparsemixer( - router_logits, - jitter_eps=self.router_jitter_noise, - training=self.training, - ) - return routing_weights, selected_experts - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_dim = hidden_states.shape if self.training and self.input_jitter_noise > 0: @@ -342,8 +327,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.reshape(-1, hidden_dim) - router_logits = self.gate(hidden_states) - routing_weights, selected_experts = self.route_tokens_to_experts(router_logits) + routing_weights, selected_experts = self.router(hidden_states) final_hidden_states = self.experts(hidden_states, selected_experts, routing_weights) return final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) @@ -354,7 +338,7 @@ class PhimoeDecoderLayer(MixtralDecoderLayer): class PhimoePreTrainedModel(MixtralPreTrainedModel): _can_record_outputs = { - "router_logits": OutputRecorder(nn.Linear, layer_name="mlp.gate", index=0), + "router_logits": OutputRecorder(PhimoeTopKRouter, layer_name="mlp.router", index=0), "hidden_states": PhimoeDecoderLayer, "attentions": PhimoeAttention, } diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 09f7e5783b9c..f47e9f005e02 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -350,11 +350,12 @@ def dummy_inputs(self): } return dummy_inputs + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" factor = self.config.initializer_factor # Used for testing weights initialization if isinstance(module, Pix2StructLayerNorm): - module.weight.data.fill_(factor * 1.0) + module.weight.fill_(factor * 1.0) elif isinstance(module, Pix2StructTextDenseGatedActDense): hidden_size = ( self.config.text_config.hidden_size @@ -363,15 +364,15 @@ def _init_weights(self, module): ) d_ff = self.config.text_config.d_ff if isinstance(self.config, Pix2StructConfig) else self.config.d_ff - module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((hidden_size) ** -0.5)) + module.wi_0.weight.normal_(mean=0.0, std=factor * ((hidden_size) ** -0.5)) if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None: - module.wi_0.bias.data.zero_() - module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((hidden_size) ** -0.5)) + module.wi_0.bias.zero_() + module.wi_1.weight.normal_(mean=0.0, std=factor * ((hidden_size) ** -0.5)) if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None: - module.wi_1.bias.data.zero_() - module.wo.weight.data.normal_(mean=0.0, std=factor * ((d_ff) ** -0.5)) + module.wi_1.bias.zero_() + module.wo.weight.normal_(mean=0.0, std=factor * ((d_ff) ** -0.5)) if hasattr(module.wo, "bias") and module.wo.bias is not None: - module.wo.bias.data.zero_() + module.wo.bias.zero_() elif isinstance(module, Pix2StructTextAttention): hidden_size = ( self.config.text_config.hidden_size @@ -387,12 +388,12 @@ def _init_weights(self, module): else self.config.num_heads ) - module.query.weight.data.normal_(mean=0.0, std=factor * ((hidden_size * key_value_proj_dim) ** -0.5)) - module.key.weight.data.normal_(mean=0.0, std=factor * (hidden_size**-0.5)) - module.value.weight.data.normal_(mean=0.0, std=factor * (hidden_size**-0.5)) - module.output.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) + module.query.weight.normal_(mean=0.0, std=factor * ((hidden_size * key_value_proj_dim) ** -0.5)) + module.key.weight.normal_(mean=0.0, std=factor * (hidden_size**-0.5)) + module.value.weight.normal_(mean=0.0, std=factor * (hidden_size**-0.5)) + module.output.weight.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) if module.has_relative_attention_bias: - module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((hidden_size) ** -0.5)) + module.relative_attention_bias.weight.normal_(mean=0.0, std=factor * ((hidden_size) ** -0.5)) elif isinstance(module, nn.Embedding): hidden_size = ( self.config.text_config.hidden_size @@ -400,9 +401,9 @@ def _init_weights(self, module): else self.config.hidden_size ) - module.weight.data.normal_(mean=0.0, std=factor * ((hidden_size) ** -0.5)) + module.weight.normal_(mean=0.0, std=factor * ((hidden_size) ** -0.5)) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, Pix2StructTextModel): hidden_size = ( self.config.text_config.hidden_size @@ -410,22 +411,24 @@ def _init_weights(self, module): else self.config.hidden_size ) - module.lm_head.weight.data.normal_(mean=0.0, std=factor * ((hidden_size) ** -0.5)) + module.lm_head.weight.normal_(mean=0.0, std=factor * ((hidden_size) ** -0.5)) elif isinstance(module, (nn.Linear, nn.Conv2d)): # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid # `trunc_normal_cpu` not implemented in `half` issues - module.weight.data = nn.init.trunc_normal_( - module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range - ).to(module.weight.dtype) + module.weight.copy_( + nn.init.trunc_normal_(module.weight.to(torch.float32), mean=0.0, std=self.config.initializer_range).to( + module.weight.dtype + ) + ) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, Pix2StructLayerNorm): if module.weight is not None: - module.weight.data.fill_(1.0) + module.weight.fill_(1.0) elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right with T5->Pix2Struct def _shift_right(self, input_ids): @@ -958,7 +961,7 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel): config: Pix2StructTextConfig input_modalities = "text" _no_split_modules = ["Pix2StructTextBlock"] - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "embed_tokens.weight"} supports_gradient_checkpointing = True def __init__(self, config): @@ -1319,7 +1322,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position( class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel, GenerationMixin): config: Pix2StructConfig main_input_name = "flattened_patches" - _tied_weights_keys = ["decoder.lm_head.weight"] def __init__(self, config: Pix2StructConfig): super().__init__(config) diff --git a/src/transformers/models/pixtral/modeling_pixtral.py b/src/transformers/models/pixtral/modeling_pixtral.py index 0f237c86beac..f9a408193387 100644 --- a/src/transformers/models/pixtral/modeling_pixtral.py +++ b/src/transformers/models/pixtral/modeling_pixtral.py @@ -441,14 +441,15 @@ class PixtralPreTrainedModel(PreTrainedModel): _supports_flex_attn = True _no_split_modules = ["PixtralAttentionLayer"] + @torch.no_grad() def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, PixtralRMSNorm): - module.weight.data.fill_(1.0) + module.weight.fill_(1.0) def generate_block_attention_mask(patch_embeds_list, tensor): diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 9a80a46f6265..028c22e180f8 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -332,7 +332,7 @@ class PLBartEncoder(PLBartPreTrainedModel): embed_tokens (nn.Embedding): output embedding """ - def __init__(self, config: PLBartConfig, embed_tokens: Optional[nn.Embedding] = None): + def __init__(self, config: PLBartConfig): super().__init__(config) self.dropout = config.dropout @@ -343,12 +343,9 @@ def __init__(self, config: PLBartConfig, embed_tokens: Optional[nn.Embedding] = self.max_source_positions = config.max_position_embeddings embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 - if embed_tokens is not None: - self.embed_tokens = embed_tokens - else: - self.embed_tokens = PLBartScaledWordEmbedding( - config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale - ) + self.embed_tokens = PLBartScaledWordEmbedding( + config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale + ) self.embed_positions = PLBartLearnedPositionalEmbedding( config.max_position_embeddings, @@ -587,7 +584,7 @@ class PLBartDecoder(PLBartPreTrainedModel): embed_tokens (nn.Embedding): output embedding """ - def __init__(self, config: PLBartConfig, embed_tokens: Optional[nn.Embedding] = None): + def __init__(self, config: PLBartConfig): super().__init__(config) self.dropout = config.dropout self.layerdrop = config.decoder_layerdrop @@ -595,12 +592,9 @@ def __init__(self, config: PLBartConfig, embed_tokens: Optional[nn.Embedding] = self.max_target_positions = config.max_position_embeddings embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 - if embed_tokens is not None: - self.embed_tokens = embed_tokens - else: - self.embed_tokens = PLBartScaledWordEmbedding( - config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale - ) + self.embed_tokens = PLBartScaledWordEmbedding( + config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale + ) self.embed_positions = PLBartLearnedPositionalEmbedding( config.max_position_embeddings, @@ -832,7 +826,10 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int): @auto_docstring class PLBartModel(PLBartPreTrainedModel): - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + _tied_weights_keys = { + "encoder.embed_tokens.weight": "shared.weight", + "decoder.embed_tokens.weight": "shared.weight", + } def __init__(self, config: PLBartConfig): super().__init__(config) @@ -841,8 +838,8 @@ def __init__(self, config: PLBartConfig): embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 self.shared = PLBartScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale) - self.encoder = PLBartEncoder(config, self.shared) - self.decoder = PLBartDecoder(config, self.shared) + self.encoder = PLBartEncoder(config) + self.decoder = PLBartDecoder(config) self.init_weights() @@ -854,11 +851,6 @@ def set_input_embeddings(self, value): self.encoder.embed_tokens = self.shared self.decoder.embed_tokens = self.shared - def _tie_weights(self): - if self.config.tie_word_embeddings: - self._tie_embedding_weights(self.encoder.embed_tokens, self.shared) - self._tie_embedding_weights(self.decoder.embed_tokens, self.shared) - def get_encoder(self): return self.encoder @@ -968,7 +960,9 @@ def forward( class PLBartForConditionalGeneration(PLBartPreTrainedModel, GenerationMixin): base_model_prefix = "model" _keys_to_ignore_on_load_missing = ["final_logits_bias"] - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + _tied_weights_keys = { + "lm_head.weight": "model.shared.weight", + } def __init__(self, config: PLBartConfig): super().__init__(config) @@ -1145,8 +1139,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """ ) class PLBartForSequenceClassification(PLBartPreTrainedModel): - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] - def __init__(self, config: PLBartConfig, **kwargs): super().__init__(config, **kwargs) self.model = PLBartModel(config) @@ -1296,7 +1288,9 @@ def forward(self, *args, **kwargs): """ ) class PLBartForCausalLM(PLBartPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = { + "lm_head.weight": "model.decoder.embed_tokens.weight", + } def __init__(self, config): config.is_decoder = True diff --git a/src/transformers/models/plbart/modular_plbart.py b/src/transformers/models/plbart/modular_plbart.py index 0d17549a2d00..e67705ef697b 100644 --- a/src/transformers/models/plbart/modular_plbart.py +++ b/src/transformers/models/plbart/modular_plbart.py @@ -67,7 +67,10 @@ class PLBartDecoder(BartDecoder): @auto_docstring class PLBartModel(PLBartPreTrainedModel): - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + _tied_weights_keys = { + "encoder.embed_tokens.weight": "shared.weight", + "decoder.embed_tokens.weight": "shared.weight", + } def __init__(self, config: PLBartConfig): super().__init__(config) @@ -76,8 +79,8 @@ def __init__(self, config: PLBartConfig): embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 self.shared = PLBartScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale) - self.encoder = PLBartEncoder(config, self.shared) - self.decoder = PLBartDecoder(config, self.shared) + self.encoder = PLBartEncoder(config) + self.decoder = PLBartDecoder(config) self.init_weights() @@ -89,11 +92,6 @@ def set_input_embeddings(self, value): self.encoder.embed_tokens = self.shared self.decoder.embed_tokens = self.shared - def _tie_weights(self): - if self.config.tie_word_embeddings: - self._tie_embedding_weights(self.encoder.embed_tokens, self.shared) - self._tie_embedding_weights(self.decoder.embed_tokens, self.shared) - def get_encoder(self): return self.encoder @@ -203,7 +201,9 @@ def forward( class PLBartForConditionalGeneration(PLBartPreTrainedModel, GenerationMixin): base_model_prefix = "model" _keys_to_ignore_on_load_missing = ["final_logits_bias"] - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + _tied_weights_keys = { + "lm_head.weight": "model.shared.weight", + } def __init__(self, config: PLBartConfig): super().__init__(config) diff --git a/src/transformers/models/poolformer/modeling_poolformer.py b/src/transformers/models/poolformer/modeling_poolformer.py index a32b6dde21b5..0e7dc6fe24f0 100755 --- a/src/transformers/models/poolformer/modeling_poolformer.py +++ b/src/transformers/models/poolformer/modeling_poolformer.py @@ -245,19 +245,20 @@ class PoolFormerPreTrainedModel(PreTrainedModel): input_modalities = "image" _no_split_modules = ["PoolFormerLayer"] + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.GroupNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, PoolFormerLayer): if hasattr(module, "layer_scale_1"): - module.layer_scale_1.data.fill_(self.config.layer_scale_init_value) - module.layer_scale_2.data.fill_(self.config.layer_scale_init_value) + module.layer_scale_1.fill_(self.config.layer_scale_init_value) + module.layer_scale_2.fill_(self.config.layer_scale_init_value) @auto_docstring diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index 0fe560260d78..ea0bee57e157 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -544,44 +544,45 @@ class Pop2PianoPreTrainedModel(PreTrainedModel): _no_split_modules = ["Pop2PianoBlock"] _keep_in_fp32_modules = ["wo"] + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" factor = self.config.initializer_factor # Used for testing weights initialization if isinstance(module, Pop2PianoLayerNorm): - module.weight.data.fill_(factor * 1.0) + module.weight.fill_(factor * 1.0) elif isinstance(module, Pop2PianoConcatEmbeddingToMel): - module.embedding.weight.data.normal_(mean=0.0, std=factor * 1.0) + module.embedding.weight.normal_(mean=0.0, std=factor * 1.0) elif isinstance(module, Pop2PianoForConditionalGeneration): - module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) + module.shared.weight.normal_(mean=0.0, std=factor * 1.0) if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: - module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0) + module.lm_head.weight.normal_(mean=0.0, std=factor * 1.0) elif isinstance(module, Pop2PianoDenseActDense): - module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + module.wi.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.wi, "bias") and module.wi.bias is not None: - module.wi.bias.data.zero_() - module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + module.wi.bias.zero_() + module.wo.weight.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) if hasattr(module.wo, "bias") and module.wo.bias is not None: - module.wo.bias.data.zero_() + module.wo.bias.zero_() elif isinstance(module, Pop2PianoDenseGatedActDense): - module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + module.wi_0.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None: - module.wi_0.bias.data.zero_() - module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + module.wi_0.bias.zero_() + module.wi_1.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None: - module.wi_1.bias.data.zero_() - module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + module.wi_1.bias.zero_() + module.wo.weight.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) if hasattr(module.wo, "bias") and module.wo.bias is not None: - module.wo.bias.data.zero_() + module.wo.bias.zero_() elif isinstance(module, Pop2PianoAttention): d_model = self.config.d_model key_value_proj_dim = self.config.d_kv n_heads = self.config.num_heads - module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) - module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) - module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) - module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) + module.q.weight.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) + module.k.weight.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.v.weight.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.o.weight.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) if module.has_relative_attention_bias: - module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) + module.relative_attention_bias.weight.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id @@ -606,10 +607,10 @@ def _shift_right(self, input_ids): class Pop2PianoStack(Pop2PianoPreTrainedModel): # Copied from transformers.models.t5.modeling_t5.T5Stack.__init__ with T5->Pop2Piano,t5->pop2piano - def __init__(self, config, embed_tokens=None): + def __init__(self, config): super().__init__(config) - self.embed_tokens = embed_tokens + self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model) self.is_decoder = config.is_decoder self.block = nn.ModuleList( @@ -943,7 +944,11 @@ def forward(self, feature, index_value, embedding_offset): """ ) class Pop2PianoForConditionalGeneration(Pop2PianoPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + _tied_weights_keys = { + "encoder.embed_tokens.weight": "shared.weight", + "decoder.embed_tokens.weight": "shared.weight", + "lm_head.weight": "shared.weight", + } def __init__(self, config: Pop2PianoConfig): super().__init__(config) @@ -959,13 +964,13 @@ def __init__(self, config: Pop2PianoConfig): encoder_config.use_cache = False encoder_config.tie_encoder_decoder = False - self.encoder = Pop2PianoStack(encoder_config, self.shared) + self.encoder = Pop2PianoStack(encoder_config) decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True decoder_config.tie_encoder_decoder = False decoder_config.num_layers = config.num_decoder_layers - self.decoder = Pop2PianoStack(decoder_config, self.shared) + self.decoder = Pop2PianoStack(decoder_config) self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index 8cc5eae250bc..7356740348e1 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -332,15 +332,16 @@ class ProphetNetPreTrainedModel(PreTrainedModel): base_model_prefix = "prophetnet" supports_gradient_checkpointing = True + @torch.no_grad() def _init_weights(self, module): if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.init_std) + module.weight.normal_(mean=0.0, std=self.config.init_std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.init_std) + module.weight.normal_(mean=0.0, std=self.config.init_std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id @@ -975,7 +976,7 @@ def forward( """ ) class ProphetNetEncoder(ProphetNetPreTrainedModel): - def __init__(self, config: ProphetNetConfig, word_embeddings: Optional[nn.Embedding] = None): + def __init__(self, config: ProphetNetConfig): r""" word_embeddings (`torch.nn.Embeddings` of shape `(config.vocab_size, config.hidden_size)`, *optional*): The word embedding parameters. This can be used to initialize [`ProphetNetEncoder`] with pre-defined word @@ -983,11 +984,7 @@ def __init__(self, config: ProphetNetConfig, word_embeddings: Optional[nn.Embedd """ super().__init__(config) - self.word_embeddings = ( - word_embeddings - if word_embeddings is not None - else nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) - ) + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) self.position_embeddings = ProphetNetPositionalEmbeddings(config) self.embeddings_layer_norm = LayerNorm(config.hidden_size) @@ -1090,7 +1087,7 @@ def forward( """ ) class ProphetNetDecoder(ProphetNetPreTrainedModel): - def __init__(self, config: ProphetNetConfig, word_embeddings: Optional[nn.Embedding] = None): + def __init__(self, config: ProphetNetConfig): r""" word_embeddings (`torch.nn.Embeddings` of shape `(config.vocab_size, config.hidden_size)`, *optional*): The word embedding parameters. This can be used to initialize [`ProphetNetEncoder`] with pre-defined word @@ -1104,11 +1101,7 @@ def __init__(self, config: ProphetNetConfig, word_embeddings: Optional[nn.Embedd self.dropout = config.dropout self.max_target_positions = config.max_position_embeddings - self.word_embeddings = ( - word_embeddings - if word_embeddings is not None - else nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) - ) + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) self.position_embeddings = ProphetNetPositionalEmbeddings(config) self.ngram_embeddings = nn.Embedding(self.ngram, config.hidden_size, None) @@ -1400,7 +1393,10 @@ def prepare_predict_attention_mask(self, hidden_states, attention_mask): @auto_docstring class ProphetNetModel(ProphetNetPreTrainedModel): - _tied_weights_keys = ["encoder.word_embeddings.weight", "decoder.word_embeddings.weight"] + _tied_weights_keys = { + "encoder.word_embeddings.weight": "word_embeddings.weight", + "decoder.word_embeddings.weight": "word_embeddings.weight", + } def __init__(self, config: ProphetNetConfig): super().__init__(config) @@ -1409,12 +1405,12 @@ def __init__(self, config: ProphetNetConfig): encoder_config = copy.deepcopy(config) encoder_config.use_cache = False encoder_config.tie_encoder_decoder = False - self.encoder = ProphetNetEncoder(encoder_config, self.word_embeddings) + self.encoder = ProphetNetEncoder(encoder_config) decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True decoder_config.tie_encoder_decoder = False - self.decoder = ProphetNetDecoder(decoder_config, self.word_embeddings) + self.decoder = ProphetNetDecoder(decoder_config) # Initialize weights and apply final processing self.post_init() @@ -1427,11 +1423,6 @@ def set_input_embeddings(self, value): self.encoder.word_embeddings = self.word_embeddings self.decoder.word_embeddings = self.word_embeddings - def _tie_weights(self): - if self.config.tie_word_embeddings: - self._tie_embedding_weights(self.encoder.word_embeddings, self.word_embeddings) - self._tie_embedding_weights(self.decoder.word_embeddings, self.word_embeddings) - def get_encoder(self): return self.encoder @@ -1540,7 +1531,9 @@ def forward( """ ) class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["encoder.word_embeddings.weight", "decoder.word_embeddings.weight", "lm_head.weight"] + _tied_weights_keys = { + "lm_head.weight": "prophetnet.word_embeddings.weight", + } def __init__(self, config: ProphetNetConfig): super().__init__(config) @@ -1553,10 +1546,6 @@ def __init__(self, config: ProphetNetConfig): # Initialize weights and apply final processing self.post_init() - def _tie_weights(self): - if self.config.tie_word_embeddings: - self._tie_embedding_weights(self.prophetnet.word_embeddings, self.lm_head) - def get_input_embeddings(self): return self.prophetnet.word_embeddings @@ -1718,11 +1707,9 @@ def get_decoder(self): """ ) class ProphetNetForCausalLM(ProphetNetPreTrainedModel, GenerationMixin): - _tied_weights_keys = [ - "prophetnet.word_embeddings.weight", - "prophetnet.decoder.word_embeddings.weight", - "lm_head.weight", - ] + _tied_weights_keys = { + "lm_head.weight": "prophetnet.word_embeddings.weight", + } def __init__(self, config: ProphetNetConfig): # set config for CLM @@ -1746,10 +1733,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.prophetnet.decoder.word_embeddings = value - def _tie_weights(self): - if self.config.tie_word_embeddings: - self._tie_embedding_weights(self.prophetnet.decoder.word_embeddings, self.lm_head) - def set_decoder(self, decoder): self.prophetnet.decoder = decoder @@ -1928,18 +1911,19 @@ class ProphetNetDecoderWrapper(ProphetNetPreTrainedModel): classes. """ + _tied_weights_keys = { + "decoder.word_embeddings.weight": "word_embeddings.weight", + } + def __init__(self, config: ProphetNetConfig): super().__init__(config) self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) - self.decoder = ProphetNetDecoder(config, word_embeddings=self.word_embeddings) + self.decoder = ProphetNetDecoder(config) # Initialize weights and apply final processing self.post_init() - def _tie_weights(self): - self._tie_embedding_weights(self.word_embeddings, self.decoder.get_input_embeddings()) - def forward(self, *args, **kwargs): return self.decoder(*args, **kwargs) diff --git a/src/transformers/models/pvt/modeling_pvt.py b/src/transformers/models/pvt/modeling_pvt.py index 4abde5266d11..2a296a5e09e8 100755 --- a/src/transformers/models/pvt/modeling_pvt.py +++ b/src/transformers/models/pvt/modeling_pvt.py @@ -421,30 +421,35 @@ class PvtPreTrainedModel(PreTrainedModel): input_modalities = "image" _no_split_modules = [] + @torch.no_grad() def _init_weights(self, module: nn.Module) -> None: """Initialize the weights""" std = self.config.initializer_range if isinstance(module, (nn.Linear, nn.Conv2d)): # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid # `trunc_normal_cpu` not implemented in `half` issues - nn.init.trunc_normal_(module.weight.data, mean=0.0, std=std) + nn.init.trunc_normal_(module.weight, mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, PvtPatchEmbeddings): - module.position_embeddings.data = nn.init.trunc_normal_( - module.position_embeddings.data, - mean=0.0, - std=std, - ) - if module.cls_token is not None: - module.cls_token.data = nn.init.trunc_normal_( - module.cls_token.data, + module.position_embeddings.copy_( + nn.init.trunc_normal_( + module.position_embeddings, mean=0.0, std=std, ) + ) + if module.cls_token is not None: + module.cls_token.copy_( + nn.init.trunc_normal_( + module.cls_token, + mean=0.0, + std=std, + ) + ) @auto_docstring diff --git a/src/transformers/models/pvt_v2/modeling_pvt_v2.py b/src/transformers/models/pvt_v2/modeling_pvt_v2.py index 113a4a14bd95..010e91b9d479 100644 --- a/src/transformers/models/pvt_v2/modeling_pvt_v2.py +++ b/src/transformers/models/pvt_v2/modeling_pvt_v2.py @@ -368,23 +368,24 @@ class PvtV2PreTrainedModel(PreTrainedModel): input_modalities = "image" supports_gradient_checkpointing = True + @torch.no_grad() def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" if isinstance(module, nn.Linear): # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid # `trunc_normal_cpu` not implemented in `half` issues - module.weight.data = nn.init.trunc_normal_(module.weight.data, mean=0.0, std=self.config.initializer_range) + module.weight.copy_(nn.init.trunc_normal_(module.weight, mean=0.0, std=self.config.initializer_range)) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, nn.Conv2d): fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels fan_out //= module.groups - module.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + module.weight.normal_(0, math.sqrt(2.0 / fan_out)) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() @auto_docstring diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 59e038eb2552..1215f3677603 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -427,7 +427,7 @@ def forward( @auto_docstring class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index 80b23721431d..77bc48a1e19d 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -1693,7 +1693,7 @@ def forward( class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForConditionalGeneration, GenerationMixin): config: Qwen2_5OmniThinkerConfig base_model_prefix = "thinker" - _tied_weights_keys = ["model.embed_tokens.weight", "lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _no_split_modules = ["Qwen2_5OmniAudioEncoder", "Qwen2_5OmniVisionEncoder"] def __init__(self, config: Qwen2_5OmniThinkerConfig): diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index 329e1b798dd6..673da8201fed 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -2057,7 +2057,7 @@ def __init__(self, config: Qwen2_5OmniTextConfig): class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForConditionalGeneration, GenerationMixin): config: Qwen2_5OmniThinkerConfig base_model_prefix = "thinker" - _tied_weights_keys = ["model.embed_tokens.weight", "lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _no_split_modules = ["Qwen2_5OmniAudioEncoder", "Qwen2_5OmniVisionEncoder"] def __init__(self, config: Qwen2_5OmniThinkerConfig): diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 0e6e07ff54c1..1a24d18939bb 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -1373,7 +1373,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi "^visual": "model.visual", r"^model(?!\.(language_model|visual))": "model.language_model", } - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} # Reference: fix gemma3 grad acc #37208 accepts_loss_kwargs = False diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py index 736d67b1a2ad..fb84ea711ea4 100644 --- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py @@ -257,6 +257,7 @@ class Qwen2AudioPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True + @torch.no_grad() def _init_weights(self, module): # important: this ported version of Qwen2Audio isn't meant for training from scratch - only # inference and fine-tuning - so the proper init weights code has been removed @@ -267,16 +268,16 @@ def _init_weights(self, module): ) if isinstance(module, (nn.Linear, nn.Conv1d)): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() @auto_docstring( @@ -460,8 +461,6 @@ def __init__(self, config: Qwen2AudioConfig): self.multi_modal_projector = Qwen2AudioMultiModalProjector(config) self.vocab_size = config.text_config.vocab_size self.language_model = AutoModelForCausalLM.from_config(config.text_config) - if self.language_model._tied_weights_keys is not None: - self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys] self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 self._padding_side = "left" # set it to left by default, user can use setter to change padding_sides diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index d1e309f612c6..bf642609c9fe 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -289,66 +289,80 @@ def forward( return attn_output, attn_weights -class Qwen2MoeExperts(nn.ModuleList): - """ - ModuleList of experts. - """ +class Qwen2MoeExperts(nn.Module): + """Collection of expert weights stored as 3D tensors.""" def __init__(self, config): super().__init__() self.num_experts = config.num_experts - for _ in range(config.num_experts): - self.append(Qwen2MoeMLP(config, intermediate_size=config.moe_intermediate_size)) + self.hidden_dim = config.hidden_size + self.intermediate_dim = config.moe_intermediate_size + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) + self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) + self.act_fn = ACT2FN[config.hidden_act] def forward( - self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, ) -> torch.Tensor: - """ - Args: - hidden_states: (batch_size * sequence_length, hidden_dim) - top_k_index: (batch_size * sequence_length, top_k) - top_k_weights: (batch_size * sequence_length, top_k) - Returns: - (batch_size * sequence_length, hidden_dim) - """ final_hidden_states = torch.zeros_like(hidden_states) - expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0) + num_experts = top_k_weights.shape[1] + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() for expert_idx in expert_hit: - idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) - current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1]) - current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None] - final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + expert_idx = expert_idx[0] + if expert_idx == num_experts: + continue + _, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) + current_hidden_states = self.act_fn(gate) * up + current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) + current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None] + final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) + return final_hidden_states -class Qwen2MoeSparseMoeBlock(nn.Module): +class Qwen2MoeTopKRouter(nn.Module): def __init__(self, config): super().__init__() - # gating - self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) - self.experts = Qwen2MoeExperts(config) - self.num_experts_per_tok = config.num_experts_per_tok + self.top_k = config.num_experts_per_tok + self.num_experts = config.num_experts self.norm_topk_prob = config.norm_topk_prob + self.hidden_dim = config.hidden_size + self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim)) + + def forward(self, hidden_states): + hidden_states = hidden_states.reshape(-1, self.hidden_dim) + router_logits = F.linear(hidden_states, self.weight) # (seq_len, num_experts) + router_logits = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1) + router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) + if self.norm_topk_prob: + router_top_value /= router_top_value.sum(dim=-1, keepdim=True) + router_top_value = router_top_value.to(router_logits.dtype) + router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) + return router_scores, router_indices + +class Qwen2MoeSparseMoeBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.gate = Qwen2MoeTopKRouter(config) + self.experts = Qwen2MoeExperts(config) self.shared_expert = Qwen2MoeMLP(config, intermediate_size=config.shared_expert_intermediate_size) self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False) - def route_tokens_to_experts(self, hidden_states, router_logits): - routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float) - routing_weights, selected_experts = torch.topk(routing_weights, self.num_experts_per_tok, dim=-1) - if self.norm_topk_prob: - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - routing_weights = routing_weights.to(router_logits.dtype) - return selected_experts, routing_weights - def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states_reshaped = hidden_states.view(-1, hidden_dim) shared_expert_output = self.shared_expert(hidden_states_reshaped) - router_logits = self.gate(hidden_states_reshaped) - selected_experts, routing_weights = self.route_tokens_to_experts(hidden_states_reshaped, router_logits) + routing_weights, selected_experts = self.gate(hidden_states_reshaped) expert_output = self.experts(hidden_states_reshaped, selected_experts, routing_weights) shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states_reshaped)) * shared_expert_output @@ -419,11 +433,21 @@ class Qwen2MoePreTrainedModel(PreTrainedModel): _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) _supports_attention_backend = True _can_record_outputs = { - "router_logits": OutputRecorder(nn.Linear, layer_name="mlp.gate", index=0), + "router_logits": OutputRecorder(Qwen2MoeTopKRouter, index=0), "hidden_states": Qwen2MoeDecoderLayer, "attentions": Qwen2MoeAttention, } + @torch.no_grad() + def _init_weights(self, module): + super()._init_weights(module) + std = self.config.initializer_range + if isinstance(module, Qwen2MoeExperts): + module.gate_up_proj.normal_(mean=0.0, std=std) + module.down_proj.normal_(mean=0.0, std=std) + elif isinstance(module, Qwen2MoeTopKRouter): + module.weight.normal_(mean=0.0, std=std) + @auto_docstring class Qwen2MoeModel(Qwen2MoePreTrainedModel): @@ -597,7 +621,7 @@ def load_balancing_loss_func( @auto_docstring class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/qwen2_moe/modular_qwen2_moe.py b/src/transformers/models/qwen2_moe/modular_qwen2_moe.py index 56c100f94b93..fa33b78c42f5 100644 --- a/src/transformers/models/qwen2_moe/modular_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modular_qwen2_moe.py @@ -82,40 +82,47 @@ def __init__(self, config: Qwen2MoeConfig, layer_idx: int): self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) -class Qwen2MoeExperts(MixtralExperts, nn.Module): +class Qwen2MoeExperts(MixtralExperts): def __init__(self, config): - nn.ModuleList.__init__(self) + super().__init__(config) self.num_experts = config.num_experts - for _ in range(config.num_experts): - self.append(Qwen2MoeMLP(config, intermediate_size=config.moe_intermediate_size)) + self.intermediate_dim = config.moe_intermediate_size -class Qwen2MoeSparseMoeBlock(nn.Module): +class Qwen2MoeTopKRouter(nn.Module): def __init__(self, config): super().__init__() - # gating - self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) - self.experts = Qwen2MoeExperts(config) - self.num_experts_per_tok = config.num_experts_per_tok + self.top_k = config.num_experts_per_tok + self.num_experts = config.num_experts self.norm_topk_prob = config.norm_topk_prob + self.hidden_dim = config.hidden_size + self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim)) + + def forward(self, hidden_states): + hidden_states = hidden_states.reshape(-1, self.hidden_dim) + router_logits = F.linear(hidden_states, self.weight) # (seq_len, num_experts) + router_logits = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1) + router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) + if self.norm_topk_prob: + router_top_value /= router_top_value.sum(dim=-1, keepdim=True) + router_top_value = router_top_value.to(router_logits.dtype) + router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) + return router_scores, router_indices + +class Qwen2MoeSparseMoeBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.gate = Qwen2MoeTopKRouter(config) + self.experts = Qwen2MoeExperts(config) self.shared_expert = Qwen2MoeMLP(config, intermediate_size=config.shared_expert_intermediate_size) self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False) - def route_tokens_to_experts(self, hidden_states, router_logits): - routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float) - routing_weights, selected_experts = torch.topk(routing_weights, self.num_experts_per_tok, dim=-1) - if self.norm_topk_prob: - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - routing_weights = routing_weights.to(router_logits.dtype) - return selected_experts, routing_weights - def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states_reshaped = hidden_states.view(-1, hidden_dim) shared_expert_output = self.shared_expert(hidden_states_reshaped) - router_logits = self.gate(hidden_states_reshaped) - selected_experts, routing_weights = self.route_tokens_to_experts(hidden_states_reshaped, router_logits) + routing_weights, selected_experts = self.gate(hidden_states_reshaped) expert_output = self.experts(hidden_states_reshaped, selected_experts, routing_weights) shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states_reshaped)) * shared_expert_output @@ -143,7 +150,7 @@ def __init__(self, config: Qwen2MoeConfig, layer_idx: int): @auto_docstring class Qwen2MoePreTrainedModel(MixtralPreTrainedModel): _can_record_outputs = { - "router_logits": OutputRecorder(nn.Linear, layer_name="mlp.gate", index=0), + "router_logits": OutputRecorder(Qwen2MoeTopKRouter, index=0), "hidden_states": Qwen2MoeDecoderLayer, "attentions": Qwen2MoeAttention, } @@ -230,7 +237,7 @@ def forward( class Qwen2MoeForCausalLM(MixtralForCausalLM, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index d0074b1662e6..c1b52ff75f9f 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -1273,7 +1273,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin): "^visual": "model.visual", r"^model(?!\.(language_model|visual))": "model.language_model", } - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index 1973de1b19ef..5f0f8974eb0a 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -453,7 +453,7 @@ def forward( @auto_docstring class Qwen3ForCausalLM(Qwen3PreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py index ff0855c223ee..e709a7d84709 100644 --- a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -209,61 +209,77 @@ def forward(self, x): return down_proj -class Qwen3MoeExperts(nn.ModuleList): - """ - ModuleList of experts. - """ +class Qwen3MoeExperts(nn.Module): + """Collection of expert weights stored as 3D tensors.""" - def __init__(self, config: Qwen3MoeConfig): + def __init__(self, config): super().__init__() self.num_experts = config.num_experts - for _ in range(self.num_experts): - self.append(Qwen3MoeMLP(config, intermediate_size=config.moe_intermediate_size)) + self.hidden_dim = config.hidden_size + self.intermediate_dim = config.moe_intermediate_size + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) + self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) + self.act_fn = ACT2FN[config.hidden_act] def forward( - self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, ) -> torch.Tensor: - """ - Args: - hidden_states: (batch_size * sequence_length, hidden_dim) - top_k_index: (batch_size * sequence_length, top_k) - top_k_weights: (batch_size * sequence_length, top_k) - Returns: - (batch_size * sequence_length, hidden_dim) - """ final_hidden_states = torch.zeros_like(hidden_states) - expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0) + num_experts = top_k_weights.shape[1] + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() for expert_idx in expert_hit: - idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) - current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1]) - current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None] - final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + expert_idx = expert_idx[0] + if expert_idx == num_experts: + continue + _, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) + current_hidden_states = self.act_fn(gate) * up + current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) + current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None] + final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) + return final_hidden_states -class Qwen3MoeSparseMoeBlock(nn.Module): - def __init__(self, config: Qwen3MoeConfig): +class Qwen3MoeTopKRouter(nn.Module): + def __init__(self, config): super().__init__() - self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) - self.experts = Qwen3MoeExperts(config) - self.num_experts_per_tok = config.num_experts_per_tok + self.top_k = config.num_experts_per_tok + self.num_experts = config.num_experts self.norm_topk_prob = config.norm_topk_prob + self.hidden_dim = config.hidden_size + self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim)) - def route_tokens_to_experts(self, hidden_states, router_logits): - routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float) - routing_weights, selected_experts = torch.topk(routing_weights, self.num_experts_per_tok, dim=-1) + def forward(self, hidden_states): + hidden_states = hidden_states.reshape(-1, self.hidden_dim) + router_logits = F.linear(hidden_states, self.weight) # (seq_len, num_experts) + router_logits = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1) + router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) if self.norm_topk_prob: - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - routing_weights = routing_weights.to(router_logits.dtype) - return selected_experts, routing_weights + router_top_value /= router_top_value.sum(dim=-1, keepdim=True) + router_top_value = router_top_value.to(router_logits.dtype) + router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) + return router_scores, router_indices + + +class Qwen3MoeSparseMoeBlock(nn.Module): + def __init__(self, config: Qwen3MoeConfig): + super().__init__() + self.experts = Qwen3MoeExperts(config) + self.router = Qwen3MoeTopKRouter(config) def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states_reshaped = hidden_states.view(-1, hidden_dim) - router_logits = self.gate(hidden_states_reshaped) - selected_experts, routing_weights = self.route_tokens_to_experts(hidden_states_reshaped, router_logits) + routing_weights, selected_experts = self.router(hidden_states_reshaped) final_hidden_states = self.experts(hidden_states_reshaped, selected_experts, routing_weights) return final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) @@ -350,11 +366,21 @@ class Qwen3MoePreTrainedModel(PreTrainedModel): _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) _supports_attention_backend = True _can_record_outputs = { - "router_logits": OutputRecorder(nn.Linear, layer_name="mlp.gate", index=0), + "router_logits": OutputRecorder(Qwen3MoeTopKRouter, layer_name="mlp.router", index=0), "hidden_states": Qwen3MoeDecoderLayer, "attentions": Qwen3MoeAttention, } + @torch.no_grad() + def _init_weights(self, module): + super()._init_weights(module) + std = self.config.initializer_range + if isinstance(module, Qwen3MoeExperts): + module.gate_up_proj.normal_(mean=0.0, std=std) + module.down_proj.normal_(mean=0.0, std=std) + elif isinstance(module, Qwen3MoeTopKRouter): + module.weight.normal_(mean=0.0, std=std) + class Qwen3MoeRotaryEmbedding(nn.Module): inv_freq: torch.Tensor # fix linting for `register_buffer` @@ -586,7 +612,7 @@ def load_balancing_loss_func( @auto_docstring class Qwen3MoeForCausalLM(Qwen3MoePreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/qwen3_moe/modular_qwen3_moe.py b/src/transformers/models/qwen3_moe/modular_qwen3_moe.py index 87a4bbfa9625..6f4d5c53b820 100644 --- a/src/transformers/models/qwen3_moe/modular_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modular_qwen3_moe.py @@ -17,7 +17,6 @@ from typing import Optional, Union import torch -import torch.nn.functional as F from torch import nn from ...cache_utils import Cache @@ -32,13 +31,12 @@ LlamaRMSNorm, ) from ..mixtral.modeling_mixtral import ( - MixtralExperts, MixtralForCausalLM, MixtralModel, MixtralPreTrainedModel, load_balancing_loss_func, ) -from ..qwen2_moe.modeling_qwen2_moe import Qwen2MoeDecoderLayer, Qwen2MoeMLP +from ..qwen2_moe.modeling_qwen2_moe import Qwen2MoeDecoderLayer, Qwen2MoeExperts, Qwen2MoeMLP, Qwen2MoeTopKRouter from ..qwen3.modeling_qwen3 import Qwen3Attention from .configuration_qwen3_moe import Qwen3MoeConfig @@ -57,35 +55,24 @@ class Qwen3MoeMLP(Qwen2MoeMLP): pass -class Qwen3MoeExperts(MixtralExperts, nn.ModuleList): - def __init__(self, config: Qwen3MoeConfig): - nn.ModuleList.__init__(self) - self.num_experts = config.num_experts - for _ in range(self.num_experts): - self.append(Qwen3MoeMLP(config, intermediate_size=config.moe_intermediate_size)) +class Qwen3MoeExperts(Qwen2MoeExperts): + pass + + +class Qwen3MoeTopKRouter(Qwen2MoeTopKRouter): + pass class Qwen3MoeSparseMoeBlock(nn.Module): def __init__(self, config: Qwen3MoeConfig): super().__init__() - self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) self.experts = Qwen3MoeExperts(config) - self.num_experts_per_tok = config.num_experts_per_tok - self.norm_topk_prob = config.norm_topk_prob - - def route_tokens_to_experts(self, hidden_states, router_logits): - routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float) - routing_weights, selected_experts = torch.topk(routing_weights, self.num_experts_per_tok, dim=-1) - if self.norm_topk_prob: - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - routing_weights = routing_weights.to(router_logits.dtype) - return selected_experts, routing_weights + self.router = Qwen3MoeTopKRouter(config) def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states_reshaped = hidden_states.view(-1, hidden_dim) - router_logits = self.gate(hidden_states_reshaped) - selected_experts, routing_weights = self.route_tokens_to_experts(hidden_states_reshaped, router_logits) + routing_weights, selected_experts = self.router(hidden_states_reshaped) final_hidden_states = self.experts(hidden_states_reshaped, selected_experts, routing_weights) return final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) @@ -100,7 +87,7 @@ class Qwen3MoeDecoderLayer(Qwen2MoeDecoderLayer): class Qwen3MoePreTrainedModel(MixtralPreTrainedModel): _can_record_outputs = { - "router_logits": OutputRecorder(nn.Linear, layer_name="mlp.gate", index=0), + "router_logits": OutputRecorder(Qwen3MoeTopKRouter, layer_name="mlp.router", index=0), "hidden_states": Qwen3MoeDecoderLayer, "attentions": Qwen3MoeAttention, } diff --git a/src/transformers/models/qwen3_next/modeling_qwen3_next.py b/src/transformers/models/qwen3_next/modeling_qwen3_next.py index 3847c43117a3..72d097c35543 100644 --- a/src/transformers/models/qwen3_next/modeling_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modeling_qwen3_next.py @@ -819,66 +819,80 @@ def forward(self, x): return down_proj -class Qwen3NextExperts(nn.ModuleList): - """ - ModuleList of experts. - """ +class Qwen3NextExperts(nn.Module): + """Collection of expert weights stored as 3D tensors.""" def __init__(self, config): super().__init__() self.num_experts = config.num_experts - for _ in range(config.num_experts): - self.append(Qwen3NextMLP(config, intermediate_size=config.moe_intermediate_size)) + self.hidden_dim = config.hidden_size + self.intermediate_dim = config.moe_intermediate_size + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) + self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) + self.act_fn = ACT2FN[config.hidden_act] def forward( - self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, ) -> torch.Tensor: - """ - Args: - hidden_states: (batch_size * sequence_length, hidden_dim) - top_k_index: (batch_size * sequence_length, top_k) - top_k_weights: (batch_size * sequence_length, top_k) - Returns: - (batch_size * sequence_length, hidden_dim) - """ final_hidden_states = torch.zeros_like(hidden_states) - expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0) + num_experts = top_k_weights.shape[1] + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() for expert_idx in expert_hit: - idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) - current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1]) - current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None] - final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + expert_idx = expert_idx[0] + if expert_idx == num_experts: + continue + _, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) + current_hidden_states = self.act_fn(gate) * up + current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) + current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None] + final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) + return final_hidden_states -class Qwen3NextSparseMoeBlock(nn.Module): +class Qwen3NextTopKRouter(nn.Module): def __init__(self, config): super().__init__() - # gating - self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) - self.experts = Qwen3NextExperts(config) - self.num_experts_per_tok = config.num_experts_per_tok + self.top_k = config.num_experts_per_tok + self.num_experts = config.num_experts self.norm_topk_prob = config.norm_topk_prob + self.hidden_dim = config.hidden_size + self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim)) + + def forward(self, hidden_states): + hidden_states = hidden_states.reshape(-1, self.hidden_dim) + router_logits = F.linear(hidden_states, self.weight) # (seq_len, num_experts) + router_logits = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1) + router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) + if self.norm_topk_prob: + router_top_value /= router_top_value.sum(dim=-1, keepdim=True) + router_top_value = router_top_value.to(router_logits.dtype) + router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) + return router_scores, router_indices + +class Qwen3NextSparseMoeBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.gate = Qwen3NextTopKRouter(config) + self.experts = Qwen3NextExperts(config) self.shared_expert = Qwen3NextMLP(config, intermediate_size=config.shared_expert_intermediate_size) self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False) - def route_tokens_to_experts(self, hidden_states, router_logits): - routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float) - routing_weights, selected_experts = torch.topk(routing_weights, self.num_experts_per_tok, dim=-1) - if self.norm_topk_prob: - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - routing_weights = routing_weights.to(router_logits.dtype) - return selected_experts, routing_weights - def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states_reshaped = hidden_states.view(-1, hidden_dim) shared_expert_output = self.shared_expert(hidden_states_reshaped) - router_logits = self.gate(hidden_states_reshaped) - selected_experts, routing_weights = self.route_tokens_to_experts(hidden_states_reshaped, router_logits) + routing_weights, selected_experts = self.gate(hidden_states_reshaped) expert_output = self.experts(hidden_states_reshaped, selected_experts, routing_weights) shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states_reshaped)) * shared_expert_output @@ -974,14 +988,18 @@ class Qwen3NextPreTrainedModel(PreTrainedModel): } _is_stateful = True + @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) if isinstance(module, Qwen3NextGatedDeltaNet): - module.dt_bias.data.fill_(1.0) - module.A_log.data.uniform_(0, 16).log_() + module.dt_bias.fill_(1.0) + module.A_log.uniform_(0, 16).log_() # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight) elif isinstance(module, Qwen3NextRMSNorm): - module.weight.data.zero_() + module.weight.zero_() + if isinstance(module, Qwen3NextExperts): + module.gate_up_proj.normal_(mean=0.0, std=self.config.initializer_range) + module.down_proj.normal_(mean=0.0, std=self.config.initializer_range) class Qwen3NextModel(Qwen3NextPreTrainedModel): @@ -1158,7 +1176,7 @@ def load_balancing_loss_func( @auto_docstring class Qwen3NextForCausalLM(Qwen3NextPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/qwen3_next/modular_qwen3_next.py b/src/transformers/models/qwen3_next/modular_qwen3_next.py index e624a653150b..ae95f727b993 100644 --- a/src/transformers/models/qwen3_next/modular_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modular_qwen3_next.py @@ -43,7 +43,7 @@ LlamaForTokenClassification, ) from ..mixtral.modeling_mixtral import MixtralForCausalLM -from ..qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock +from ..qwen2_moe.modeling_qwen2_moe import Qwen2MoeExperts, Qwen2MoeSparseMoeBlock from ..qwen3_moe.modeling_qwen3_moe import ( Qwen3MoeAttention, Qwen3MoeDecoderLayer, @@ -642,6 +642,10 @@ class Qwen3NextMLP(Qwen3MoeMLP): pass +class Qwen3NextExperts(Qwen2MoeExperts): + pass + + class Qwen3NextSparseMoeBlock(Qwen2MoeSparseMoeBlock): pass @@ -732,14 +736,18 @@ class Qwen3NextPreTrainedModel(PreTrainedModel): } _is_stateful = True + @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) if isinstance(module, Qwen3NextGatedDeltaNet): - module.dt_bias.data.fill_(1.0) - module.A_log.data.uniform_(0, 16).log_() + module.dt_bias.fill_(1.0) + module.A_log.uniform_(0, 16).log_() # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight) elif isinstance(module, Qwen3NextRMSNorm): - module.weight.data.zero_() + module.weight.zero_() + if isinstance(module, Qwen3NextExperts): + module.gate_up_proj.normal_(mean=0.0, std=self.config.initializer_range) + module.down_proj.normal_(mean=0.0, std=self.config.initializer_range) class Qwen3NextModel(Qwen3NextPreTrainedModel): diff --git a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py index aabd906dc3b2..df2ae424649d 100644 --- a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py @@ -1307,23 +1307,7 @@ def apply_interleaved_mrope(self, freqs, mrope_section): return freqs_t -class Qwen3OmniMoeThinkerTextMLP(nn.Module): - def __init__(self, config, intermediate_size=None): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - return down_proj - - -class Qwen3OmniMoeThinkerTextExperts(nn.ModuleList): +class Qwen3OmniMoeThinkerTextExperts(nn.Module): """ ModuleList of experts. """ @@ -1331,53 +1315,71 @@ class Qwen3OmniMoeThinkerTextExperts(nn.ModuleList): def __init__(self, config: Qwen3OmniMoeThinkerConfig): super().__init__() self.num_experts = config.num_experts - for _ in range(self.num_experts): - self.append(Qwen3OmniMoeThinkerTextMLP(config, intermediate_size=config.moe_intermediate_size)) + self.hidden_dim = config.hidden_size + self.intermediate_dim = config.moe_intermediate_size + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) + self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) + self.act_fn = ACT2FN[config.hidden_act] def forward( - self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, ) -> torch.Tensor: - """ - Args: - hidden_states: (batch_size * sequence_length, hidden_dim) - top_k_index: (batch_size * sequence_length, top_k) - top_k_weights: (batch_size * sequence_length, top_k) - Returns: - (batch_size * sequence_length, hidden_dim) - """ final_hidden_states = torch.zeros_like(hidden_states) - expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0) + num_experts = top_k_weights.shape[1] + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() for expert_idx in expert_hit: - idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) - current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1]) - current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None] - final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + expert_idx = expert_idx[0] + if expert_idx == num_experts: + continue + _, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) + current_hidden_states = self.act_fn(gate) * up + current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) + current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None] + final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) + return final_hidden_states -class Qwen3OmniMoeThinkerTextSparseMoeBlock(nn.Module): - def __init__(self, config: Qwen3OmniMoeThinkerConfig): +class Qwen3OmniMoeThinkerTextTopKRouter(nn.Module): + def __init__(self, config): super().__init__() - self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) - self.experts = Qwen3OmniMoeThinkerTextExperts(config) - self.num_experts_per_tok = config.num_experts_per_tok + self.top_k = config.num_experts_per_tok + self.num_experts = config.num_experts self.norm_topk_prob = config.norm_topk_prob + self.hidden_dim = config.hidden_size + self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim)) - def route_tokens_to_experts(self, hidden_states, router_logits): - routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float) - routing_weights, selected_experts = torch.topk(routing_weights, self.num_experts_per_tok, dim=-1) + def forward(self, hidden_states): + hidden_states = hidden_states.reshape(-1, self.hidden_dim) + router_logits = F.linear(hidden_states, self.weight) # (seq_len, num_experts) + router_logits = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1) + router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) if self.norm_topk_prob: - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - routing_weights = routing_weights.to(router_logits.dtype) - return selected_experts, routing_weights + router_top_value /= router_top_value.sum(dim=-1, keepdim=True) + router_top_value = router_top_value.to(router_logits.dtype) + router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) + return router_scores, router_indices + + +class Qwen3OmniMoeThinkerTextSparseMoeBlock(nn.Module): + def __init__(self, config: Qwen3OmniMoeThinkerConfig): + super().__init__() + self.experts = Qwen3OmniMoeThinkerTextExperts(config) + self.router = Qwen3OmniMoeThinkerTextTopKRouter(config) def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states_reshaped = hidden_states.view(-1, hidden_dim) - router_logits = self.gate(hidden_states_reshaped) - selected_experts, routing_weights = self.route_tokens_to_experts(hidden_states_reshaped, router_logits) + routing_weights, selected_experts = self.router(hidden_states_reshaped) final_hidden_states = self.experts(hidden_states_reshaped, selected_experts, routing_weights) return final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) @@ -1508,6 +1510,22 @@ def forward( return attn_output, attn_weights +class Qwen3OmniMoeThinkerTextMLP(nn.Module): + def __init__(self, config, intermediate_size=None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + class Qwen3OmniMoeThinkerTextDecoderLayer(GradientCheckpointingLayer): def __init__(self, config, layer_idx): super().__init__() @@ -1569,12 +1587,22 @@ class Qwen3OmniMoeThinkerTextPreTrainedModel(PreTrainedModel): _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) _supports_attention_backend = True _can_record_outputs = { - "router_logits": OutputRecorder(nn.Linear, layer_name="mlp.gate", index=0), + "router_logits": OutputRecorder(Qwen3OmniMoeThinkerTextTopKRouter, layer_name="mlp.router", index=0), "hidden_states": Qwen3OmniMoeThinkerTextDecoderLayer, "attentions": Qwen3OmniMoeThinkerTextAttention, } config_class = Qwen3OmniMoeTextConfig + @torch.no_grad() + def _init_weights(self, module): + super()._init_weights(module) + std = self.config.initializer_range + if isinstance(module, Qwen3OmniMoeThinkerTextExperts): + module.gate_up_proj.normal_(mean=0.0, std=std) + module.down_proj.normal_(mean=0.0, std=std) + elif isinstance(module, Qwen3OmniMoeThinkerTextTopKRouter): + module.weight.normal_(mean=0.0, std=std) + @use_kernel_forward_from_hub("RMSNorm") class Qwen3OmniMoeTextRMSNorm(nn.Module): @@ -1837,7 +1865,7 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( ): config: Qwen3OmniMoeThinkerConfig base_model_prefix = "thinker" - _tied_weights_keys = ["model.embed_tokens.weight", "lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _no_split_modules = [ "Qwen3OmniMoeAudioEncoderLayer", "Qwen3OmniMoeThinkerTextDecoderLayer", @@ -2590,7 +2618,7 @@ def get_input_embeddings(self): @auto_docstring class Qwen3OmniMoeTalkerCodePredictorModelForConditionalGeneration(Qwen3OmniMoePreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} config_class = Qwen3OmniMoeTalkerCodePredictorConfig @@ -2707,68 +2735,82 @@ def forward(self, x): return down_proj -class Qwen3OmniMoeTalkerTextExperts(nn.ModuleList): - """ - ModuleList of experts. - """ +class Qwen3OmniMoeTalkerTextExperts(nn.Module): + """Collection of expert weights stored as 3D tensors.""" def __init__(self, config): super().__init__() self.num_experts = config.num_experts - for _ in range(config.num_experts): - self.append(Qwen3OmniMoeTalkerTextMLP(config, intermediate_size=config.moe_intermediate_size)) + self.hidden_dim = config.hidden_size + self.intermediate_dim = config.moe_intermediate_size + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) + self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) + self.act_fn = ACT2FN[config.hidden_act] def forward( - self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, ) -> torch.Tensor: - """ - Args: - hidden_states: (batch_size * sequence_length, hidden_dim) - top_k_index: (batch_size * sequence_length, top_k) - top_k_weights: (batch_size * sequence_length, top_k) - Returns: - (batch_size * sequence_length, hidden_dim) - """ final_hidden_states = torch.zeros_like(hidden_states) - expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0) + num_experts = top_k_weights.shape[1] + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() for expert_idx in expert_hit: - idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) - current_state = hidden_states[None, top_x].reshape(-1, hidden_states.shape[-1]) - current_hidden_states = self[expert_idx](current_state) * top_k_weights[top_x, idx, None] - final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + expert_idx = expert_idx[0] + if expert_idx == num_experts: + continue + _, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) + current_hidden_states = self.act_fn(gate) * up + current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) + current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None] + final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) + return final_hidden_states -class Qwen3OmniMoeTalkerTextSparseMoeBlock(nn.Module): +class Qwen3OmniMoeTalkerTextTopKRouter(nn.Module): def __init__(self, config): super().__init__() - # gating - self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) - self.experts = Qwen3OmniMoeTalkerTextExperts(config) - self.num_experts_per_tok = config.num_experts_per_tok + self.top_k = config.num_experts_per_tok + self.num_experts = config.num_experts self.norm_topk_prob = config.norm_topk_prob + self.hidden_dim = config.hidden_size + self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim)) + def forward(self, hidden_states): + hidden_states = hidden_states.reshape(-1, self.hidden_dim) + router_logits = F.linear(hidden_states, self.weight) # (seq_len, num_experts) + router_logits = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1) + router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) + if self.norm_topk_prob: + router_top_value /= router_top_value.sum(dim=-1, keepdim=True) + router_top_value = router_top_value.to(router_logits.dtype) + router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) + return router_scores, router_indices + + +class Qwen3OmniMoeTalkerTextSparseMoeBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.gate = Qwen3OmniMoeTalkerTextTopKRouter(config) + self.experts = Qwen3OmniMoeTalkerTextExperts(config) self.shared_expert = Qwen3OmniMoeTalkerTextMLP( config, intermediate_size=config.shared_expert_intermediate_size ) self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False) - def route_tokens_to_experts(self, hidden_states, router_logits): - routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float) - routing_weights, selected_experts = torch.topk(routing_weights, self.num_experts_per_tok, dim=-1) - if self.norm_topk_prob: - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - routing_weights = routing_weights.to(router_logits.dtype) - return selected_experts, routing_weights - def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states_reshaped = hidden_states.view(-1, hidden_dim) shared_expert_output = self.shared_expert(hidden_states_reshaped) - router_logits = self.gate(hidden_states_reshaped) - selected_experts, routing_weights = self.route_tokens_to_experts(hidden_states_reshaped, router_logits) + routing_weights, selected_experts = self.gate(hidden_states_reshaped) expert_output = self.experts(hidden_states_reshaped, selected_experts, routing_weights) shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states_reshaped)) * shared_expert_output @@ -2969,7 +3011,7 @@ def get_input_embeddings(self): @auto_docstring class Qwen3OmniMoeTalkerForConditionalGeneration(Qwen3OmniMoeThinkerTextPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} config_class = Qwen3OmniMoeTalkerConfig diff --git a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py index 37f6a5146053..beb526201c27 100644 --- a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -1300,7 +1300,7 @@ class Qwen3VLCausalLMOutputWithPast(ModelOutput): class Qwen3VLForConditionalGeneration(Qwen3VLPreTrainedModel, GenerationMixin): _checkpoint_conversion_mapping = {} - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} # Reference: fix gemma3 grad acc #37208 accepts_loss_kwargs = False config: Qwen3VLConfig diff --git a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index 264902c2d8a4..e4125e630911 100644 --- a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -71,7 +71,7 @@ def __init__(self, config): self.intermediate_size = config.moe_intermediate_size self.hidden_size = config.hidden_size self.expert_dim = self.intermediate_size - self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim)) + self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size, 2 * self.expert_dim)) self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size))) self.act_fn = ACT2FN[config.hidden_act] @@ -365,6 +365,27 @@ def forward( return hidden_states +class Qwen3VLMoeTextTopKRouter(nn.Module): + def __init__(self, config): + super().__init__() + self.top_k = config.num_experts_per_tok + self.num_experts = config.num_experts + self.norm_topk_prob = config.norm_topk_prob + self.hidden_dim = config.hidden_size + self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim)) + + def forward(self, hidden_states): + hidden_states = hidden_states.reshape(-1, self.hidden_dim) + router_logits = F.linear(hidden_states, self.weight) # (seq_len, num_experts) + router_logits = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1) + router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) + if self.norm_topk_prob: + router_top_value /= router_top_value.sum(dim=-1, keepdim=True) + router_top_value = router_top_value.to(router_logits.dtype) + router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) + return router_scores, router_indices + + @auto_docstring class Qwen3VLMoePreTrainedModel(PreTrainedModel): config: Qwen3VLMoeConfig @@ -378,11 +399,12 @@ class Qwen3VLMoePreTrainedModel(PreTrainedModel): _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) _supports_attention_backend = True _can_record_outputs = { - "router_logits": OutputRecorder(nn.Linear, layer_name="mlp.gate", index=0), + "router_logits": OutputRecorder(Qwen3VLMoeTextTopKRouter, layer_name="mlp.router", index=0), "hidden_states": Qwen3VLMoeTextDecoderLayer, "attentions": Qwen3VLMoeTextAttention, } + @torch.no_grad() def _init_weights(self, module): """Initialize the weights.""" super()._init_weights(module) @@ -391,8 +413,8 @@ def _init_weights(self, module): else: std = getattr(self.config.get_text_config(), "initializer_range", 0.02) if isinstance(module, Qwen3VLMoeTextExperts): - module.gate_up_proj.data.normal_(mean=0.0, std=std) - module.down_proj.data.normal_(mean=0.0, std=std) + module.gate_up_proj.normal_(mean=0.0, std=std) + module.down_proj.normal_(mean=0.0, std=std) class Qwen3VLMoeVisionMLP(nn.Module): @@ -1487,7 +1509,7 @@ def load_balancing_loss_func( class Qwen3VLMoeForConditionalGeneration(Qwen3VLMoePreTrainedModel, GenerationMixin): _checkpoint_conversion_mapping = {} - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} # Reference: fix gemma3 grad acc #37208 accepts_loss_kwargs = False config: Qwen3VLMoeConfig diff --git a/src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py index c0c4be2ddb68..459d45159fdc 100644 --- a/src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py @@ -265,7 +265,7 @@ def __init__(self, config): self.intermediate_size = config.moe_intermediate_size self.hidden_size = config.hidden_size self.expert_dim = self.intermediate_size - self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim)) + self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size, 2 * self.expert_dim)) self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size))) self.act_fn = ACT2FN[config.hidden_act] @@ -358,6 +358,7 @@ class Qwen3VLMoePreTrainedModel(Qwen3MoePreTrainedModel): config: Qwen3VLMoeConfig _no_split_modules = ["Qwen3VLMoeTextDecoderLayer", "Qwen3VLMoeVisionBlock"] + @torch.no_grad() def _init_weights(self, module): """Initialize the weights.""" PreTrainedModel._init_weights(self, module) @@ -366,8 +367,8 @@ def _init_weights(self, module): else: std = getattr(self.config.get_text_config(), "initializer_range", 0.02) if isinstance(module, Qwen3VLMoeTextExperts): - module.gate_up_proj.data.normal_(mean=0.0, std=std) - module.down_proj.data.normal_(mean=0.0, std=std) + module.gate_up_proj.normal_(mean=0.0, std=std) + module.down_proj.normal_(mean=0.0, std=std) class Qwen3VLMoeVisionModel(Qwen3VLVisionModel): diff --git a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py index 6abf3a0599ca..a1d58064207e 100644 --- a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +++ b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py @@ -553,6 +553,7 @@ class RecurrentGemmaPreTrainedModel(PreTrainedModel): _supports_flash_attn = False _supports_sdpa = False # we can't compare with eager for now + @torch.no_grad() def _init_weights(self, module): std = math.sqrt(self.config.w_init_variance_scale / self.config.conv1d_width) if isinstance(module, nn.Conv1d): @@ -584,21 +585,21 @@ def _init_weights(self, module): torch.nn.init.zeros_(module.input_gate_bias) torch.nn.init.zeros_(module.recurrent_gate_bias) - module.recurrent_param.data.uniform_(0.9**2 + 1e-8, 0.999**2 + 1e-8) - module.recurrent_param.data.log_().mul_(0.5) - module.recurrent_param.data.neg_().exp_().sub_(1.0).log_() + module.recurrent_param.uniform_(0.9**2 + 1e-8, 0.999**2 + 1e-8) + module.recurrent_param.log_().mul_(0.5) + module.recurrent_param.neg_().exp_().sub_(1.0).log_() elif isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=std) if getattr(module, "bias", None) is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight) elif isinstance(module, RecurrentGemmaRMSNorm): - module.weight.data.zero_() + module.weight.zero_() def _setup_cache(self, config, batch, device, dtype): layers = getattr(self, "model", self).layers @@ -728,7 +729,7 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position): # TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->RECURRENTGEMMA,Llama->RecurrentGemma,llama->gemma @auto_docstring class RecurrentGemmaForCausalLM(RecurrentGemmaPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/reformer/modeling_reformer.py b/src/transformers/models/reformer/modeling_reformer.py index a880837004be..5cfeca479f51 100755 --- a/src/transformers/models/reformer/modeling_reformer.py +++ b/src/transformers/models/reformer/modeling_reformer.py @@ -1817,9 +1817,8 @@ def __init__(self, config): # Layer Norm is done over 2 * hidden_size self.seq_len_dim = 1 self.chunk_size_lm_head = config.chunk_size_lm_head - self.decoder = nn.Linear(2 * config.hidden_size, config.vocab_size, bias=False) + self.decoder = nn.Linear(2 * config.hidden_size, config.vocab_size, bias=True) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - self.decoder.bias = self.bias def forward(self, hidden_states): return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states) @@ -1828,14 +1827,6 @@ def forward_chunk(self, hidden_states): hidden_states = self.decoder(hidden_states) return hidden_states - def _tie_weights(self) -> None: - # For accelerate compatibility and to not break backward compatibility - if self.decoder.bias.device.type == "meta": - self.decoder.bias = self.bias - else: - # To tie those two weights if they get disconnected (on TPU or when the bias is resized) - self.bias = self.decoder.bias - @auto_docstring class ReformerPreTrainedModel(PreTrainedModel): @@ -1852,22 +1843,23 @@ def dummy_inputs(self): } return dummy_inputs + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, AxialPositionEmbeddings): for weight in module.weights: nn.init.normal_(weight, std=self.config.axial_norm_std) elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) @dataclass @@ -2149,7 +2141,9 @@ def _pad_to_mult_of_chunk_length( """ ) class ReformerModelWithLMHead(ReformerPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] + _tied_weights_keys = { + "lm_head.decoder.bias": "lm_head.bias", + } def __init__(self, config): super().__init__(config) @@ -2285,7 +2279,9 @@ def prepare_inputs_for_generation( @auto_docstring class ReformerForMaskedLM(ReformerPreTrainedModel): - _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] + _tied_weights_keys = { + "lm_head.decoder.bias": "lm_head.bias", + } def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/regnet/modeling_regnet.py b/src/transformers/models/regnet/modeling_regnet.py index 70611113885f..fd6416f46ec6 100644 --- a/src/transformers/models/regnet/modeling_regnet.py +++ b/src/transformers/models/regnet/modeling_regnet.py @@ -263,7 +263,7 @@ class RegNetPreTrainedModel(PreTrainedModel): main_input_name = "pixel_values" _no_split_modules = ["RegNetYLayer"] - # Copied from transformers.models.resnet.modeling_resnet.ResNetPreTrainedModel._init_weights + @torch.no_grad() def _init_weights(self, module): if isinstance(module, nn.Conv2d): nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") diff --git a/src/transformers/models/rembert/modeling_rembert.py b/src/transformers/models/rembert/modeling_rembert.py index a8e4a29e806f..13651c32f5da 100755 --- a/src/transformers/models/rembert/modeling_rembert.py +++ b/src/transformers/models/rembert/modeling_rembert.py @@ -488,19 +488,20 @@ class RemBertPreTrainedModel(PreTrainedModel): base_model_prefix = "rembert" supports_gradient_checkpointing = True + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) @auto_docstring( @@ -638,8 +639,6 @@ def forward( @auto_docstring class RemBertForMaskedLM(RemBertPreTrainedModel): - _tied_weights_keys = ["cls.predictions.decoder.weight"] - def __init__(self, config): super().__init__(config) @@ -745,8 +744,6 @@ def can_generate(cls) -> bool: """ ) class RemBertForCausalLM(RemBertPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["cls.predictions.decoder.weight"] - def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/resnet/modeling_resnet.py b/src/transformers/models/resnet/modeling_resnet.py index 801907aa1e63..dba8200edba1 100644 --- a/src/transformers/models/resnet/modeling_resnet.py +++ b/src/transformers/models/resnet/modeling_resnet.py @@ -250,6 +250,7 @@ class ResNetPreTrainedModel(PreTrainedModel): input_modalities = "image" _no_split_modules = ["ResNetConvLayer", "ResNetShortCut"] + @torch.no_grad() def _init_weights(self, module): if isinstance(module, nn.Conv2d): nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index a718c3528805..f5b315f38f26 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -494,22 +494,22 @@ class RobertaPreTrainedModel(PreTrainedModel): "cross_attentions": RobertaCrossAttention, } - # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights with BertLMPredictionHead->RobertaLMHead + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, RobertaLMHead): - module.bias.data.zero_() + module.bias.zero_() class RobertaEncoder(nn.Module): @@ -719,7 +719,10 @@ def _create_attention_masks( """ ) class RobertaForCausalLM(RobertaPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] + _tied_weights_keys = { + "lm_head.decoder.weight": "roberta.embeddings.word_embeddings.weight", + "lm_head.decoder.bias": "lm_head.bias", + } def __init__(self, config): super().__init__(config) @@ -827,7 +830,10 @@ def forward( @auto_docstring class RobertaForMaskedLM(RobertaPreTrainedModel): - _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] + _tied_weights_keys = { + "lm_head.decoder.weight": "roberta.embeddings.word_embeddings.weight", + "lm_head.decoder.bias": "lm_head.bias", + } def __init__(self, config): super().__init__(config) @@ -918,7 +924,6 @@ def __init__(self, config): self.decoder = nn.Linear(config.hidden_size, config.vocab_size) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - self.decoder.bias = self.bias def forward(self, features, **kwargs): x = self.dense(features) @@ -930,14 +935,6 @@ def forward(self, features, **kwargs): return x - def _tie_weights(self): - # To tie those two weights if they get disconnected (on TPU or when the bias is resized) - # For accelerate compatibility and to not break backward compatibility - if self.decoder.bias.device.type == "meta": - self.decoder.bias = self.bias - else: - self.bias = self.decoder.bias - @auto_docstring( custom_intro=""" diff --git a/src/transformers/models/roberta/modular_roberta.py b/src/transformers/models/roberta/modular_roberta.py index 5884e893027d..54049e1189da 100644 --- a/src/transformers/models/roberta/modular_roberta.py +++ b/src/transformers/models/roberta/modular_roberta.py @@ -165,22 +165,22 @@ class RobertaPreTrainedModel(PreTrainedModel): "cross_attentions": RobertaCrossAttention, } - # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights with BertLMPredictionHead->RobertaLMHead + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, RobertaLMHead): - module.bias.data.zero_() + module.bias.zero_() class RobertaModel(BertModel): @@ -194,7 +194,10 @@ def __init__(self, config, add_pooling_layer=True): """ ) class RobertaForCausalLM(RobertaPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] + _tied_weights_keys = { + "lm_head.decoder.weight": "roberta.embeddings.word_embeddings.weight", + "lm_head.decoder.bias": "lm_head.bias", + } def __init__(self, config): super().__init__(config) @@ -302,7 +305,10 @@ def forward( @auto_docstring class RobertaForMaskedLM(RobertaPreTrainedModel): - _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] + _tied_weights_keys = { + "lm_head.decoder.weight": "roberta.embeddings.word_embeddings.weight", + "lm_head.decoder.bias": "lm_head.bias", + } def __init__(self, config): super().__init__(config) @@ -393,7 +399,6 @@ def __init__(self, config): self.decoder = nn.Linear(config.hidden_size, config.vocab_size) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - self.decoder.bias = self.bias def forward(self, features, **kwargs): x = self.dense(features) @@ -405,14 +410,6 @@ def forward(self, features, **kwargs): return x - def _tie_weights(self): - # To tie those two weights if they get disconnected (on TPU or when the bias is resized) - # For accelerate compatibility and to not break backward compatibility - if self.decoder.bias.device.type == "meta": - self.decoder.bias = self.bias - else: - self.bias = self.decoder.bias - @auto_docstring( custom_intro=""" diff --git a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py index 17cc0ad9e3ae..31cbdbc9e762 100644 --- a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py @@ -554,22 +554,22 @@ class RobertaPreLayerNormPreTrainedModel(PreTrainedModel): "cross_attentions": RobertaPreLayerNormCrossAttention, } - # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights with BertLMPredictionHead->RobertaPreLayerNormLMHead + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, RobertaPreLayerNormLMHead): - module.bias.data.zero_() + module.bias.zero_() @auto_docstring( @@ -747,7 +747,10 @@ def _create_attention_masks( ) # Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM with FacebookAI/roberta-base->andreasmadsen/efficient_mlm_m0.40,ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm,roberta->roberta_prelayernorm, RobertaPreLayerNormTokenizer->RobertaTokenizer class RobertaPreLayerNormForCausalLM(RobertaPreLayerNormPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] + _tied_weights_keys = { + "lm_head.decoder.weight": "roberta_prelayernorm.embeddings.word_embeddings.weight", + "lm_head.decoder.bias": "lm_head.bias", + } def __init__(self, config): super().__init__(config) @@ -861,7 +864,10 @@ def forward( """ ) class RobertaPreLayerNormForMaskedLM(RobertaPreLayerNormPreTrainedModel): - _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] + _tied_weights_keys = { + "lm_head.decoder.weight": "roberta_prelayernorm.embeddings.weight", + "lm_head.decoder.bias": "lm_head.bias", + } # Copied from transformers.models.roberta.modeling_roberta.RobertaForMaskedLM.__init__ with ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm,roberta->roberta_prelayernorm def __init__(self, config): @@ -955,7 +961,6 @@ def __init__(self, config): self.decoder = nn.Linear(config.hidden_size, config.vocab_size) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - self.decoder.bias = self.bias def forward(self, features, **kwargs): x = self.dense(features) @@ -967,14 +972,6 @@ def forward(self, features, **kwargs): return x - def _tie_weights(self): - # To tie those two weights if they get disconnected (on TPU or when the bias is resized) - # For accelerate compatibility and to not break backward compatibility - if self.decoder.bias.device.type == "meta": - self.decoder.bias = self.bias - else: - self.bias = self.decoder.bias - @auto_docstring( custom_intro=""" diff --git a/src/transformers/models/roc_bert/modeling_roc_bert.py b/src/transformers/models/roc_bert/modeling_roc_bert.py index b7ae250bd297..6800fa2fbfa5 100644 --- a/src/transformers/models/roc_bert/modeling_roc_bert.py +++ b/src/transformers/models/roc_bert/modeling_roc_bert.py @@ -579,16 +579,9 @@ def __init__(self, config): # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. - self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` - self.decoder.bias = self.bias - - def _tie_weights(self): - self.decoder.bias = self.bias - def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -621,21 +614,22 @@ class RoCBertPreTrainedModel(PreTrainedModel): "cross_attentions": RoCBertCrossAttention, } + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, RoCBertLMPredictionHead): - module.bias.data.zero_() + module.bias.zero_() @auto_docstring( @@ -827,7 +821,10 @@ def _create_attention_masks( """ ) class RoCBertForPreTraining(RoCBertPreTrainedModel): - _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"] + _tied_weights_keys = { + "cls.predictions.decoder.bias": "cls.predictions.bias", + "cls.predictions.decoder.weight": "roc_bert.embeddings.word_embeddings.weight", + } def __init__(self, config): super().__init__(config) @@ -1020,7 +1017,10 @@ def forward( @auto_docstring class RoCBertForMaskedLM(RoCBertPreTrainedModel): - _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"] + _tied_weights_keys = { + "cls.predictions.decoder.bias": "cls.predictions.bias", + "cls.predictions.decoder.weight": "roc_bert.embeddings.word_embeddings.weight", + } # Copied from transformers.models.bert.modeling_bert.BertForMaskedLM.__init__ with Bert->RoCBert,bert->roc_bert def __init__(self, config): @@ -1175,7 +1175,10 @@ def can_generate(cls) -> bool: """ ) class RoCBertForCausalLM(RoCBertPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"] + _tied_weights_keys = { + "cls.predictions.decoder.bias": "cls.predictions.bias", + "cls.predictions.decoder.weight": "roc_bert.embeddings.word_embeddings.weight", + } # Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.__init__ with BertLMHeadModel->RoCBertForCausalLM,Bert->RoCBert,bert->roc_bert def __init__(self, config): diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py index b7c5afa01722..0aa4cb11bf51 100644 --- a/src/transformers/models/roformer/modeling_roformer.py +++ b/src/transformers/models/roformer/modeling_roformer.py @@ -608,16 +608,10 @@ def __init__(self, config): # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. - self.decoder = nn.Linear(config.embedding_size, config.vocab_size, bias=False) + self.decoder = nn.Linear(config.embedding_size, config.vocab_size, bias=True) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` - self.decoder.bias = self.bias - - def _tie_weights(self) -> None: - self.decoder.bias = self.bias - def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -641,23 +635,24 @@ class RoFormerPreTrainedModel(PreTrainedModel): base_model_prefix = "roformer" supports_gradient_checkpointing = True + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, RoFormerSinusoidalPositionalEmbedding): module._init_weight() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, RoFormerLMPredictionHead): - module.bias.data.zero_() + module.bias.zero_() @auto_docstring( @@ -796,7 +791,10 @@ def forward( @auto_docstring class RoFormerForMaskedLM(RoFormerPreTrainedModel): - _tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"] + _tied_weights_keys = { + "cls.predictions.decoder.bias": "cls.predictions.bias", + "cls.predictions.decoder.weight": "roformer.embeddings.word_embeddings.weight", + } def __init__(self, config): super().__init__(config) @@ -894,7 +892,10 @@ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_ """ ) class RoFormerForCausalLM(RoFormerPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"] + _tied_weights_keys = { + "cls.predictions.decoder.bias": "cls.predictions.bias", + "cls.predictions.decoder.weight": "roformer.embeddings.word_embeddings.weight", + } def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/rt_detr/modeling_rt_detr.py b/src/transformers/models/rt_detr/modeling_rt_detr.py index 05159b06e335..c6c6e9645da2 100644 --- a/src/transformers/models/rt_detr/modeling_rt_detr.py +++ b/src/transformers/models/rt_detr/modeling_rt_detr.py @@ -1010,6 +1010,7 @@ class RTDetrPreTrainedModel(PreTrainedModel): input_modalities = "image" _no_split_modules = [r"RTDetrHybridEncoder", r"RTDetrDecoderLayer"] + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (RTDetrForObjectDetection, RTDetrDecoder)): @@ -1026,7 +1027,7 @@ def _init_weights(self, module): nn.init.constant_(layer.layers[-1].bias, 0) elif isinstance(module, RTDetrMultiscaleDeformableAttention): - nn.init.constant_(module.sampling_offsets.weight.data, 0.0) + nn.init.constant_(module.sampling_offsets.weight, 0.0) default_dtype = torch.get_default_dtype() thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * ( 2.0 * math.pi / module.n_heads @@ -1041,12 +1042,12 @@ def _init_weights(self, module): grid_init[:, :, i, :] *= i + 1 with torch.no_grad(): module.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) - nn.init.constant_(module.attention_weights.weight.data, 0.0) - nn.init.constant_(module.attention_weights.bias.data, 0.0) - nn.init.xavier_uniform_(module.value_proj.weight.data) - nn.init.constant_(module.value_proj.bias.data, 0.0) - nn.init.xavier_uniform_(module.output_proj.weight.data) - nn.init.constant_(module.output_proj.bias.data, 0.0) + nn.init.constant_(module.attention_weights.weight, 0.0) + nn.init.constant_(module.attention_weights.bias, 0.0) + nn.init.xavier_uniform_(module.value_proj.weight) + nn.init.constant_(module.value_proj.bias, 0.0) + nn.init.xavier_uniform_(module.output_proj.weight) + nn.init.constant_(module.output_proj.bias, 0.0) elif isinstance(module, RTDetrModel): prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1) @@ -1055,13 +1056,13 @@ def _init_weights(self, module): nn.init.constant_(module.enc_score_head.bias, bias) elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() if hasattr(module, "weight_embedding") and self.config.learn_initial_query: nn.init.xavier_uniform_(module.weight_embedding.weight) @@ -1813,30 +1814,23 @@ def forward( ) class RTDetrForObjectDetection(RTDetrPreTrainedModel): # When using clones, all layers > 0 will be clones, but layer 0 *is* required - _tied_weights_keys = ["bbox_embed", "class_embed"] + _tied_weights_keys = {"model.decoder.bbox_embed": "bbox_embed", "model.decoder.class_embed": "class_embed"} + # We can't initialize the model on meta device as some weights are modified during the initialization _no_split_modules = None def __init__(self, config: RTDetrConfig): super().__init__(config) - - # RTDETR encoder-decoder model self.model = RTDetrModel(config) - - # Detection heads on top - self.class_embed = partial(nn.Linear, config.d_model, config.num_labels) - self.bbox_embed = partial(RTDetrMLPPredictionHead, config, config.d_model, config.d_model, 4, num_layers=3) + num_pred = config.decoder_layers + self.class_embed = nn.ModuleList([torch.nn.Linear(config.d_model, config.num_labels) for _ in range(num_pred)]) + self.bbox_embed = nn.ModuleList([RTDetrMLPPredictionHead(config, config.d_model, config.d_model, 4, num_layers=3) for _ in range(num_pred)]) # if two-stage, the last class_embed and bbox_embed is for region proposal generation - num_pred = config.decoder_layers if config.with_box_refine: - self.class_embed = _get_clones(self.class_embed, num_pred) - self.bbox_embed = _get_clones(self.bbox_embed, num_pred) - else: - self.class_embed = nn.ModuleList([self.class_embed() for _ in range(num_pred)]) - self.bbox_embed = nn.ModuleList([self.bbox_embed() for _ in range(num_pred)]) - - # hack implementation for iterative bounding box refinement + self._tied_weights_keys[r"bbox_embed.(?![0])\d+"] = "bbox_embed.0" + self._tied_weights_keys[r"class_embed.(?![0])\d+"] = "class_embed.0" + # hack implementation for iterative bounding box refinement self.model.decoder.class_embed = self.class_embed self.model.decoder.bbox_embed = self.bbox_embed diff --git a/src/transformers/models/rt_detr/modeling_rt_detr_resnet.py b/src/transformers/models/rt_detr/modeling_rt_detr_resnet.py index b7e56abc170c..12f9d90d8eb5 100644 --- a/src/transformers/models/rt_detr/modeling_rt_detr_resnet.py +++ b/src/transformers/models/rt_detr/modeling_rt_detr_resnet.py @@ -20,6 +20,7 @@ import math from typing import Optional +import torch from torch import Tensor, nn from ...activations import ACT2FN @@ -303,6 +304,7 @@ class RTDetrResNetPreTrainedModel(PreTrainedModel): input_modalities = "image" _no_split_modules = ["RTDetrResNetConvLayer", "RTDetrResNetShortCut"] + @torch.no_grad() def _init_weights(self, module): if isinstance(module, nn.Conv2d): nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") diff --git a/src/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py b/src/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py index 6f85dacad092..8a16dc7fbf21 100644 --- a/src/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +++ b/src/transformers/models/rt_detr_v2/modeling_rt_detr_v2.py @@ -457,6 +457,7 @@ class RTDetrV2PreTrainedModel(PreTrainedModel): input_modalities = "image" _no_split_modules = [r"RTDetrV2HybridEncoder", r"RTDetrV2DecoderLayer"] + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (RTDetrV2ForObjectDetection, RTDetrV2Decoder)): @@ -473,7 +474,7 @@ def _init_weights(self, module): nn.init.constant_(layer.layers[-1].bias, 0) elif isinstance(module, RTDetrV2MultiscaleDeformableAttention): - nn.init.constant_(module.sampling_offsets.weight.data, 0.0) + nn.init.constant_(module.sampling_offsets.weight, 0.0) default_dtype = torch.get_default_dtype() thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * ( 2.0 * math.pi / module.n_heads @@ -488,12 +489,12 @@ def _init_weights(self, module): grid_init[:, :, i, :] *= i + 1 with torch.no_grad(): module.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) - nn.init.constant_(module.attention_weights.weight.data, 0.0) - nn.init.constant_(module.attention_weights.bias.data, 0.0) - nn.init.xavier_uniform_(module.value_proj.weight.data) - nn.init.constant_(module.value_proj.bias.data, 0.0) - nn.init.xavier_uniform_(module.output_proj.weight.data) - nn.init.constant_(module.output_proj.bias.data, 0.0) + nn.init.constant_(module.attention_weights.weight, 0.0) + nn.init.constant_(module.attention_weights.bias, 0.0) + nn.init.xavier_uniform_(module.value_proj.weight) + nn.init.constant_(module.value_proj.bias, 0.0) + nn.init.xavier_uniform_(module.output_proj.weight) + nn.init.constant_(module.output_proj.bias, 0.0) elif isinstance(module, RTDetrV2Model): prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1) @@ -502,13 +503,13 @@ def _init_weights(self, module): nn.init.constant_(module.enc_score_head.bias, bias) elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() if hasattr(module, "weight_embedding") and self.config.learn_initial_query: nn.init.xavier_uniform_(module.weight_embedding.weight) @@ -1810,7 +1811,7 @@ class RTDetrV2ObjectDetectionOutput(ModelOutput): ) class RTDetrV2ForObjectDetection(RTDetrV2PreTrainedModel): # When using clones, all layers > 0 will be clones, but layer 0 *is* required - _tied_weights_keys = ["bbox_embed", "class_embed"] + _tied_weights_keys = {"model.decoder.bbox_embed": "bbox_embed", "model.decoder.class_embed": "class_embed"} # We can't initialize the model on meta device as some weights are modified during the initialization _no_split_modules = None diff --git a/src/transformers/models/rwkv/modeling_rwkv.py b/src/transformers/models/rwkv/modeling_rwkv.py index 895abd981228..2f0a434720a2 100644 --- a/src/transformers/models/rwkv/modeling_rwkv.py +++ b/src/transformers/models/rwkv/modeling_rwkv.py @@ -366,6 +366,7 @@ class RwkvPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _is_stateful = True + @torch.no_grad() def _init_weights(self, module: nn.Module): """Initialize the weights.""" if isinstance(module, RwkvSelfAttention): @@ -398,12 +399,12 @@ def _init_weights(self, module: nn.Module): * 0.5 ) - module.time_decay.data = decay_speed - module.time_first.data = torch.ones_like(module.time_first * math.log(0.3) + zigzag) + module.time_decay.copy_(decay_speed) + module.time_first.copy_(torch.ones_like(module.time_first * math.log(0.3) + zigzag)) - module.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0) - module.time_mix_value.data = torch.pow(time_weight, ratio_1_to_almost0) + 0.3 * ratio_0_to_1 - module.time_mix_receptance.data = torch.pow(time_weight, 0.5 * ratio_1_to_almost0) + module.time_mix_key.copy_(torch.pow(time_weight, ratio_1_to_almost0)) + module.time_mix_value.copy_(torch.pow(time_weight, ratio_1_to_almost0) + 0.3 * ratio_0_to_1) + module.time_mix_receptance.copy_(torch.pow(time_weight, 0.5 * ratio_1_to_almost0)) elif isinstance(module, RwkvFeedForward): layer_id = module.layer_id num_hidden_layers = module.config.num_hidden_layers @@ -418,14 +419,14 @@ def _init_weights(self, module: nn.Module): ) time_weight = time_weight[None, None, :] - module.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0) - module.time_mix_receptance.data = torch.pow(time_weight, ratio_1_to_almost0) + module.time_mix_key.copy_(torch.pow(time_weight, ratio_1_to_almost0)) + module.time_mix_receptance.copy_(torch.pow(time_weight, ratio_1_to_almost0)) elif isinstance(module, nn.Linear): - shape = module.weight.data.shape + shape = module.weight.shape gain = 1.0 scale = 1.0 # extra scale for gain if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() if shape[0] > shape[1]: gain = math.sqrt(shape[0] / shape[1]) if shape[0] == self.config.vocab_size and shape[1] == self.config.hidden_size: # final projection? @@ -434,12 +435,12 @@ def _init_weights(self, module: nn.Module): gain *= scale nn.init.orthogonal_(module.weight, gain=gain) elif isinstance(module, nn.Embedding): - shape = module.weight.data.shape + shape = module.weight.shape gain = 1e-4 * math.sqrt(max(shape[0], shape[1])) nn.init.orthogonal_(module.weight, gain=gain) elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() @dataclass @@ -666,7 +667,7 @@ def _bnb_4bit_dequantize_and_rescale(self, target_layer, block_id): """ ) class RwkvForCausalLM(RwkvPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["head.weight"] + _tied_weights_keys = {"head.weight": "rwkv.embeddings.weight"} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/sam/modeling_sam.py b/src/transformers/models/sam/modeling_sam.py index cd59721180ba..eaaf534da364 100644 --- a/src/transformers/models/sam/modeling_sam.py +++ b/src/transformers/models/sam/modeling_sam.py @@ -1014,15 +1014,16 @@ class SamPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_sdpa = True + @torch.no_grad() def _init_weights(self, module: nn.Module): super()._init_weights(module) if isinstance(module, SamVisionAttention): if module.use_rel_pos: - module.rel_pos_h.data.zero_() - module.rel_pos_w.data.zero_() + module.rel_pos_h.zero_() + module.rel_pos_w.zero_() elif isinstance(module, SamVisionEncoder): if self.config.use_abs_pos: - module.pos_embed.data.zero_() + module.pos_embed.zero_() class SamVisionEncoder(SamPreTrainedModel): @@ -1113,7 +1114,9 @@ def forward( ) class SamModel(SamPreTrainedModel): input_modalities = ["image", "text"] - _tied_weights_keys = ["prompt_encoder.shared_embedding.positional_embedding"] + _tied_weights_keys = { + "prompt_encoder.shared_embedding.positional_embedding": "shared_image_embedding.positional_embedding" + } # need to be ignored, as it's a buffer and will not be correctly detected as tied weight _keys_to_ignore_on_load_missing = ["prompt_encoder.shared_embedding.positional_embedding"] _can_record_outputs = {"mask_decoder_attentions": OutputRecorder(SamTwoWayAttentionBlock, index=2)} @@ -1130,11 +1133,6 @@ def __init__(self, config: SamConfig): self.post_init() - def _tie_weights(self): - self.prompt_encoder.shared_embedding.positional_embedding.data = ( - self.shared_image_embedding.positional_embedding.data - ) - def get_input_embeddings(self): return self.vision_encoder.get_input_embeddings() diff --git a/src/transformers/models/sam2/modeling_sam2.py b/src/transformers/models/sam2/modeling_sam2.py index f7ec0da2d319..ef9bbd9600f7 100644 --- a/src/transformers/models/sam2/modeling_sam2.py +++ b/src/transformers/models/sam2/modeling_sam2.py @@ -556,27 +556,28 @@ class Sam2PreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_attention_backend = True + @torch.no_grad() def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, (nn.LayerNorm, Sam2LayerNorm)): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() if isinstance(module, Sam2HieraDetModel): if module.pos_embed is not None: - module.pos_embed.data.zero_() + module.pos_embed.zero_() if module.pos_embed_window is not None: - module.pos_embed_window.data.zero_() + module.pos_embed_window.zero_() if isinstance(module, Sam2Model): if module.no_memory_embedding is not None: - module.no_memory_embedding.data.zero_() + module.no_memory_embedding.zero_() class Sam2HieraDetModel(Sam2PreTrainedModel): @@ -1278,7 +1279,9 @@ def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores): ) class Sam2Model(Sam2PreTrainedModel): input_modalities = ["image", "text"] - _tied_weights_keys = ["prompt_encoder.shared_embedding.positional_embedding"] + _tied_weights_keys = { + "prompt_encoder.shared_embedding.positional_embedding": "shared_image_embedding.positional_embedding" + } # need to be ignored, as it's a buffer and will not be correctly detected as tied weight _keys_to_ignore_on_load_missing = ["prompt_encoder.shared_embedding.positional_embedding"] _can_record_outputs = {"mask_decoder_attentions": OutputRecorder(Sam2TwoWayAttentionBlock, index=2)} @@ -1309,11 +1312,6 @@ def __init__(self, config: Sam2Config): self.post_init() - def _tie_weights(self): - self.prompt_encoder.shared_embedding.positional_embedding.data = ( - self.shared_image_embedding.positional_embedding.data - ) - def get_input_embeddings(self): return self.vision_encoder.get_input_embeddings() diff --git a/src/transformers/models/sam2/modular_sam2.py b/src/transformers/models/sam2/modular_sam2.py index 1e6bb7f006be..20b617768a97 100644 --- a/src/transformers/models/sam2/modular_sam2.py +++ b/src/transformers/models/sam2/modular_sam2.py @@ -677,27 +677,28 @@ class Sam2PreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_attention_backend = True + @torch.no_grad() def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, (nn.LayerNorm, Sam2LayerNorm)): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() if isinstance(module, Sam2HieraDetModel): if module.pos_embed is not None: - module.pos_embed.data.zero_() + module.pos_embed.zero_() if module.pos_embed_window is not None: - module.pos_embed_window.data.zero_() + module.pos_embed_window.zero_() if isinstance(module, Sam2Model): if module.no_memory_embedding is not None: - module.no_memory_embedding.data.zero_() + module.no_memory_embedding.zero_() class Sam2HieraDetModel(Sam2PreTrainedModel): diff --git a/src/transformers/models/sam2_video/modeling_sam2_video.py b/src/transformers/models/sam2_video/modeling_sam2_video.py index 751e9c0445cb..5d5c1b52a1e1 100644 --- a/src/transformers/models/sam2_video/modeling_sam2_video.py +++ b/src/transformers/models/sam2_video/modeling_sam2_video.py @@ -666,31 +666,32 @@ class Sam2VideoPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_attention_backend = True + @torch.no_grad() def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, (nn.LayerNorm, Sam2VideoLayerNorm)): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() elif isinstance(module, Sam2VideoModel): if module.no_memory_positional_encoding is not None: - module.no_memory_positional_encoding.data.zero_() + module.no_memory_positional_encoding.zero_() if module.memory_temporal_positional_encoding is not None: - module.memory_temporal_positional_encoding.data.zero_() + module.memory_temporal_positional_encoding.zero_() if module.no_object_pointer is not None: - module.no_object_pointer.data.zero_() + module.no_object_pointer.zero_() if module.occlusion_spatial_embedding_parameter is not None: - module.occlusion_spatial_embedding_parameter.data.zero_() + module.occlusion_spatial_embedding_parameter.zero_() if isinstance(module, Sam2VideoMemoryFuserCXBlock): if module.scale is not None: - module.scale.data.zero_() + module.scale.zero_() class Sam2VideoVisionRotaryEmbedding(nn.Module): @@ -1560,7 +1561,9 @@ def get_1d_sine_pe(pos_inds, dim, temperature=10000): @auto_docstring class Sam2VideoModel(Sam2VideoPreTrainedModel): input_modalities = ["video", "text"] - _tied_weights_keys = ["prompt_encoder.shared_embedding.positional_embedding"] + _tied_weights_keys = { + "prompt_encoder.shared_embedding.positional_embedding": "shared_image_embedding.positional_embedding" + } # need to be ignored, as it's a buffer and will not be correctly detected as tied weight _keys_to_ignore_on_load_missing = ["prompt_encoder.shared_embedding.positional_embedding"] _can_record_outputs = {"mask_decoder_attentions": OutputRecorder(Sam2VideoTwoWayAttentionBlock, index=2)} @@ -1616,11 +1619,6 @@ def __init__(self, config: Sam2VideoConfig): self.post_init() - def _tie_weights(self): - self.prompt_encoder.shared_embedding.positional_embedding.data = ( - self.shared_image_embedding.positional_embedding.data - ) - def get_input_embeddings(self): return self.vision_encoder.get_input_embeddings() diff --git a/src/transformers/models/sam2_video/modular_sam2_video.py b/src/transformers/models/sam2_video/modular_sam2_video.py index 6caef802aa20..bb03658756f8 100644 --- a/src/transformers/models/sam2_video/modular_sam2_video.py +++ b/src/transformers/models/sam2_video/modular_sam2_video.py @@ -991,31 +991,32 @@ class Sam2VideoPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_attention_backend = True + @torch.no_grad() def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, (nn.LayerNorm, Sam2VideoLayerNorm)): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() elif isinstance(module, Sam2VideoModel): if module.no_memory_positional_encoding is not None: - module.no_memory_positional_encoding.data.zero_() + module.no_memory_positional_encoding.zero_() if module.memory_temporal_positional_encoding is not None: - module.memory_temporal_positional_encoding.data.zero_() + module.memory_temporal_positional_encoding.zero_() if module.no_object_pointer is not None: - module.no_object_pointer.data.zero_() + module.no_object_pointer.zero_() if module.occlusion_spatial_embedding_parameter is not None: - module.occlusion_spatial_embedding_parameter.data.zero_() + module.occlusion_spatial_embedding_parameter.zero_() if isinstance(module, Sam2VideoMemoryFuserCXBlock): if module.scale is not None: - module.scale.data.zero_() + module.scale.zero_() class Sam2VideoVisionRotaryEmbedding(nn.Module): @@ -1449,7 +1450,9 @@ def get_1d_sine_pe(pos_inds, dim, temperature=10000): @auto_docstring class Sam2VideoModel(Sam2Model): input_modalities = ["video", "text"] - _tied_weights_keys = ["prompt_encoder.shared_embedding.positional_embedding"] + _tied_weights_keys = { + "prompt_encoder.shared_embedding.positional_embedding": "shared_image_embedding.positional_embedding" + } # need to be ignored, as it's a buffer and will not be correctly detected as tied weight _keys_to_ignore_on_load_missing = ["prompt_encoder.shared_embedding.positional_embedding"] _keys_to_ignore_on_load_unexpected = [] diff --git a/src/transformers/models/sam_hq/modeling_sam_hq.py b/src/transformers/models/sam_hq/modeling_sam_hq.py index 5dee354b2600..5f39effe1bab 100644 --- a/src/transformers/models/sam_hq/modeling_sam_hq.py +++ b/src/transformers/models/sam_hq/modeling_sam_hq.py @@ -429,15 +429,16 @@ class SamHQPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_sdpa = True + @torch.no_grad() def _init_weights(self, module: nn.Module): super()._init_weights(module) if isinstance(module, SamHQVisionAttention): if module.use_rel_pos: - module.rel_pos_h.data.zero_() - module.rel_pos_w.data.zero_() + module.rel_pos_h.zero_() + module.rel_pos_w.zero_() elif isinstance(module, SamHQVisionEncoder): if self.config.use_abs_pos: - module.pos_embed.data.zero_() + module.pos_embed.zero_() class SamHQPatchEmbeddings(nn.Module): @@ -1236,7 +1237,9 @@ def forward( ) class SamHQModel(SamHQPreTrainedModel): input_modalities = ["image", "text"] - _tied_weights_keys = ["prompt_encoder.shared_embedding.positional_embedding"] + _tied_weights_keys = { + "prompt_encoder.shared_embedding.positional_embedding": "shared_image_embedding.positional_embedding" + } _keys_to_ignore_on_load_missing = ["prompt_encoder.shared_embedding.positional_embedding"] _can_record_outputs = {"mask_decoder_attentions": OutputRecorder(SamHQTwoWayAttentionBlock, index=2)} @@ -1252,11 +1255,6 @@ def __init__(self, config): self.post_init() - def _tie_weights(self): - self.prompt_encoder.shared_embedding.positional_embedding.data = ( - self.shared_image_embedding.positional_embedding.data - ) - def get_input_embeddings(self): return self.vision_encoder.get_input_embeddings() diff --git a/src/transformers/models/sam_hq/modular_sam_hq.py b/src/transformers/models/sam_hq/modular_sam_hq.py index 5e259fd1cece..e7cea1598e78 100644 --- a/src/transformers/models/sam_hq/modular_sam_hq.py +++ b/src/transformers/models/sam_hq/modular_sam_hq.py @@ -442,7 +442,9 @@ class SamHQVisionModel(SamVisionModel): """ ) class SamHQModel(SamModel): - _tied_weights_keys = ["prompt_encoder.shared_embedding.positional_embedding"] + _tied_weights_keys = { + "prompt_encoder.shared_embedding.positional_embedding": "shared_image_embedding.positional_embedding" + } _keys_to_ignore_on_load_missing = ["prompt_encoder.shared_embedding.positional_embedding"] def __init__(self, config): diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index 2388556f06e3..7efe8936d837 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -1342,17 +1342,18 @@ class SeamlessM4TPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["SeamlessM4TEncoderLayer", "SeamlessM4TDecoderLayer", "SeamlessM4TConformerEncoderLayer"] + @torch.no_grad() def _init_weights(self, module: nn.Module): """Initialize the weights""" std = self.config.initializer_range if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, SeamlessM4TConformerSelfAttention): if hasattr(module, "pos_bias_u"): nn.init.xavier_uniform_(module.pos_bias_u) @@ -1370,8 +1371,8 @@ def _init_weights(self, module: nn.Module): nn.init.uniform_(module.projection.weight, a=-k, b=k) nn.init.uniform_(module.projection.bias, a=-k, b=k) elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d)): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, nn.Conv1d): nn.init.kaiming_normal_(module.weight) if module.bias is not None: @@ -1978,7 +1979,7 @@ class SeamlessM4TTextToUnitForConditionalGeneration(SeamlessM4TPreTrainedModel, "text_encoder", "text_decoder", ] - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.decoder.embed_tokens.weight"} def __init__( self, @@ -2092,12 +2093,6 @@ def forward( def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return shift_tokens_right(labels, self.config.t2u_pad_token_id, self.config.t2u_decoder_start_token_id) - def _tie_weights(self) -> None: - if getattr(self.config, "tie_word_embeddings", True): - output_embeddings = self.get_output_embeddings() - if output_embeddings is not None: - self._tie_embedding_weights(output_embeddings, self.get_input_embeddings()) - ############ VOCODER related code ################ @@ -2405,20 +2400,21 @@ def forward( return hidden_states, lengths + @torch.no_grad() def _init_weights(self, module: nn.Module): """Initialize the weights.""" std = self.config.initializer_range if isinstance(module, (nn.Linear, nn.Conv1d, nn.ConvTranspose1d)): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() def apply_weight_norm(self): weight_norm = nn.utils.weight_norm @@ -2453,19 +2449,19 @@ class SeamlessM4TForTextToText(SeamlessM4TPreTrainedModel, GenerationMixin): _keys_to_ignore_on_load_missing = ["speech_encoder", "t2u_model", "vocoder"] main_input_name = "input_ids" - _tied_weights_keys = [ - "lm_head.weight", - "text_encoder.embed_tokens.weight", - "text_decoder.embed_tokens.weight", - ] + _tied_weights_keys = { + "lm_head.weight": "shared.weight", + "text_encoder.embed_tokens.weight": "shared.weight", + "text_decoder.embed_tokens.weight": "shared.weight", + } def __init__(self, config: SeamlessM4TConfig): super().__init__(config) self.shared = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) - self.text_encoder = SeamlessM4TEncoder(config, self.shared) - self.text_decoder = SeamlessM4TDecoder(config, self.shared) + self.text_encoder = SeamlessM4TEncoder(config) + self.text_decoder = SeamlessM4TDecoder(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing @@ -2485,12 +2481,6 @@ def set_input_embeddings(self, value): self.text_decoder.embed_tokens = value self.shared = value - def _tie_weights(self): - if self.config.tie_word_embeddings: - self._tie_embedding_weights(self.text_encoder.embed_tokens, self.shared) - self._tie_embedding_weights(self.text_decoder.embed_tokens, self.shared) - self._tie_embedding_weights(self.lm_head, self.shared) - @auto_docstring(custom_args=SEAMLESS_M4T_COMMON_CUSTOM_ARGS) def forward( self, @@ -2711,17 +2701,17 @@ class SeamlessM4TForSpeechToText(SeamlessM4TPreTrainedModel, GenerationMixin): _keys_to_ignore_on_load_missing = ["text_encoder", "t2u_model", "vocoder"] main_input_name = "input_features" - _tied_weights_keys = [ - "lm_head.weight", - "text_decoder.embed_tokens.weight", - ] + _tied_weights_keys = { + "lm_head.weight": "shared.weight", + "text_decoder.embed_tokens.weight": "shared.weight", + } def __init__(self, config: SeamlessM4TConfig): super().__init__(config) self.shared = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) self.speech_encoder = SeamlessM4TSpeechEncoder(config) - self.text_decoder = SeamlessM4TDecoder(config, self.shared) + self.text_decoder = SeamlessM4TDecoder(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing @@ -2739,11 +2729,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.text_decoder.embed_tokens = value - def _tie_weights(self): - if self.config.tie_word_embeddings: - self._tie_embedding_weights(self.text_decoder.embed_tokens, self.shared) - self._tie_embedding_weights(self.lm_head, self.shared) - @auto_docstring(custom_args=SEAMLESS_M4T_COMMON_CUSTOM_ARGS) def forward( self, @@ -2973,19 +2958,19 @@ class SeamlessM4TForTextToSpeech(SeamlessM4TPreTrainedModel, GenerationMixin): _keys_to_ignore_on_load_missing = ["speech_encoder"] main_input_name = "input_ids" - _tied_weights_keys = [ - "lm_head.weight", - "text_encoder.embed_tokens.weight", - "text_decoder.embed_tokens.weight", - ] + _tied_weights_keys = { + "lm_head.weight": "shared.weight", + "text_encoder.embed_tokens.weight": "shared.weight", + "text_decoder.embed_tokens.weight": "shared.weight", + } def __init__(self, config: SeamlessM4TConfig): super().__init__(config) self.shared = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) - self.text_encoder = SeamlessM4TEncoder(config, self.shared) - self.text_decoder = SeamlessM4TDecoder(config, self.shared) + self.text_encoder = SeamlessM4TEncoder(config) + self.text_decoder = SeamlessM4TDecoder(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing @@ -3008,12 +2993,6 @@ def set_input_embeddings(self, value): self.text_decoder.embed_tokens = value self.shared = value - def _tie_weights(self): - if self.config.tie_word_embeddings: - self._tie_embedding_weights(self.text_encoder.embed_tokens, self.shared) - self._tie_embedding_weights(self.text_decoder.embed_tokens, self.shared) - self._tie_embedding_weights(self.lm_head, self.shared) - @auto_docstring(custom_args=SEAMLESS_M4T_COMMON_CUSTOM_ARGS) def forward( self, @@ -3298,24 +3277,19 @@ class SeamlessM4TForSpeechToSpeech(SeamlessM4TPreTrainedModel, GenerationMixin): _keys_to_ignore_on_load_missing = ["text_encoder"] main_input_name = "input_features" - _tied_weights_keys = [ - "lm_head.weight", - "text_decoder.embed_tokens.weight", - ] + _tied_weights_keys = {"lm_head.weight": "shared.weight", "text_decoder.embed_tokens.weight": "shared.weight"} def __init__(self, config): super().__init__(config) self.shared = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) self.speech_encoder = SeamlessM4TSpeechEncoder(config) - self.text_decoder = SeamlessM4TDecoder(config, self.shared) + self.text_decoder = SeamlessM4TDecoder(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - # Initialize weights and apply final processing - self.post_init() - self.t2u_model = SeamlessM4TTextToUnitForConditionalGeneration(config) self.vocoder = SeamlessM4TCodeHifiGan(config) + self.post_init() def get_encoder(self): return self.speech_encoder @@ -3329,11 +3303,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.text_decoder.embed_tokens = value - def _tie_weights(self): - if self.config.tie_word_embeddings: - self._tie_embedding_weights(self.text_decoder.embed_tokens, self.shared) - self._tie_embedding_weights(self.lm_head, self.shared) - @auto_docstring(custom_args=SEAMLESS_M4T_COMMON_CUSTOM_ARGS) def forward( self, @@ -3628,11 +3597,11 @@ def generate( class SeamlessM4TModel(SeamlessM4TPreTrainedModel, GenerationMixin): input_modalities = ["audio", "text"] output_modalities = ["audio", "text"] - _tied_weights_keys = [ - "lm_head.weight", - "text_encoder.embed_tokens.weight", - "text_decoder.embed_tokens.weight", - ] + _tied_weights_keys = { + "lm_head.weight": "shared.weight", + "text_encoder.embed_tokens.weight": "shared.weight", + "text_decoder.embed_tokens.weight": "shared.weight", + } def __init__(self, config, current_modality="text"): r""" @@ -3643,9 +3612,9 @@ def __init__(self, config, current_modality="text"): self.shared = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) - self.text_encoder = SeamlessM4TEncoder(config, self.shared) + self.text_encoder = SeamlessM4TEncoder(config) self.speech_encoder = SeamlessM4TSpeechEncoder(config) - self.text_decoder = SeamlessM4TDecoder(config, self.shared) + self.text_decoder = SeamlessM4TDecoder(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing @@ -3683,12 +3652,6 @@ def set_input_embeddings(self, value): self.text_decoder.embed_tokens = value self.shared = value - def _tie_weights(self): - if self.config.tie_word_embeddings: - self._tie_embedding_weights(self.text_encoder.embed_tokens, self.shared) - self._tie_embedding_weights(self.text_decoder.embed_tokens, self.shared) - self._tie_embedding_weights(self.lm_head, self.shared) - @auto_docstring(custom_args=SEAMLESS_M4T_COMMON_CUSTOM_ARGS) def forward( self, diff --git a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py index 2775f8297f65..16aba775566c 100644 --- a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +++ b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py @@ -1258,17 +1258,18 @@ class SeamlessM4Tv2PreTrainedModel(PreTrainedModel): "SeamlessM4Tv2TextToUnitDecoderLayer", ] + @torch.no_grad() def _init_weights(self, module: nn.Module): """Initialize the weights""" std = self.config.initializer_range if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, SeamlessM4Tv2ConformerSelfAttention): if hasattr(module, "pos_bias_u"): nn.init.xavier_uniform_(module.pos_bias_u) @@ -1279,11 +1280,11 @@ def _init_weights(self, module: nn.Module): nn.init.uniform_(module.projection.weight, a=-k, b=k) nn.init.uniform_(module.projection.bias, a=-k, b=k) elif isinstance(module, SeamlessM4Tv2TextToUnitDecoder): - module.pos_emb_alpha_char.data.fill_(1) - module.pos_emb_alpha.data.fill_(1) + module.pos_emb_alpha_char.fill_(1) + module.pos_emb_alpha.fill_(1) elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, (nn.Conv1d, nn.ConvTranspose1d)): nn.init.kaiming_normal_(module.weight) if module.bias is not None: @@ -2179,7 +2180,7 @@ class SeamlessM4Tv2TextToUnitForConditionalGeneration(SeamlessM4Tv2PreTrainedMod "text_encoder", "text_decoder", ] - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.decoder.embed_tokens.weight"} # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TTextToUnitForConditionalGeneration.__init__ with SeamlessM4T->SeamlessM4Tv2 def __init__( @@ -2287,13 +2288,6 @@ def forward( loss=masked_lm_loss, ) - # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TTextToUnitForConditionalGeneration._tie_weights - def _tie_weights(self) -> None: - if getattr(self.config, "tie_word_embeddings", True): - output_embeddings = self.get_output_embeddings() - if output_embeddings is not None: - self._tie_embedding_weights(output_embeddings, self.get_input_embeddings()) - ############ VOCODER related code ################ @@ -2608,21 +2602,21 @@ def forward( return hidden_states, lengths - # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TCodeHifiGan._init_weights + @torch.no_grad() def _init_weights(self, module: nn.Module): """Initialize the weights.""" std = self.config.initializer_range if isinstance(module, (nn.Linear, nn.Conv1d, nn.ConvTranspose1d)): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TCodeHifiGan.apply_weight_norm def apply_weight_norm(self): @@ -2660,19 +2654,19 @@ class SeamlessM4Tv2ForTextToText(SeamlessM4Tv2PreTrainedModel, GenerationMixin): _keys_to_ignore_on_load_missing = ["speech_encoder", "t2u_model", "vocoder"] main_input_name = "input_ids" - _tied_weights_keys = [ - "lm_head.weight", - "text_encoder.embed_tokens.weight", - "text_decoder.embed_tokens.weight", - ] + _tied_weights_keys = { + "lm_head.weight": "shared.weight", + "text_encoder.embed_tokens.weight": "shared.weight", + "text_decoder.embed_tokens.weight": "shared.weight", + } def __init__(self, config: SeamlessM4Tv2Config): super().__init__(config) self.shared = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) - self.text_encoder = SeamlessM4Tv2Encoder(config, self.shared) - self.text_decoder = SeamlessM4Tv2Decoder(config, self.shared) + self.text_encoder = SeamlessM4Tv2Encoder(config) + self.text_decoder = SeamlessM4Tv2Decoder(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing @@ -2692,12 +2686,6 @@ def set_input_embeddings(self, value): self.text_decoder.embed_tokens = value self.shared = value - def _tie_weights(self): - if self.config.tie_word_embeddings: - self._tie_embedding_weights(self.text_encoder.embed_tokens, self.shared) - self._tie_embedding_weights(self.text_decoder.embed_tokens, self.shared) - self._tie_embedding_weights(self.lm_head, self.shared) - @auto_docstring(custom_args=SEAMLESS_M4T_V2_COMMON_CUSTOM_ARGS) def forward( self, @@ -2918,10 +2906,10 @@ class SeamlessM4Tv2ForSpeechToText(SeamlessM4Tv2PreTrainedModel, GenerationMixin _keys_to_ignore_on_load_missing = ["text_encoder", "t2u_model", "vocoder"] main_input_name = "input_features" - _tied_weights_keys = [ - "lm_head.weight", - "text_decoder.embed_tokens.weight", - ] + _tied_weights_keys = { + "lm_head.weight": "shared.weight", + "text_decoder.embed_tokens.weight": "shared.weight", + } # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForSpeechToText.__init__ with SeamlessM4T->SeamlessM4Tv2 def __init__(self, config: SeamlessM4Tv2Config): @@ -2929,7 +2917,7 @@ def __init__(self, config: SeamlessM4Tv2Config): self.shared = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) self.speech_encoder = SeamlessM4Tv2SpeechEncoder(config) - self.text_decoder = SeamlessM4Tv2Decoder(config, self.shared) + self.text_decoder = SeamlessM4Tv2Decoder(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing @@ -2951,12 +2939,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.text_decoder.embed_tokens = value - # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForSpeechToText._tie_weights - def _tie_weights(self): - if self.config.tie_word_embeddings: - self._tie_embedding_weights(self.text_decoder.embed_tokens, self.shared) - self._tie_embedding_weights(self.lm_head, self.shared) - @auto_docstring(custom_args=SEAMLESS_M4T_V2_COMMON_CUSTOM_ARGS) # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForSpeechToText.forward def forward( @@ -3188,11 +3170,11 @@ class SeamlessM4Tv2ForTextToSpeech(SeamlessM4Tv2PreTrainedModel, GenerationMixin _keys_to_ignore_on_load_missing = ["speech_encoder"] main_input_name = "input_ids" - _tied_weights_keys = [ - "lm_head.weight", - "text_encoder.embed_tokens.weight", - "text_decoder.embed_tokens.weight", - ] + _tied_weights_keys = { + "lm_head.weight": "shared.weight", + "text_encoder.embed_tokens.weight": "shared.weight", + "text_decoder.embed_tokens.weight": "shared.weight", + } # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForTextToSpeech.__init__ with SeamlessM4T->SeamlessM4Tv2 def __init__(self, config: SeamlessM4Tv2Config): @@ -3200,8 +3182,8 @@ def __init__(self, config: SeamlessM4Tv2Config): self.shared = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) - self.text_encoder = SeamlessM4Tv2Encoder(config, self.shared) - self.text_decoder = SeamlessM4Tv2Decoder(config, self.shared) + self.text_encoder = SeamlessM4Tv2Encoder(config) + self.text_decoder = SeamlessM4Tv2Decoder(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing @@ -3228,13 +3210,6 @@ def set_input_embeddings(self, value): self.text_decoder.embed_tokens = value self.shared = value - # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForTextToSpeech._tie_weights - def _tie_weights(self): - if self.config.tie_word_embeddings: - self._tie_embedding_weights(self.text_encoder.embed_tokens, self.shared) - self._tie_embedding_weights(self.text_decoder.embed_tokens, self.shared) - self._tie_embedding_weights(self.lm_head, self.shared) - @auto_docstring(custom_args=SEAMLESS_M4T_V2_COMMON_CUSTOM_ARGS) # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForTextToSpeech.forward with SeamlessM4T->SeamlessM4Tv2 def forward( @@ -3551,10 +3526,7 @@ class SeamlessM4Tv2ForSpeechToSpeech(SeamlessM4Tv2PreTrainedModel, GenerationMix _keys_to_ignore_on_load_missing = ["text_encoder"] main_input_name = "input_features" - _tied_weights_keys = [ - "lm_head.weight", - "text_decoder.embed_tokens.weight", - ] + _tied_weights_keys = {"lm_head.weight": "shared.weight", "text_decoder.embed_tokens.weight": "shared.weight"} # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForSpeechToSpeech.__init__ with SeamlessM4T->SeamlessM4Tv2 def __init__(self, config): @@ -3562,14 +3534,12 @@ def __init__(self, config): self.shared = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) self.speech_encoder = SeamlessM4Tv2SpeechEncoder(config) - self.text_decoder = SeamlessM4Tv2Decoder(config, self.shared) + self.text_decoder = SeamlessM4Tv2Decoder(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - # Initialize weights and apply final processing - self.post_init() - self.t2u_model = SeamlessM4Tv2TextToUnitForConditionalGeneration(config) self.vocoder = SeamlessM4Tv2CodeHifiGan(config) + self.post_init() # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForSpeechToSpeech.get_encoder def get_encoder(self): @@ -3587,12 +3557,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.text_decoder.embed_tokens = value - # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForSpeechToSpeech._tie_weights - def _tie_weights(self): - if self.config.tie_word_embeddings: - self._tie_embedding_weights(self.text_decoder.embed_tokens, self.shared) - self._tie_embedding_weights(self.lm_head, self.shared) - @auto_docstring(custom_args=SEAMLESS_M4T_V2_COMMON_CUSTOM_ARGS) # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForSpeechToSpeech.forward with SeamlessM4T->SeamlessM4Tv2 def forward( @@ -3918,11 +3882,11 @@ def generate( class SeamlessM4Tv2Model(SeamlessM4Tv2PreTrainedModel, GenerationMixin): input_modalities = ["audio", "text"] output_modalities = ["audio", "text"] - _tied_weights_keys = [ - "lm_head.weight", - "text_encoder.embed_tokens.weight", - "text_decoder.embed_tokens.weight", - ] + _tied_weights_keys = { + "lm_head.weight": "shared.weight", + "text_encoder.embed_tokens.weight": "shared.weight", + "text_decoder.embed_tokens.weight": "shared.weight", + } # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TModel.__init__ with SeamlessM4T->SeamlessM4Tv2 def __init__(self, config, current_modality="text"): @@ -3934,9 +3898,9 @@ def __init__(self, config, current_modality="text"): self.shared = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) - self.text_encoder = SeamlessM4Tv2Encoder(config, self.shared) + self.text_encoder = SeamlessM4Tv2Encoder(config) self.speech_encoder = SeamlessM4Tv2SpeechEncoder(config) - self.text_decoder = SeamlessM4Tv2Decoder(config, self.shared) + self.text_decoder = SeamlessM4Tv2Decoder(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing @@ -3978,13 +3942,6 @@ def set_input_embeddings(self, value): self.text_decoder.embed_tokens = value self.shared = value - # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TModel._tie_weights - def _tie_weights(self): - if self.config.tie_word_embeddings: - self._tie_embedding_weights(self.text_encoder.embed_tokens, self.shared) - self._tie_embedding_weights(self.text_decoder.embed_tokens, self.shared) - self._tie_embedding_weights(self.lm_head, self.shared) - @auto_docstring(custom_args=SEAMLESS_M4T_V2_COMMON_CUSTOM_ARGS) # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TModel.forward with SeamlessM4T->SeamlessM4Tv2 def forward( diff --git a/src/transformers/models/seed_oss/modeling_seed_oss.py b/src/transformers/models/seed_oss/modeling_seed_oss.py index 7e645e3ce052..7cd0093b9e69 100644 --- a/src/transformers/models/seed_oss/modeling_seed_oss.py +++ b/src/transformers/models/seed_oss/modeling_seed_oss.py @@ -439,7 +439,7 @@ def forward( @auto_docstring class SeedOssForCausalLM(SeedOssPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/segformer/modeling_segformer.py b/src/transformers/models/segformer/modeling_segformer.py index 99382806bedd..ea0a58568101 100755 --- a/src/transformers/models/segformer/modeling_segformer.py +++ b/src/transformers/models/segformer/modeling_segformer.py @@ -414,19 +414,20 @@ class SegformerPreTrainedModel(PreTrainedModel): main_input_name = "pixel_values" input_modalities = "image" + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) @auto_docstring diff --git a/src/transformers/models/seggpt/modeling_seggpt.py b/src/transformers/models/seggpt/modeling_seggpt.py index 9de5ad3a0729..80f98707757d 100644 --- a/src/transformers/models/seggpt/modeling_seggpt.py +++ b/src/transformers/models/seggpt/modeling_seggpt.py @@ -595,39 +595,46 @@ class SegGptPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["SegGptEmbeddings", "SegGptLayer"] + @torch.no_grad() def _init_weights(self, module: nn.Module) -> None: """Initialize the weights""" std = self.config.initializer_range if isinstance(module, (nn.Linear, nn.Conv2d)): # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid # `trunc_normal_cpu` not implemented in `half` issues - module.weight.data = nn.init.trunc_normal_(module.weight.data.to(torch.float32), mean=0.0, std=std).to( - module.weight.dtype + module.weight.copy_( + nn.init.trunc_normal_(module.weight.to(torch.float32), mean=0.0, std=std).to(module.weight.dtype) ) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, (nn.LayerNorm, SegGptLayerNorm)): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, SegGptAttention): - module.rel_pos_h.data = nn.init.trunc_normal_( - module.rel_pos_h.data.to(torch.float32), - mean=0.0, - std=std, - ).to(module.rel_pos_h.dtype) - - module.rel_pos_w.data = nn.init.trunc_normal_( - module.rel_pos_w.data.to(torch.float32), - mean=0.0, - std=std, - ).to(module.rel_pos_w.dtype) + module.rel_pos_h.copy_( + nn.init.trunc_normal_( + module.rel_pos_h.to(torch.float32), + mean=0.0, + std=std, + ).to(module.rel_pos_h.dtype) + ) + + module.rel_pos_w.copy_( + nn.init.trunc_normal_( + module.rel_pos_w.to(torch.float32), + mean=0.0, + std=std, + ).to(module.rel_pos_w.dtype) + ) elif isinstance(module, SegGptEmbeddings): - module.position_embeddings.data = nn.init.trunc_normal_( - module.position_embeddings.data.to(torch.float32), - mean=0.0, - std=std, - ).to(module.position_embeddings.dtype) + module.position_embeddings.copy_( + nn.init.trunc_normal_( + module.position_embeddings.to(torch.float32), + mean=0.0, + std=std, + ).to(module.position_embeddings.dtype) + ) torch.nn.init.normal_(module.mask_token, std=std) torch.nn.init.normal_(module.segment_token_input, std=std) diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index 8cf3e2d24036..728b63d408a5 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -518,6 +518,7 @@ class SEWPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = False # needs a proper look into the mask creation + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, SEWPositionalConvEmbedding): @@ -528,25 +529,25 @@ def _init_weights(self, module): ) nn.init.constant_(module.conv.bias, 0) elif isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, nn.Conv1d): if is_deepspeed_zero3_enabled(): import deepspeed if hasattr(module, "weight_v") and hasattr(module, "weight_g"): with deepspeed.zero.GatheredParameters([module.weight_v, module.weight_g], modifier_rank=0): - nn.init.kaiming_normal_(module.weight.data) + nn.init.kaiming_normal_(module.weight) else: with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0): - nn.init.kaiming_normal_(module.weight.data) + nn.init.kaiming_normal_(module.weight) else: - nn.init.kaiming_normal_(module.weight.data) + nn.init.kaiming_normal_(module.weight) if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): """ @@ -856,7 +857,7 @@ def __init__(self, config, target_lang: Optional[str] = None): # Initialize weights and apply final processing self.post_init() - def tie_weights(self): + def tie_weights(self, missing_keys=None): """ This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when passing `target_lang=...` to `from_pretrained(...)`. diff --git a/src/transformers/models/sew/modular_sew.py b/src/transformers/models/sew/modular_sew.py index 8a2cfc3a2689..4db3783036e5 100644 --- a/src/transformers/models/sew/modular_sew.py +++ b/src/transformers/models/sew/modular_sew.py @@ -255,6 +255,7 @@ class SEWPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = False # needs a proper look into the mask creation + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, SEWPositionalConvEmbedding): @@ -265,25 +266,25 @@ def _init_weights(self, module): ) nn.init.constant_(module.conv.bias, 0) elif isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, nn.Conv1d): if is_deepspeed_zero3_enabled(): import deepspeed if hasattr(module, "weight_v") and hasattr(module, "weight_g"): with deepspeed.zero.GatheredParameters([module.weight_v, module.weight_g], modifier_rank=0): - nn.init.kaiming_normal_(module.weight.data) + nn.init.kaiming_normal_(module.weight) else: with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0): - nn.init.kaiming_normal_(module.weight.data) + nn.init.kaiming_normal_(module.weight) else: - nn.init.kaiming_normal_(module.weight.data) + nn.init.kaiming_normal_(module.weight) if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): """ diff --git a/src/transformers/models/sew_d/modeling_sew_d.py b/src/transformers/models/sew_d/modeling_sew_d.py index 7dda40514663..e14224e12c1f 100644 --- a/src/transformers/models/sew_d/modeling_sew_d.py +++ b/src/transformers/models/sew_d/modeling_sew_d.py @@ -1187,6 +1187,7 @@ class SEWDPreTrainedModel(PreTrainedModel): input_modalities = "audio" supports_gradient_checkpointing = True + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, SEWDPositionalConvEmbedding): @@ -1197,29 +1198,29 @@ def _init_weights(self, module): ) nn.init.constant_(module.conv.bias, 0) elif isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, nn.Conv1d): if is_deepspeed_zero3_enabled(): import deepspeed if hasattr(module, "weight_v") and hasattr(module, "weight_g"): with deepspeed.zero.GatheredParameters([module.weight_v, module.weight_g], modifier_rank=0): - nn.init.kaiming_normal_(module.weight.data) + nn.init.kaiming_normal_(module.weight) else: with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0): - nn.init.kaiming_normal_(module.weight.data) + nn.init.kaiming_normal_(module.weight) else: - nn.init.kaiming_normal_(module.weight.data) + nn.init.kaiming_normal_(module.weight) elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): """ @@ -1409,7 +1410,7 @@ def __init__(self, config, target_lang: Optional[str] = None): # Initialize weights and apply final processing self.post_init() - def tie_weights(self): + def tie_weights(self, missing_keys=None): """ This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when passing `target_lang=...` to `from_pretrained(...)`. diff --git a/src/transformers/models/shieldgemma2/modeling_shieldgemma2.py b/src/transformers/models/shieldgemma2/modeling_shieldgemma2.py index 36fd972de140..fba702ecc342 100644 --- a/src/transformers/models/shieldgemma2/modeling_shieldgemma2.py +++ b/src/transformers/models/shieldgemma2/modeling_shieldgemma2.py @@ -76,9 +76,6 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model.language_model.get_decoder() - def tie_weights(self): - return self.model.language_model.tie_weights() - @auto_docstring def forward( self, diff --git a/src/transformers/models/siglip/modeling_siglip.py b/src/transformers/models/siglip/modeling_siglip.py index 9fbfb286a2a0..f414444e663f 100644 --- a/src/transformers/models/siglip/modeling_siglip.py +++ b/src/transformers/models/siglip/modeling_siglip.py @@ -485,6 +485,7 @@ class SiglipPreTrainedModel(PreTrainedModel): "attentions": SiglipAttention, } + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, SiglipVisionEmbeddings): @@ -511,13 +512,13 @@ def _init_weights(self, module): nn.init.normal_(module.fc1.bias, std=1e-6) nn.init.normal_(module.fc2.bias, std=1e-6) elif isinstance(module, SiglipMultiheadAttentionPoolingHead): - nn.init.xavier_uniform_(module.probe.data) - nn.init.xavier_uniform_(module.attention.in_proj_weight.data) - nn.init.zeros_(module.attention.in_proj_bias.data) + nn.init.xavier_uniform_(module.probe) + nn.init.xavier_uniform_(module.attention.in_proj_weight) + nn.init.zeros_(module.attention.in_proj_bias) elif isinstance(module, SiglipModel): logit_scale_init = torch.log(torch.tensor(1.0)) - module.logit_scale.data.fill_(logit_scale_init) - module.logit_bias.data.zero_() + module.logit_scale.fill_(logit_scale_init) + module.logit_bias.zero_() elif isinstance(module, SiglipForImageClassification): nn.init.normal_( module.classifier.weight, @@ -528,8 +529,8 @@ def _init_weights(self, module): if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->Siglip diff --git a/src/transformers/models/siglip2/modeling_siglip2.py b/src/transformers/models/siglip2/modeling_siglip2.py index a50e13329e83..e9b56fa58e6c 100644 --- a/src/transformers/models/siglip2/modeling_siglip2.py +++ b/src/transformers/models/siglip2/modeling_siglip2.py @@ -560,6 +560,7 @@ class Siglip2PreTrainedModel(PreTrainedModel): "attentions": Siglip2Attention, } + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, Siglip2VisionEmbeddings): @@ -586,13 +587,13 @@ def _init_weights(self, module): nn.init.normal_(module.fc1.bias, std=1e-6) nn.init.normal_(module.fc2.bias, std=1e-6) elif isinstance(module, Siglip2MultiheadAttentionPoolingHead): - nn.init.xavier_uniform_(module.probe.data) - nn.init.xavier_uniform_(module.attention.in_proj_weight.data) - nn.init.zeros_(module.attention.in_proj_bias.data) + nn.init.xavier_uniform_(module.probe) + nn.init.xavier_uniform_(module.attention.in_proj_weight) + nn.init.zeros_(module.attention.in_proj_bias) elif isinstance(module, Siglip2Model): logit_scale_init = torch.log(torch.tensor(1.0)) - module.logit_scale.data.fill_(logit_scale_init) - module.logit_bias.data.zero_() + module.logit_scale.fill_(logit_scale_init) + module.logit_bias.zero_() elif isinstance(module, Siglip2ForImageClassification): nn.init.normal_( module.classifier.weight, @@ -603,8 +604,8 @@ def _init_weights(self, module): if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) class Siglip2TextEmbeddings(nn.Module): diff --git a/src/transformers/models/smollm3/_tied_weights_keys = { b/src/transformers/models/smollm3/_tied_weights_keys = { new file mode 100644 index 000000000000..dd370f5fea56 --- /dev/null +++ b/src/transformers/models/smollm3/_tied_weights_keys = { @@ -0,0 +1,142 @@ + _tied_weights_keys = { + "cls.predictions.decoder.bias": "cls.predictions.bias", + "cls.predictions.decoder.weight": "bert.embeddings.word_embeddings.weight" + } + + _tied_weights_keys = { + "encoder.embed_tokens.weight": "shared.weight", + "decoder.embed_tokens.weight": "shared.weight", + } + + _tied_weights_keys = { + "lm_head.weight": "model.shared.weight", + } + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + + _tied_weights_keys = { + "lm_head.decoder.weight": "roberta.embeddings.word_embeddings.weight", + "lm_head.decoder.bias": "lm_head.bias" + } + + _tied_weights_keys = { + "lm_head.decoder.weight": "roberta.embedding.word_embeddings.weight", + "lm_head.decoder.bias": "lm_head.bias" + } + +tests/models/albert/test_modeling_albert.py : 2 failures +tests/models/bert/test_modeling_bert.py : 2 failures +tests/models/bert_generation/test_modeling_bert_generation.py : 2 failures +tests/models/big_bird/test_modeling_big_bird.py : 2 failures +tests/models/blip_2/test_modeling_blip_2.py : 2 failures +tests/models/codegen/test_modeling_codegen.py : 2 failures +tests/models/convbert/test_modeling_convbert.py : 2 failures +tests/models/d_fine/test_modeling_d_fine.py : 2 failures +tests/models/dab_detr/test_modeling_dab_detr.py : 2 failures +tests/models/data2vec/test_modeling_data2vec_audio.py : 2 failures +tests/models/data2vec/test_modeling_data2vec_text.py : 2 failures +tests/models/deberta/test_modeling_deberta.py : 2 failures +tests/models/deberta_v2/test_modeling_deberta_v2.py : 2 failures +tests/models/distilbert/test_modeling_distilbert.py : 2 failures +tests/models/electra/test_modeling_electra.py : 2 failures +tests/models/ernie/test_modeling_ernie.py : 2 failures +tests/models/flaubert/test_modeling_flaubert.py : 2 failures +tests/models/fnet/test_modeling_fnet.py : 2 failures +tests/models/git/test_modeling_git.py : 2 failures +tests/models/gptj/test_modeling_gptj.py : 2 failures +tests/models/layoutlm/test_modeling_layoutlm.py : 2 failures +tests/models/longformer/test_modeling_longformer.py : 2 failures +tests/models/marian/test_modeling_marian.py : 2 failures +tests/models/megatron_bert/test_modeling_megatron_bert.py : 2 failures +tests/models/mpnet/test_modeling_mpnet.py : 2 failures +tests/models/musicgen/test_modeling_musicgen.py : 2 failures +tests/models/musicgen_melody/test_modeling_musicgen_melody.py : 2 failures +tests/models/nystromformer/test_modeling_nystromformer.py : 2 failures +tests/models/reformer/test_modeling_reformer.py : 2 failures +tests/models/roberta/test_modeling_roberta.py : 2 failures +tests/models/roberta_prelayernorm/test_modeling_roberta_prelayernorm.py : 2 failures +tests/models/roc_bert/test_modeling_roc_bert.py : 2 failures +tests/models/roformer/test_modeling_roformer.py : 2 failures +tests/models/rt_detr/test_modeling_rt_detr.py : 2 failures +tests/models/rt_detr_v2/test_modeling_rt_detr_v2.py : 2 failures +tests/models/squeezebert/test_modeling_squeezebert.py : 2 failures +tests/models/tapas/test_modeling_tapas.py : 2 failures +tests/models/visual_bert/test_modeling_visual_bert.py : 2 failures +tests/models/xmod/test_modeling_xmod.py : 2 failures +tests/models/yoso/test_modeling_yoso.py : 2 failures +tests/models/apertus/test_modeling_apertus.py : 3 failures +tests/models/arcee/test_modeling_arcee.py : 3 failures +tests/models/cwm/test_modeling_cwm.py : 3 failures +tests/models/deepseek_v2/test_modeling_deepseek_v2.py : 3 failures +tests/models/dots1/test_modeling_dots1.py : 3 failures +tests/models/ernie4_5/test_modeling_ernie4_5.py : 3 failures +tests/models/exaone4/test_modeling_exaone4.py : 3 failures +tests/models/flex_olmo/test_modeling_flex_olmo.py : 3 failures +tests/models/funnel/test_modeling_funnel.py : 3 failures +tests/models/glm/test_modeling_glm.py : 3 failures +tests/models/glm4/test_modeling_glm4.py : 3 failures +tests/models/glm4_moe/test_modeling_glm4_moe.py : 3 failures +tests/models/gpt_oss/test_modeling_gpt_oss.py : 3 failures +tests/models/helium/test_modeling_helium.py : 3 failures +tests/models/ibert/test_modeling_ibert.py : 3 failures +tests/models/lfm2/test_modeling_lfm2.py : 3 failures +tests/models/lfm2_moe/test_modeling_lfm2_moe.py : 3 failures +tests/models/llama/test_modeling_llama.py : 3 failures +tests/models/longcat_flash/test_modeling_longcat_flash.py : 3 failures +tests/models/ministral/test_modeling_ministral.py : 3 failures +tests/models/mistral/test_modeling_mistral.py : 3 failures +tests/models/modernbert/test_modeling_modernbert.py : 3 failures +tests/models/modernbert_decoder/test_modeling_modernbert_decoder.py : 3 failures +tests/models/olmo3/test_modeling_olmo3.py : 3 failures +tests/models/phi3/test_modeling_phi3.py : 3 failures +tests/models/pop2piano/test_modeling_pop2piano.py : 3 failures +tests/models/qwen2/test_modeling_qwen2.py : 3 failures +tests/models/qwen2_moe/test_modeling_qwen2_moe.py : 3 failures +tests/models/qwen3/test_modeling_qwen3.py : 3 failures +tests/models/qwen3_moe/test_modeling_qwen3_moe.py : 3 failures +tests/models/seed_oss/test_modeling_seed_oss.py : 3 failures +tests/models/smollm3/test_modeling_smollm3.py : 3 failures +tests/models/starcoder2/test_modeling_starcoder2.py : 3 failures +tests/models/unispeech/test_modeling_unispeech.py : 3 failures +tests/models/wav2vec2_conformer/test_modeling_wav2vec2_conformer.py : 3 failures +tests/models/zamba/test_modeling_zamba.py : 3 failures +tests/models/blt/test_modeling_blt.py : 4 failures +tests/models/edgetam/test_modeling_edgetam.py : 4 failures +tests/models/ernie4_5_moe/test_modeling_ernie4_5_moe.py : 4 failures +tests/models/imagegpt/test_modeling_imagegpt.py : 4 failures +tests/models/mamba/test_modeling_mamba.py : 4 failures +tests/models/mixtral/test_modeling_mixtral.py : 4 failures +tests/models/mra/test_modeling_mra.py : 4 failures +tests/models/sam/test_modeling_sam.py : 4 failures +tests/models/sam2/test_modeling_sam2.py : 4 failures +tests/models/sam_hq/test_modeling_sam_hq.py : 4 failures +tests/models/speecht5/test_modeling_speecht5.py : 4 failures +tests/models/tvp/test_modeling_tvp.py : 4 failures +tests/models/phi/test_modeling_phi.py : 5 failures +tests/models/timm_wrapper/test_modeling_timm_wrapper.py : 5 failures +tests/models/unispeech_sat/test_modeling_unispeech_sat.py : 5 failures +tests/models/grounding_dino/test_modeling_grounding_dino.py : 6 failures +tests/models/mm_grounding_dino/test_modeling_mm_grounding_dino.py : 6 failures +tests/models/udop/test_modeling_udop.py : 6 failures +tests/models/auto/test_modeling_auto.py : 7 failures +tests/models/deformable_detr/test_modeling_deformable_detr.py : 7 failures +tests/models/flava/test_modeling_flava.py : 7 failures +tests/models/minimax/test_modeling_minimax.py : 8 failures +tests/models/bark/test_modeling_bark.py : 10 failures +tests/models/blip/test_modeling_blip.py : 10 failures +tests/models/mllama/test_modeling_mllama.py : 11 failures + +tests/models/encoder_decoder/test_modeling_encoder_decoder.py : 12 failures +tests/models/seamless_m4t/test_modeling_seamless_m4t.py : 12 failures + + +# PROBABLY just + if isinstance(input_embeddings, nn.Module): + for k, v in input_embeddings.named_parameters(): + module, param_type = get_module_from_name(output_embeddings, k) + setattr(output_embeddings, k, v) + + +tests/models/d_fine/test_modeling_d_fine.py : 25 failures +tests/models/dab_detr/test_modeling_dab_detr.py : 25 failures +tests/models/rt_detr/test_modeling_rt_detr.py : 25 failures +tests/models/rt_detr_v2/test_modeling_rt_detr_v2.py : 25 failures \ No newline at end of file diff --git a/src/transformers/models/smollm3/modeling_smollm3.py b/src/transformers/models/smollm3/modeling_smollm3.py index e11c1138b490..e23d4993e84c 100644 --- a/src/transformers/models/smollm3/modeling_smollm3.py +++ b/src/transformers/models/smollm3/modeling_smollm3.py @@ -456,7 +456,7 @@ def forward( @auto_docstring class SmolLM3ForCausalLM(SmolLM3PreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/smolvlm/modeling_smolvlm.py b/src/transformers/models/smolvlm/modeling_smolvlm.py index e7b120369a7b..95983cc1c305 100644 --- a/src/transformers/models/smolvlm/modeling_smolvlm.py +++ b/src/transformers/models/smolvlm/modeling_smolvlm.py @@ -83,22 +83,23 @@ class SmolVLMPreTrainedModel(PreTrainedModel): _supports_attention_backend = True + @torch.no_grad() def _init_weights(self, module): std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() elif isinstance(module, SmolVLMRMSNorm): - module.weight.data.fill_(1.0) + module.weight.fill_(1.0) class SmolVLMVisionEmbeddings(nn.Module): @@ -774,7 +775,7 @@ class SmolVLMCausalLMOutputWithPast(ModelOutput): """ ) class SmolVLMForConditionalGeneration(SmolVLMPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index 00e77d7465ed..0176ef4fa636 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -495,16 +495,17 @@ class Speech2TextPreTrainedModel(PreTrainedModel): _supports_sdpa = False _supports_flex_attn = False + @torch.no_grad() def _init_weights(self, module): std = self.config.init_std if isinstance(module, (nn.Linear, nn.Conv1d)): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): """ @@ -1023,7 +1024,7 @@ def forward( class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel, GenerationMixin): input_modalities = ["audio", "text"] base_model_prefix = "model" - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.decoder.embed_tokens.weight"} def __init__(self, config: Speech2TextConfig): super().__init__(config) diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index 72c63fb86d43..74744a42e6f5 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -1170,6 +1170,7 @@ class SpeechT5PreTrainedModel(PreTrainedModel): input_modalities = "audio" supports_gradient_checkpointing = True + @torch.no_grad() def _init_weights(self, module: nn.Module): """Initialize the weights""" std = self.config.initializer_range @@ -1181,27 +1182,27 @@ def _init_weights(self, module: nn.Module): ) nn.init.constant_(module.conv.bias, 0) elif isinstance(module, SpeechT5ScaledPositionalEncoding): - module.alpha.data.fill_(1.0) + module.alpha.fill_(1.0) elif isinstance(module, SpeechT5FeatureProjection): k = math.sqrt(1 / module.projection.in_features) nn.init.uniform_(module.projection.weight, a=-k, b=k) nn.init.uniform_(module.projection.bias, a=-k, b=k) elif isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm1d)): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, nn.Conv1d): nn.init.kaiming_normal_(module.weight) if module.bias is not None: k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) nn.init.uniform_(module.bias, a=-k, b=k) elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() if hasattr(module, "masked_spec_embed"): nn.init.uniform_(module.masked_spec_embed) @@ -1996,7 +1997,7 @@ def forward( """ ) class SpeechT5ForSpeechToText(SpeechT5PreTrainedModel, GenerationMixin): - _tied_weights_keys = ["text_decoder_postnet.lm_head.weight"] + _tied_weights_keys = {"text_decoder_postnet.lm_head.weight": "speecht5.decoder.prenet.embed_tokens.weight"} def __init__(self, config: SpeechT5Config): super().__init__(config) @@ -3014,12 +3015,13 @@ def __init__(self, config: SpeechT5HifiGanConfig): # Initialize weights and apply final processing self.post_init() + @torch.no_grad() def _init_weights(self, module: nn.Module): """Initialize the weights.""" if isinstance(module, (nn.Conv1d, nn.ConvTranspose1d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() def apply_weight_norm(self): weight_norm = nn.utils.weight_norm diff --git a/src/transformers/models/splinter/modeling_splinter.py b/src/transformers/models/splinter/modeling_splinter.py index 176ed5f479c7..d0fa3699207b 100755 --- a/src/transformers/models/splinter/modeling_splinter.py +++ b/src/transformers/models/splinter/modeling_splinter.py @@ -331,19 +331,20 @@ class SplinterPreTrainedModel(PreTrainedModel): base_model_prefix = "splinter" supports_gradient_checkpointing = True + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) @auto_docstring diff --git a/src/transformers/models/squeezebert/modeling_squeezebert.py b/src/transformers/models/squeezebert/modeling_squeezebert.py index 7b2244b42b28..b5418e34a575 100644 --- a/src/transformers/models/squeezebert/modeling_squeezebert.py +++ b/src/transformers/models/squeezebert/modeling_squeezebert.py @@ -378,15 +378,11 @@ def __init__(self, config): # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. - self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` - self.decoder.bias = self.bias - - def _tie_weights(self) -> None: - self.decoder.bias = self.bias def forward(self, hidden_states): hidden_states = self.transform(hidden_states) @@ -409,21 +405,22 @@ class SqueezeBertPreTrainedModel(PreTrainedModel): config: SqueezeBertConfig base_model_prefix = "transformer" + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv1d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, SqueezeBertLMPredictionHead): - module.bias.data.zero_() + module.bias.zero_() @auto_docstring @@ -507,7 +504,10 @@ def forward( @auto_docstring class SqueezeBertForMaskedLM(SqueezeBertPreTrainedModel): - _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"] + _tied_weights_keys = { + "cls.predictions.decoder.bias": "cls.predictions.bias", + "cls.predictions.decoder.weight": "transformer.embeddings.word_embeddings.weight", + } def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 6698273cfae3..f2ab414ff30c 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -452,19 +452,20 @@ class StableLmPreTrainedModel(PreTrainedModel): _can_compile_fullgraph = True + @torch.no_grad() def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() @auto_docstring @@ -710,7 +711,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( # Copied from transformers.models.persimmon.modeling_persimmon.PersimmonForCausalLM with PERSIMMON->STABLELM,Persimmon->StableLm class StableLmForCausalLM(StableLmPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with LLAMA->STABLELM,Llama->StableLm def __init__(self, config): diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 6b93c18a3d17..042033fe3565 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -420,7 +420,7 @@ def forward( @auto_docstring class Starcoder2ForCausalLM(Starcoder2PreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/superglue/modeling_superglue.py b/src/transformers/models/superglue/modeling_superglue.py index 61495fc31164..fbba759df1b5 100644 --- a/src/transformers/models/superglue/modeling_superglue.py +++ b/src/transformers/models/superglue/modeling_superglue.py @@ -469,18 +469,19 @@ class SuperGluePreTrainedModel(PreTrainedModel): main_input_name = "pixel_values" input_modalities = "image" + @torch.no_grad() def _init_weights(self, module: nn.Module) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.BatchNorm1d): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) if hasattr(module, "bin_score"): - module.bin_score.data.fill_(1.0) + module.bin_score.fill_(1.0) @auto_docstring( diff --git a/src/transformers/models/superpoint/modeling_superpoint.py b/src/transformers/models/superpoint/modeling_superpoint.py index c211705aaefd..9e2abdeb863f 100644 --- a/src/transformers/models/superpoint/modeling_superpoint.py +++ b/src/transformers/models/superpoint/modeling_superpoint.py @@ -328,15 +328,16 @@ class SuperPointPreTrainedModel(PreTrainedModel): input_modalities = "image" supports_gradient_checkpointing = False + @torch.no_grad() def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) def extract_one_channel_pixel_values(self, pixel_values: torch.FloatTensor) -> torch.FloatTensor: """ diff --git a/src/transformers/models/swiftformer/modeling_swiftformer.py b/src/transformers/models/swiftformer/modeling_swiftformer.py index 9eed87cd4166..5742e0c52e1e 100644 --- a/src/transformers/models/swiftformer/modeling_swiftformer.py +++ b/src/transformers/models/swiftformer/modeling_swiftformer.py @@ -388,6 +388,7 @@ class SwiftFormerPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["SwiftFormerEncoderBlock"] + @torch.no_grad() def _init_weights(self, module: nn.Module) -> None: """Initialize the weights""" if isinstance(module, (nn.Conv2d, nn.Linear)): @@ -398,11 +399,11 @@ def _init_weights(self, module: nn.Module) -> None: nn.init.constant_(module.bias, 0) nn.init.constant_(module.weight, 1.0) elif isinstance(module, (SwiftFormerConvEncoder, SwiftFormerLocalRepresentation)): - module.layer_scale.data.fill_(1.0) + module.layer_scale.fill_(1.0) elif isinstance(module, SwiftFormerEncoderBlock): if self.config.use_layer_scale: - module.layer_scale_1.data.fill_(self.config.layer_scale_init_value) - module.layer_scale_2.data.fill_(self.config.layer_scale_init_value) + module.layer_scale_1.fill_(self.config.layer_scale_init_value) + module.layer_scale_2.fill_(self.config.layer_scale_init_value) elif isinstance(module, SwiftFormerEfficientAdditiveAttention): nn.init.normal_(module.w_g) diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index 9835a395e936..82bf2bfbc173 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -811,22 +811,23 @@ class SwinPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["SwinStage"] + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, SwinEmbeddings): if module.mask_token is not None: - module.mask_token.data.zero_() + module.mask_token.zero_() if module.position_embeddings is not None: - module.position_embeddings.data.zero_() + module.position_embeddings.zero_() elif isinstance(module, SwinSelfAttention): - module.relative_position_bias_table.data.zero_() + module.relative_position_bias_table.zero_() @auto_docstring diff --git a/src/transformers/models/swin2sr/modeling_swin2sr.py b/src/transformers/models/swin2sr/modeling_swin2sr.py index 4fb1267f47cd..093d34994b3a 100644 --- a/src/transformers/models/swin2sr/modeling_swin2sr.py +++ b/src/transformers/models/swin2sr/modeling_swin2sr.py @@ -691,15 +691,16 @@ class Swin2SRPreTrainedModel(PreTrainedModel): input_modalities = "image" supports_gradient_checkpointing = True + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - torch.nn.init.trunc_normal_(module.weight.data, std=self.config.initializer_range) + torch.nn.init.trunc_normal_(module.weight, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) @auto_docstring diff --git a/src/transformers/models/swinv2/modeling_swinv2.py b/src/transformers/models/swinv2/modeling_swinv2.py index 0d87c23ffc69..ffbeff3456ca 100644 --- a/src/transformers/models/swinv2/modeling_swinv2.py +++ b/src/transformers/models/swinv2/modeling_swinv2.py @@ -886,22 +886,23 @@ class Swinv2PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["Swinv2Stage"] + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, Swinv2Embeddings): if module.mask_token is not None: - module.mask_token.data.zero_() + module.mask_token.zero_() if module.position_embeddings is not None: - module.position_embeddings.data.zero_() + module.position_embeddings.zero_() elif isinstance(module, Swinv2SelfAttention): - module.logit_scale.data.fill_(math.log(10)) + module.logit_scale.fill_(math.log(10)) @auto_docstring diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 29f5e9c2c99a..07ffd1c280c3 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -587,43 +587,44 @@ class SwitchTransformersPreTrainedModel(PreTrainedModel): _can_compile_fullgraph = False _no_split_modules = ["SwitchTransformersBlock"] + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" factor = self.config.initializer_factor # Used for testing weights initialization if isinstance(module, SwitchTransformersLayerNorm): - module.weight.data.fill_(factor * 1.0) + module.weight.fill_(factor * 1.0) elif isinstance( module, (SwitchTransformersModel, SwitchTransformersForConditionalGeneration, SwitchTransformersEncoderModel), ): - module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) + module.shared.weight.normal_(mean=0.0, std=factor * 1.0) if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: - module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0) + module.lm_head.weight.normal_(mean=0.0, std=factor * 1.0) elif isinstance(module, SwitchTransformersDenseActDense): - module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + module.wi.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.wi, "bias") and module.wi.bias is not None: - module.wi.bias.data.zero_() - module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + module.wi.bias.zero_() + module.wo.weight.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) if hasattr(module.wo, "bias") and module.wo.bias is not None: - module.wo.bias.data.zero_() + module.wo.bias.zero_() elif isinstance(module, SwitchTransformersAttention): d_model = self.config.d_model key_value_proj_dim = self.config.d_kv n_heads = self.config.num_heads - module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) - module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) - module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) - module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) + module.q.weight.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) + module.k.weight.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.v.weight.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.o.weight.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) if module.has_relative_attention_bias: - module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) + module.relative_attention_bias.weight.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) elif isinstance(module, SwitchTransformersSparseMLP): d_model = self.config.d_model key_value_proj_dim = self.config.d_kv n_heads = self.config.num_heads - module.router.classifier.weight.data.normal_(mean=0.0, std=factor * 1) + module.router.classifier.weight.normal_(mean=0.0, std=factor * 1) for idx in range(self.config.num_experts): - module.experts[f"expert_{idx}"].wi.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) - module.experts[f"expert_{idx}"].wo.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.experts[f"expert_{idx}"].wi.weight.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.experts[f"expert_{idx}"].wo.weight.normal_(mean=0.0, std=factor * (d_model**-0.5)) def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id @@ -655,11 +656,9 @@ class SwitchTransformersStack(SwitchTransformersPreTrainedModel): "router_logits": SwitchTransformersTop1Router, } - def __init__(self, config, embed_tokens=None): + def __init__(self, config): super().__init__(config) self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model) - if embed_tokens is not None: - self.embed_tokens.weight = embed_tokens.weight self.is_decoder = config.is_decoder @@ -910,7 +909,10 @@ def _prepare_4d_causal_attention_mask_with_cache_position( @auto_docstring class SwitchTransformersModel(SwitchTransformersPreTrainedModel): - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + _tied_weights_keys = { + "encoder.embed_tokens.weight": "shared.weight", + "decoder.embed_tokens.weight": "shared.weight", + } def __init__(self, config: SwitchTransformersConfig): super().__init__(config) @@ -920,12 +922,12 @@ def __init__(self, config: SwitchTransformersConfig): encoder_config.is_decoder = False encoder_config.use_cache = False encoder_config.tie_encoder_decoder = False - self.encoder = SwitchTransformersStack(encoder_config, self.shared) + self.encoder = SwitchTransformersStack(encoder_config) decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True decoder_config.tie_encoder_decoder = False - self.decoder = SwitchTransformersStack(decoder_config, self.shared) + self.decoder = SwitchTransformersStack(decoder_config) # Initialize weights and apply final processing self.post_init() @@ -938,11 +940,6 @@ def set_input_embeddings(self, new_embeddings): self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) - def _tie_weights(self): - if self.config.tie_word_embeddings: - self._tie_embedding_weights(self.encoder.embed_tokens, self.shared) - self._tie_embedding_weights(self.decoder.embed_tokens, self.shared) - def get_encoder(self): return self.encoder @@ -1063,7 +1060,11 @@ def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.T """ ) class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + _tied_weights_keys = { + "encoder.embed_tokens.weight": "shared.weight", + "decoder.embed_tokens.weight": "shared.weight", + "lm_head.weight": "shared.weight", + } def __init__(self, config: SwitchTransformersConfig): super().__init__(config) @@ -1075,13 +1076,13 @@ def __init__(self, config: SwitchTransformersConfig): encoder_config.is_decoder = False encoder_config.use_cache = False encoder_config.tie_encoder_decoder = False - self.encoder = SwitchTransformersStack(encoder_config, self.shared) + self.encoder = SwitchTransformersStack(encoder_config) decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True decoder_config.tie_encoder_decoder = False decoder_config.num_layers = config.num_decoder_layers - self.decoder = SwitchTransformersStack(decoder_config, self.shared) + self.decoder = SwitchTransformersStack(decoder_config) self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) @@ -1097,11 +1098,6 @@ def set_input_embeddings(self, new_embeddings): self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) - def _tie_weights(self): - if self.config.tie_word_embeddings: - self._tie_embedding_weights(self.encoder.embed_tokens, self.shared) - self._tie_embedding_weights(self.decoder.embed_tokens, self.shared) - def get_encoder(self): return self.encoder @@ -1224,7 +1220,9 @@ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): class SwitchTransformersEncoderModel(SwitchTransformersPreTrainedModel): - _tied_weights_keys = ["encoder.embed_tokens.weight"] + _tied_weights_keys = { + "encoder.embed_tokens.weight": "shared.weight", + } _can_record_outputs = { "hidden_states": SwitchTransformersBlock, "attentions": OutputRecorder(SwitchTransformersAttention, index=-1, layer_name="layer.0"), @@ -1238,7 +1236,7 @@ def __init__(self, config: SwitchTransformersConfig): encoder_config = copy.deepcopy(config) encoder_config.use_cache = False encoder_config.is_encoder_decoder = False - self.encoder = SwitchTransformersStack(encoder_config, self.shared) + self.encoder = SwitchTransformersStack(encoder_config) self.post_init() def get_input_embeddings(self): @@ -1248,10 +1246,6 @@ def set_input_embeddings(self, new_embeddings): self.shared = new_embeddings self.encoder.set_input_embeddings(new_embeddings) - def _tie_weights(self): - if self.config.tie_word_embeddings: - self._tie_embedding_weights(self.encoder.embed_tokens, self.shared) - def get_encoder(self): return self.encoder diff --git a/src/transformers/models/switch_transformers/modular_switch_transformers.py b/src/transformers/models/switch_transformers/modular_switch_transformers.py index 274dc6ca44b7..d1a9f3788290 100644 --- a/src/transformers/models/switch_transformers/modular_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modular_switch_transformers.py @@ -343,43 +343,44 @@ class SwitchTransformersPreTrainedModel(PreTrainedModel): _can_compile_fullgraph = False _no_split_modules = ["SwitchTransformersBlock"] + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" factor = self.config.initializer_factor # Used for testing weights initialization if isinstance(module, SwitchTransformersLayerNorm): - module.weight.data.fill_(factor * 1.0) + module.weight.fill_(factor * 1.0) elif isinstance( module, (SwitchTransformersModel, SwitchTransformersForConditionalGeneration, SwitchTransformersEncoderModel), ): - module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) + module.shared.weight.normal_(mean=0.0, std=factor * 1.0) if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: - module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0) + module.lm_head.weight.normal_(mean=0.0, std=factor * 1.0) elif isinstance(module, SwitchTransformersDenseActDense): - module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + module.wi.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.wi, "bias") and module.wi.bias is not None: - module.wi.bias.data.zero_() - module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + module.wi.bias.zero_() + module.wo.weight.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) if hasattr(module.wo, "bias") and module.wo.bias is not None: - module.wo.bias.data.zero_() + module.wo.bias.zero_() elif isinstance(module, SwitchTransformersAttention): d_model = self.config.d_model key_value_proj_dim = self.config.d_kv n_heads = self.config.num_heads - module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) - module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) - module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) - module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) + module.q.weight.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) + module.k.weight.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.v.weight.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.o.weight.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) if module.has_relative_attention_bias: - module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) + module.relative_attention_bias.weight.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) elif isinstance(module, SwitchTransformersSparseMLP): d_model = self.config.d_model key_value_proj_dim = self.config.d_kv n_heads = self.config.num_heads - module.router.classifier.weight.data.normal_(mean=0.0, std=factor * 1) + module.router.classifier.weight.normal_(mean=0.0, std=factor * 1) for idx in range(self.config.num_experts): - module.experts[f"expert_{idx}"].wi.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) - module.experts[f"expert_{idx}"].wo.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.experts[f"expert_{idx}"].wi.weight.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.experts[f"expert_{idx}"].wo.weight.normal_(mean=0.0, std=factor * (d_model**-0.5)) def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id @@ -411,11 +412,9 @@ class SwitchTransformersStack(SwitchTransformersPreTrainedModel): "router_logits": SwitchTransformersTop1Router, } - def __init__(self, config, embed_tokens=None): + def __init__(self, config): super().__init__(config) self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model) - if embed_tokens is not None: - self.embed_tokens.weight = embed_tokens.weight self.is_decoder = config.is_decoder @@ -666,7 +665,10 @@ def _prepare_4d_causal_attention_mask_with_cache_position( @auto_docstring class SwitchTransformersModel(SwitchTransformersPreTrainedModel): - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + _tied_weights_keys = { + "encoder.embed_tokens.weight": "shared.weight", + "decoder.embed_tokens.weight": "shared.weight", + } def __init__(self, config: SwitchTransformersConfig): super().__init__(config) @@ -676,12 +678,12 @@ def __init__(self, config: SwitchTransformersConfig): encoder_config.is_decoder = False encoder_config.use_cache = False encoder_config.tie_encoder_decoder = False - self.encoder = SwitchTransformersStack(encoder_config, self.shared) + self.encoder = SwitchTransformersStack(encoder_config) decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True decoder_config.tie_encoder_decoder = False - self.decoder = SwitchTransformersStack(decoder_config, self.shared) + self.decoder = SwitchTransformersStack(decoder_config) # Initialize weights and apply final processing self.post_init() @@ -694,11 +696,6 @@ def set_input_embeddings(self, new_embeddings): self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) - def _tie_weights(self): - if self.config.tie_word_embeddings: - self._tie_embedding_weights(self.encoder.embed_tokens, self.shared) - self._tie_embedding_weights(self.decoder.embed_tokens, self.shared) - def get_encoder(self): return self.encoder @@ -754,7 +751,11 @@ def forward( """ ) class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + _tied_weights_keys = { + "encoder.embed_tokens.weight": "shared.weight", + "decoder.embed_tokens.weight": "shared.weight", + "lm_head.weight": "shared.weight", + } def __init__(self, config: SwitchTransformersConfig): super().__init__(config) @@ -766,13 +767,13 @@ def __init__(self, config: SwitchTransformersConfig): encoder_config.is_decoder = False encoder_config.use_cache = False encoder_config.tie_encoder_decoder = False - self.encoder = SwitchTransformersStack(encoder_config, self.shared) + self.encoder = SwitchTransformersStack(encoder_config) decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True decoder_config.tie_encoder_decoder = False decoder_config.num_layers = config.num_decoder_layers - self.decoder = SwitchTransformersStack(decoder_config, self.shared) + self.decoder = SwitchTransformersStack(decoder_config) self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) @@ -788,11 +789,6 @@ def set_input_embeddings(self, new_embeddings): self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) - def _tie_weights(self): - if self.config.tie_word_embeddings: - self._tie_embedding_weights(self.encoder.embed_tokens, self.shared) - self._tie_embedding_weights(self.decoder.embed_tokens, self.shared) - def get_encoder(self): return self.encoder @@ -915,7 +911,9 @@ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): class SwitchTransformersEncoderModel(SwitchTransformersPreTrainedModel): - _tied_weights_keys = ["encoder.embed_tokens.weight"] + _tied_weights_keys = { + "encoder.embed_tokens.weight": "shared.weight", + } _can_record_outputs = { "hidden_states": SwitchTransformersBlock, "attentions": OutputRecorder(SwitchTransformersAttention, index=-1, layer_name="layer.0"), @@ -929,7 +927,7 @@ def __init__(self, config: SwitchTransformersConfig): encoder_config = copy.deepcopy(config) encoder_config.use_cache = False encoder_config.is_encoder_decoder = False - self.encoder = SwitchTransformersStack(encoder_config, self.shared) + self.encoder = SwitchTransformersStack(encoder_config) self.post_init() def get_input_embeddings(self): @@ -939,10 +937,6 @@ def set_input_embeddings(self, new_embeddings): self.shared = new_embeddings self.encoder.set_input_embeddings(new_embeddings) - def _tie_weights(self): - if self.config.tie_word_embeddings: - self._tie_embedding_weights(self.encoder.embed_tokens, self.shared) - def get_encoder(self): return self.encoder diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 4a0f60dfacaf..c6fca843efed 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -570,59 +570,60 @@ def dummy_inputs(self): } return dummy_inputs + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" factor = self.config.initializer_factor # Used for testing weights initialization if isinstance(module, T5LayerNorm): - module.weight.data.fill_(factor * 1.0) + module.weight.fill_(factor * 1.0) elif isinstance( module, (T5Model, T5ForConditionalGeneration, T5EncoderModel, T5ForQuestionAnswering), ): - module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) + module.shared.weight.normal_(mean=0.0, std=factor * 1.0) if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: - module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0) + module.lm_head.weight.normal_(mean=0.0, std=factor * 1.0) if hasattr(module, "qa_outputs"): - module.qa_outputs.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) - module.qa_outputs.bias.data.zero_() + module.qa_outputs.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + module.qa_outputs.bias.zero_() elif isinstance(module, T5ForTokenClassification): if hasattr(module, "classifier"): - module.classifier.weight.data.normal_(mean=0.0, std=factor * 1.0) - module.classifier.bias.data.zero_() + module.classifier.weight.normal_(mean=0.0, std=factor * 1.0) + module.classifier.bias.zero_() elif isinstance(module, T5ClassificationHead): - module.dense.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + module.dense.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.dense, "bias") and module.dense.bias is not None: - module.dense.bias.data.zero_() - module.out_proj.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + module.dense.bias.zero_() + module.out_proj.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None: - module.out_proj.bias.data.zero_() + module.out_proj.bias.zero_() elif isinstance(module, T5DenseActDense): - module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + module.wi.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.wi, "bias") and module.wi.bias is not None: - module.wi.bias.data.zero_() - module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + module.wi.bias.zero_() + module.wo.weight.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) if hasattr(module.wo, "bias") and module.wo.bias is not None: - module.wo.bias.data.zero_() + module.wo.bias.zero_() elif isinstance(module, T5DenseGatedActDense): - module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + module.wi_0.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None: - module.wi_0.bias.data.zero_() - module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + module.wi_0.bias.zero_() + module.wi_1.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None: - module.wi_1.bias.data.zero_() - module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + module.wi_1.bias.zero_() + module.wo.weight.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) if hasattr(module.wo, "bias") and module.wo.bias is not None: - module.wo.bias.data.zero_() + module.wo.bias.zero_() elif isinstance(module, T5Attention): d_model = self.config.d_model key_value_proj_dim = self.config.d_kv n_heads = self.config.num_heads - module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) - module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) - module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) - module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) + module.q.weight.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) + module.k.weight.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.v.weight.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.o.weight.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) if module.has_relative_attention_bias: - module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) + module.relative_attention_bias.weight.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id @@ -647,10 +648,10 @@ def _shift_right(self, input_ids): class T5Stack(T5PreTrainedModel): - def __init__(self, config, embed_tokens=None): + def __init__(self, config): super().__init__(config) - self.embed_tokens = embed_tokens + self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model) self.is_decoder = config.is_decoder self.block = nn.ModuleList( @@ -971,7 +972,10 @@ class T5Model(T5PreTrainedModel): _keys_to_ignore_on_load_unexpected = [ "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", ] - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + _tied_weights_keys = { + "encoder.embed_tokens.weight": "shared.weight", + "decoder.embed_tokens.weight": "shared.weight", + } def __init__(self, config: T5Config): super().__init__(config) @@ -981,13 +985,13 @@ def __init__(self, config: T5Config): encoder_config.is_decoder = False encoder_config.use_cache = False encoder_config.tie_encoder_decoder = False - self.encoder = T5Stack(encoder_config, self.shared) + self.encoder = T5Stack(encoder_config) decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True decoder_config.tie_encoder_decoder = False decoder_config.num_layers = config.num_decoder_layers - self.decoder = T5Stack(decoder_config, self.shared) + self.decoder = T5Stack(decoder_config) # Initialize weights and apply final processing self.post_init() @@ -1000,11 +1004,6 @@ def set_input_embeddings(self, new_embeddings): self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) - def _tie_weights(self): - if self.config.tie_word_embeddings: - self._tie_embedding_weights(self.encoder.embed_tokens, self.shared) - self._tie_embedding_weights(self.decoder.embed_tokens, self.shared) - def get_encoder(self): return self.encoder @@ -1135,7 +1134,11 @@ class T5ForConditionalGeneration(T5PreTrainedModel, GenerationMixin): _keys_to_ignore_on_load_unexpected = [ "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", ] - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + _tied_weights_keys = { + "lm_head.weight": "shared.weight", + "encoder.embed_tokens.weight": "shared.weight", + "decoder.embed_tokens.weight": "shared.weight", + } def __init__(self, config: T5Config): super().__init__(config) @@ -1147,13 +1150,13 @@ def __init__(self, config: T5Config): encoder_config.is_decoder = False encoder_config.use_cache = False encoder_config.tie_encoder_decoder = False - self.encoder = T5Stack(encoder_config, self.shared) + self.encoder = T5Stack(encoder_config) decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True decoder_config.tie_encoder_decoder = False decoder_config.num_layers = config.num_decoder_layers - self.decoder = T5Stack(decoder_config, self.shared) + self.decoder = T5Stack(decoder_config) self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) @@ -1168,11 +1171,6 @@ def set_input_embeddings(self, new_embeddings): self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) - def _tie_weights(self): - if self.config.tie_word_embeddings: - self._tie_embedding_weights(self.encoder.embed_tokens, self.shared) - self._tie_embedding_weights(self.decoder.embed_tokens, self.shared) - def get_encoder(self): return self.encoder @@ -1327,7 +1325,7 @@ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): @auto_docstring class T5EncoderModel(T5PreTrainedModel): - _tied_weights_keys = ["encoder.embed_tokens.weight"] + _tied_weights_keys = {"encoder.embed_tokens.weight": "shared.weight"} _keys_to_ignore_on_load_unexpected = [r"decoder"] def __init__(self, config: T5Config): @@ -1337,7 +1335,7 @@ def __init__(self, config: T5Config): encoder_config = config encoder_config.use_cache = False encoder_config.is_encoder_decoder = False - self.encoder = T5Stack(encoder_config, self.shared) + self.encoder = T5Stack(encoder_config) # Initialize weights and apply final processing self.post_init() @@ -1349,10 +1347,6 @@ def set_input_embeddings(self, new_embeddings): self.shared = new_embeddings self.encoder.set_input_embeddings(new_embeddings) - def _tie_weights(self): - if self.config.tie_word_embeddings: - self._tie_embedding_weights(self.encoder.embed_tokens, self.shared) - def get_encoder(self): return self.encoder @@ -1411,7 +1405,6 @@ def forward( ) class T5ForSequenceClassification(T5PreTrainedModel): _keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"] - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] def __init__(self, config: T5Config): super().__init__(config) @@ -1553,8 +1546,6 @@ def forward( @auto_docstring class T5ForTokenClassification(T5PreTrainedModel): - _tied_weights_keys = ["transformer.encoder.embed_tokens.weight"] - def __init__(self, config: T5Config): super().__init__(config) self.num_labels = config.num_labels @@ -1626,7 +1617,10 @@ def forward( @auto_docstring class T5ForQuestionAnswering(T5PreTrainedModel): _keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"] - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + _tied_weights_keys = { + "encoder.embed_tokens.weight": "shared.weight", + "decoder.embed_tokens.weight": "shared.weight", + } def __init__(self, config: T5Config): super().__init__(config) @@ -1638,13 +1632,13 @@ def __init__(self, config: T5Config): encoder_config.is_decoder = False encoder_config.use_cache = False encoder_config.tie_encoder_decoder = False - self.encoder = T5Stack(encoder_config, self.shared) + self.encoder = T5Stack(encoder_config) decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True decoder_config.tie_encoder_decoder = False decoder_config.num_layers = config.num_decoder_layers - self.decoder = T5Stack(decoder_config, self.shared) + self.decoder = T5Stack(decoder_config) self.num_labels = config.num_labels self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) @@ -1660,11 +1654,6 @@ def set_input_embeddings(self, new_embeddings): self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) - def _tie_weights(self): - if self.config.tie_word_embeddings: - self._tie_embedding_weights(self.encoder.embed_tokens, self.shared) - self._tie_embedding_weights(self.decoder.embed_tokens, self.shared) - def get_encoder(self): return self.encoder diff --git a/src/transformers/models/t5gemma/modeling_t5gemma.py b/src/transformers/models/t5gemma/modeling_t5gemma.py index 81ba072a2c72..c5b38b3f4374 100644 --- a/src/transformers/models/t5gemma/modeling_t5gemma.py +++ b/src/transformers/models/t5gemma/modeling_t5gemma.py @@ -554,22 +554,23 @@ class T5GemmaPreTrainedModel(PreTrainedModel): ], } + @torch.no_grad() def _init_weights(self, module): # TODO: support initialization for encoders and decoders separately(?) super()._init_weights(module) std = self.config.initializer_range if isinstance(module, T5GemmaClassificationHead): scale = module.out_proj.weight.shape[0] ** -0.5 - module.out_proj.weight.data.normal_(mean=0.0, std=std * scale) + module.out_proj.weight.normal_(mean=0.0, std=std * scale) if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None: - module.out_proj.bias.data.zero_() + module.out_proj.bias.zero_() elif isinstance(module, T5GemmaLMHead): if not self.config.tie_word_embeddings: scale = module.out_proj.weight.shape[0] ** -0.5 - module.out_proj.weight.data.normal_(mean=0.0, std=std * scale) + module.out_proj.weight.normal_(mean=0.0, std=std * scale) # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight) elif "RMSNorm" in module.__class__.__name__: - module.weight.data.zero_() + module.weight.zero_() def _shift_right(self, input_ids): """ @@ -963,7 +964,7 @@ def forward( class T5GemmaForConditionalGeneration(T5GemmaPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["model.decoder.embed_tokens.weight", "lm_head.out_proj.weight"] + _tied_weights_keys = {"lm_head.out_proj.weight": "model.decoder.embed_tokens.weight"} _tp_plan = {"lm_head.out_proj": "colwise_rep"} _pp_plan = {"lm_head.out_proj": (["hidden_states"], ["logits"])} @@ -984,11 +985,6 @@ def set_output_embeddings(self, new_embeddings): def get_output_embeddings(self): return self.lm_head.out_proj - def _tie_weights(self): - # Decoder input and output embeddings are tied. - if self.config.tie_word_embeddings: - self._tie_embedding_weights(self.lm_head.out_proj, self.get_decoder().get_input_embeddings()) - def get_encoder(self): return self.model.encoder diff --git a/src/transformers/models/t5gemma/modular_t5gemma.py b/src/transformers/models/t5gemma/modular_t5gemma.py index 86ecf53ae6e4..68b65f0eabd8 100644 --- a/src/transformers/models/t5gemma/modular_t5gemma.py +++ b/src/transformers/models/t5gemma/modular_t5gemma.py @@ -616,22 +616,23 @@ class T5GemmaPreTrainedModel(Gemma2PreTrainedModel): ], } + @torch.no_grad() def _init_weights(self, module): # TODO: support initialization for encoders and decoders separately(?) PreTrainedModel._init_weights(self, module) std = self.config.initializer_range if isinstance(module, T5GemmaClassificationHead): scale = module.out_proj.weight.shape[0] ** -0.5 - module.out_proj.weight.data.normal_(mean=0.0, std=std * scale) + module.out_proj.weight.normal_(mean=0.0, std=std * scale) if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None: - module.out_proj.bias.data.zero_() + module.out_proj.bias.zero_() elif isinstance(module, T5GemmaLMHead): if not self.config.tie_word_embeddings: scale = module.out_proj.weight.shape[0] ** -0.5 - module.out_proj.weight.data.normal_(mean=0.0, std=std * scale) + module.out_proj.weight.normal_(mean=0.0, std=std * scale) # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight) elif "RMSNorm" in module.__class__.__name__: - module.weight.data.zero_() + module.weight.zero_() def _shift_right(self, input_ids): """ @@ -1001,7 +1002,7 @@ def forward( class T5GemmaForConditionalGeneration(T5GemmaPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["model.decoder.embed_tokens.weight", "lm_head.out_proj.weight"] + _tied_weights_keys = {"lm_head.out_proj.weight": "model.decoder.embed_tokens.weight"} _tp_plan = {"lm_head.out_proj": "colwise_rep"} _pp_plan = {"lm_head.out_proj": (["hidden_states"], ["logits"])} @@ -1022,11 +1023,6 @@ def set_output_embeddings(self, new_embeddings): def get_output_embeddings(self): return self.lm_head.out_proj - def _tie_weights(self): - # Decoder input and output embeddings are tied. - if self.config.tie_word_embeddings: - self._tie_embedding_weights(self.lm_head.out_proj, self.get_decoder().get_input_embeddings()) - def get_encoder(self): return self.model.encoder diff --git a/src/transformers/models/table_transformer/modeling_table_transformer.py b/src/transformers/models/table_transformer/modeling_table_transformer.py index 90e687b14ffd..dd47df827ee6 100644 --- a/src/transformers/models/table_transformer/modeling_table_transformer.py +++ b/src/transformers/models/table_transformer/modeling_table_transformer.py @@ -694,6 +694,7 @@ class TableTransformerPreTrainedModel(PreTrainedModel): r"TableTransformerDecoderLayer", ] + @torch.no_grad() def _init_weights(self, module): std = self.config.init_std @@ -701,13 +702,13 @@ def _init_weights(self, module): nn.init.uniform_(module.row_embeddings.weight) nn.init.uniform_(module.column_embeddings.weight) if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() class TableTransformerEncoder(TableTransformerPreTrainedModel): diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py index 779a7e96301a..e0206fc5c0a8 100644 --- a/src/transformers/models/tapas/modeling_tapas.py +++ b/src/transformers/models/tapas/modeling_tapas.py @@ -481,16 +481,9 @@ def __init__(self, config): # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. - self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` - self.decoder.bias = self.bias - - def _tie_weights(self): - self.decoder.bias = self.bias - def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -515,22 +508,22 @@ class TapasPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_param_buffer_assignment = False - # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights with Bert->Tapas + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, TapasLMPredictionHead): - module.bias.data.zero_() + module.bias.zero_() @auto_docstring @@ -684,7 +677,10 @@ class for more info. @auto_docstring class TapasForMaskedLM(TapasPreTrainedModel): - _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"] + _tied_weights_keys = { + "cls.predictions.decoder.bias": "cls.predictions.bias", + "cls.predictions.decoder.weight": "tapas.embeddings.word_embeddings.weight", + } config: TapasConfig base_model_prefix = "tapas" diff --git a/src/transformers/models/textnet/modeling_textnet.py b/src/transformers/models/textnet/modeling_textnet.py index ca39fdc0f2aa..616a1a8327c6 100644 --- a/src/transformers/models/textnet/modeling_textnet.py +++ b/src/transformers/models/textnet/modeling_textnet.py @@ -221,15 +221,16 @@ class TextNetPreTrainedModel(PreTrainedModel): base_model_prefix = "textnet" main_input_name = "pixel_values" + @torch.no_grad() def _init_weights(self, module): if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.BatchNorm2d): - module.weight.data.fill_(1.0) + module.weight.fill_(1.0) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() @auto_docstring diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index c5c9b94a7d97..33dc932e01b4 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -615,18 +615,19 @@ class TimeSeriesTransformerPreTrainedModel(PreTrainedModel): _supports_sdpa = False _supports_flex_attn = False + @torch.no_grad() def _init_weights(self, module): std = self.config.init_std if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, TimeSeriesSinusoidalPositionalEmbedding): module._init_weight() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() class TimeSeriesTransformerEncoder(TimeSeriesTransformerPreTrainedModel): diff --git a/src/transformers/models/timesfm/modeling_timesfm.py b/src/transformers/models/timesfm/modeling_timesfm.py index 814f045c61b8..d8042a82bea9 100644 --- a/src/transformers/models/timesfm/modeling_timesfm.py +++ b/src/transformers/models/timesfm/modeling_timesfm.py @@ -306,6 +306,7 @@ class TimesFmPreTrainedModel(PreTrainedModel): input_modalities = "time" _supports_sdpa = True + @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) if isinstance(module, TimesFmAttention): diff --git a/src/transformers/models/timesfm/modular_timesfm.py b/src/transformers/models/timesfm/modular_timesfm.py index f88973c420e9..dc5e05e33714 100644 --- a/src/transformers/models/timesfm/modular_timesfm.py +++ b/src/transformers/models/timesfm/modular_timesfm.py @@ -262,6 +262,7 @@ class TimesFmPreTrainedModel(PreTrainedModel): input_modalities = "time" _supports_sdpa = True + @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) if isinstance(module, TimesFmAttention): diff --git a/src/transformers/models/timesformer/modeling_timesformer.py b/src/transformers/models/timesformer/modeling_timesformer.py index 556bbe4ade09..5d463c73da91 100644 --- a/src/transformers/models/timesformer/modeling_timesformer.py +++ b/src/transformers/models/timesformer/modeling_timesformer.py @@ -455,6 +455,7 @@ class TimesformerPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["TimesformerLayer"] + @torch.no_grad() def _init_weights(self, module): if isinstance(module, (nn.Linear, nn.Conv2d)): nn.init.trunc_normal_(module.weight, std=self.config.initializer_range) diff --git a/src/transformers/models/timm_backbone/modeling_timm_backbone.py b/src/transformers/models/timm_backbone/modeling_timm_backbone.py index 50e577e1838c..d0ad3dd401bf 100644 --- a/src/transformers/models/timm_backbone/modeling_timm_backbone.py +++ b/src/transformers/models/timm_backbone/modeling_timm_backbone.py @@ -114,6 +114,7 @@ def freeze_batch_norm_2d(self): def unfreeze_batch_norm_2d(self): timm.utils.model.unfreeze_batch_norm_2d(self._backbone) + @torch.no_grad() def _init_weights(self, module): """ Empty init weights function to ensure compatibility of the class in the library. diff --git a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py index 970349054697..40481d26fbac 100644 --- a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py +++ b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py @@ -79,6 +79,7 @@ def _create_timm_model_with_error_handling(config: "TimmWrapperConfig", **model_ @auto_docstring class TimmWrapperPreTrainedModel(PreTrainedModel): + base_model_prefix = "timm_model" main_input_name = "pixel_values" input_modalities = "image" config: TimmWrapperConfig @@ -122,6 +123,7 @@ def load_state_dict(self, state_dict, *args, **kwargs): state_dict = {self._fix_state_dict_key_on_load(k)[0]: v for k, v in state_dict.items()} return super().load_state_dict(state_dict, *args, **kwargs) + @torch.no_grad() def _init_weights(self, module): """ Initialize weights function to properly initialize Linear layer weights. @@ -129,9 +131,9 @@ def _init_weights(self, module): initialization, while all other weights should be loaded from the checkpoint. """ if isinstance(module, (nn.Linear)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() def _timm_model_supports_gradient_checkpointing(self): """ diff --git a/src/transformers/models/trocr/modeling_trocr.py b/src/transformers/models/trocr/modeling_trocr.py index 9caecd7ada72..78cc9206511d 100644 --- a/src/transformers/models/trocr/modeling_trocr.py +++ b/src/transformers/models/trocr/modeling_trocr.py @@ -406,16 +406,17 @@ class TrOCRPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["TrOCRDecoderLayer"] + @torch.no_grad() def _init_weights(self, module): std = self.config.init_std if isinstance(module, (nn.Linear, nn.Conv1d)): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() class TrOCRDecoder(TrOCRPreTrainedModel): @@ -657,7 +658,7 @@ def forward(self, *args, **kwargs): """ ) class TrOCRForCausalLM(TrOCRPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["output_projection.weight"] + _tied_weights_keys = {"output_projection.weight": "model.decoder.embed_tokens.weight"} def __init__(self, config): config.is_decoder = True diff --git a/src/transformers/models/tvp/modeling_tvp.py b/src/transformers/models/tvp/modeling_tvp.py index 9e6a038197fb..303ddfbfb9cb 100644 --- a/src/transformers/models/tvp/modeling_tvp.py +++ b/src/transformers/models/tvp/modeling_tvp.py @@ -522,13 +522,14 @@ class TvpPreTrainedModel(PreTrainedModel): input_modalities = ["video", "text"] supports_gradient_checkpointing = True + @torch.no_grad() def _init_weights(self, module: nn.Module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Embedding)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, nn.Conv2d): nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") if module.bias is not None: @@ -537,7 +538,7 @@ def _init_weights(self, module: nn.Module): nn.init.normal_(module.text_prompt) if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() if hasattr(module, "pad_up"): nn.init.normal_(module.pad_up) if hasattr(module, "pad_down"): diff --git a/src/transformers/models/udop/modeling_udop.py b/src/transformers/models/udop/modeling_udop.py index f749d0ce740c..a64ecc1afb25 100644 --- a/src/transformers/models/udop/modeling_udop.py +++ b/src/transformers/models/udop/modeling_udop.py @@ -257,59 +257,60 @@ class UdopPreTrainedModel(PreTrainedModel): _can_compile_fullgraph = False _keep_in_fp32_modules = ["wo"] + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" factor = self.config.initializer_factor # Used for testing weights initialization if isinstance(module, UdopLayerNorm): - module.weight.data.fill_(factor * 1.0) + module.weight.fill_(factor * 1.0) elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=factor) + module.weight.normal_(mean=0.0, std=factor) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.Conv2d): # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid # `trunc_normal_cpu` not implemented in `half` issues - module.weight.data = nn.init.trunc_normal_(module.weight.data.to(torch.float32), mean=0.0, std=factor).to( - module.weight.dtype + module.weight.copy_( + nn.init.trunc_normal_(module.weight.to(torch.float32), mean=0.0, std=factor).to(module.weight.dtype) ) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, RelativePositionBiasBase): factor = self.config.initializer_factor d_model = self.config.d_model - module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) + module.relative_attention_bias.weight.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) elif isinstance(module, UdopModel): - module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) + module.shared.weight.normal_(mean=0.0, std=factor * 1.0) elif isinstance(module, UdopForConditionalGeneration): if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: - module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0) + module.lm_head.weight.normal_(mean=0.0, std=factor * 1.0) elif isinstance(module, UdopDenseActDense): - module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + module.wi.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.wi, "bias") and module.wi.bias is not None: - module.wi.bias.data.zero_() - module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + module.wi.bias.zero_() + module.wo.weight.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) if hasattr(module.wo, "bias") and module.wo.bias is not None: - module.wo.bias.data.zero_() + module.wo.bias.zero_() elif isinstance(module, UdopDenseGatedActDense): - module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + module.wi_0.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None: - module.wi_0.bias.data.zero_() - module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + module.wi_0.bias.zero_() + module.wi_1.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None: - module.wi_1.bias.data.zero_() - module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + module.wi_1.bias.zero_() + module.wo.weight.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) if hasattr(module.wo, "bias") and module.wo.bias is not None: - module.wo.bias.data.zero_() + module.wo.bias.zero_() elif isinstance(module, UdopAttention): d_model = self.config.d_model key_value_proj_dim = self.config.d_kv n_heads = self.config.num_heads - module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) - module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) - module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) - module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) + module.q.weight.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) + module.k.weight.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.v.weight.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.o.weight.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) if module.has_relative_attention_bias: - module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) + module.relative_attention_bias.weight.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) # Copied from transformers.models.prophetnet.modeling_prophetnet.ProphetNetPreTrainedModel._shift_right with ProphetNet->Udop def _shift_right(self, input_ids): @@ -1055,11 +1056,15 @@ class UdopStack(UdopPreTrainedModel): embeddings. """ - def __init__(self, config, embed_tokens=None, embed_patches=None): - super().__init__(config) + _tied_weights_keys = { + r"relative_bias.biases.(\d+).relative_attention_bias.weight": "block.0.layer.0.SelfAttention.relative_attention_bias.weight", + } # TODO IN THIS PR ARTHUR TODO support glob or re but better than iterating - self.embed_tokens = embed_tokens - self.embed_patches = embed_patches + def __init__(self, config): + super().__init__(config) + # text and image embeddings + self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model) + self.embed_patches = UdopPatchEmbeddings(config) self.is_decoder = config.is_decoder self._max_length = config.max_length self.num_layers = config.num_layers @@ -1077,13 +1082,6 @@ def __init__(self, config, embed_tokens=None, embed_patches=None): # get weights from encoder position bias self.relative_bias = self._get_relative_bias(config) - def _tie_weights(self): - for bias in self.relative_bias.biases: - if isinstance(bias, RelativePositionBias1D): - self._tie_embedding_weights( - bias.relative_attention_bias, self.block[0].layer[0].SelfAttention.relative_attention_bias - ) - @staticmethod def _get_relative_bias(config: UdopConfig) -> RelativePositionBiasAggregated: relative_bias_list = create_relative_bias(config) @@ -1426,14 +1424,12 @@ def _prepare_4d_causal_attention_mask_with_cache_position( @auto_docstring class UdopModel(UdopPreTrainedModel): - _tied_weights_keys = [ - "encoder.embed_tokens.weight", - "decoder.embed_tokens.weight", - "encoder.embed_patches.proj.weight", - "encoder.embed_patches.proj.bias", - "encoder.relative_bias.biases.0.relative_attention_bias.weight", - "decoder.relative_bias.biases.0.relative_attention_bias.weight", - ] + _tied_weights_keys = { + "encoder.embed_tokens.weight": "shared.weight", + "decoder.embed_tokens.weight": "shared.weight", + "encoder.embed_patches.proj.weight": "patch_embed.proj.weight", # TODO tie weights for patch embeddings not working + "encoder.embed_patches.proj.bias": "patch_embed.proj.bias", # TODO tie weights for patch embeddings not working + } def __init__(self, config): super().__init__(config) @@ -1446,13 +1442,13 @@ def __init__(self, config): encoder_config.is_decoder = False encoder_config.use_cache = False encoder_config.tie_encoder_decoder = False - self.encoder = UdopStack(encoder_config, self.shared, self.patch_embed) + self.encoder = UdopStack(encoder_config) decoder_config = deepcopy(config) decoder_config.is_decoder = True decoder_config.tie_encoder_decoder = False decoder_config.num_layers = config.num_decoder_layers - self.decoder = UdopStack(decoder_config, self.shared) + self.decoder = UdopStack(decoder_config) # Initialize weights and apply final processing self.post_init() @@ -1602,15 +1598,15 @@ def forward( """ ) class UdopForConditionalGeneration(UdopPreTrainedModel, GenerationMixin): - _tied_weights_keys = [ - "encoder.embed_tokens.weight", - "decoder.embed_tokens.weight", - "encoder.embed_patches.proj.weight", - "encoder.embed_patches.proj.bias", - "encoder.relative_bias.biases.0.relative_attention_bias.weight", - "decoder.relative_bias.biases.0.relative_attention_bias.weight", - "lm_head.weight", - ] + _tied_weights_keys = { + "encoder.embed_tokens.weight": "shared.weight", + "decoder.embed_tokens.weight": "shared.weight", + "encoder.embed_patches.proj.weight": "patch_embed.proj.weight", + "encoder.embed_patches.proj.bias": "patch_embed.proj.bias", + "encoder.relative_bias.biases.0.relative_attention_bias.weight": "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight", + "decoder.relative_bias.biases.0.relative_attention_bias.weight": "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight", + "lm_head.weight": "shared.weight", + } def __init__(self, config): super().__init__(config) @@ -1623,13 +1619,13 @@ def __init__(self, config): encoder_config.is_decoder = False encoder_config.use_cache = False encoder_config.tie_encoder_decoder = False - self.encoder = UdopStack(encoder_config, self.shared, self.patch_embed) + self.encoder = UdopStack(encoder_config) decoder_config = deepcopy(config) decoder_config.is_decoder = True decoder_config.tie_encoder_decoder = False decoder_config.num_layers = config.num_decoder_layers - self.decoder = UdopStack(decoder_config, self.shared) + self.decoder = UdopStack(decoder_config) # The weights of the language modeling head are shared with those of the encoder and decoder self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) @@ -1795,12 +1791,12 @@ def forward( @auto_docstring class UdopEncoderModel(UdopPreTrainedModel): - _tied_weights_keys = [ - "encoder.embed_tokens.weight", - "encoder.embed_patches.proj.weight", - "encoder.embed_patches.proj.bias", - "encoder.relative_bias.biases.0.relative_attention_bias.weight", - ] + _tied_weights_keys = { + "encoder.embed_tokens.weight": "shared.weight", + "encoder.embed_patches.proj.weight": "patch_embed.proj.weight", + "encoder.embed_patches.proj.bias": "patch_embed.proj.bias", + "encoder.relative_bias.biases.0.relative_attention_bias.weight": "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight", + } def __init__(self, config: UdopConfig): super().__init__(config) @@ -1813,7 +1809,7 @@ def __init__(self, config: UdopConfig): encoder_config.is_decoder = False encoder_config.use_cache = False encoder_config.is_encoder_decoder = False - self.encoder = UdopStack(encoder_config, self.shared, self.patch_embed) + self.encoder = UdopStack(encoder_config) # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index a1873b99f5cd..d5a0f955049d 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -502,11 +502,12 @@ def dummy_inputs(self): } return dummy_inputs + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" factor = self.config.initializer_factor # Used for testing weights initialization if isinstance(module, UMT5LayerNorm): - module.weight.data.fill_(factor * 1.0) + module.weight.fill_(factor * 1.0) elif isinstance( module, ( @@ -518,55 +519,55 @@ def _init_weights(self, module): ): # Mesh TensorFlow embeddings initialization # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 - module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) + module.shared.weight.normal_(mean=0.0, std=factor * 1.0) if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: - module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0) + module.lm_head.weight.normal_(mean=0.0, std=factor * 1.0) if hasattr(module, "qa_outputs"): - module.qa_outputs.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) - module.qa_outputs.bias.data.zero_() + module.qa_outputs.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + module.qa_outputs.bias.zero_() elif isinstance(module, UMT5ForTokenClassification): if hasattr(module, "classifier"): - module.classifier.weight.data.normal_(mean=0.0, std=factor * 1.0) - module.classifier.bias.data.zero_() + module.classifier.weight.normal_(mean=0.0, std=factor * 1.0) + module.classifier.bias.zero_() elif isinstance(module, UMT5ClassificationHead): - module.dense.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + module.dense.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.dense, "bias") and module.dense.bias is not None: - module.dense.bias.data.zero_() - module.out_proj.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + module.dense.bias.zero_() + module.out_proj.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None: - module.out_proj.bias.data.zero_() + module.out_proj.bias.zero_() elif isinstance(module, UMT5DenseActDense): # Mesh TensorFlow FF initialization # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89 - module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + module.wi.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.wi, "bias") and module.wi.bias is not None: - module.wi.bias.data.zero_() - module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + module.wi.bias.zero_() + module.wo.weight.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) if hasattr(module.wo, "bias") and module.wo.bias is not None: - module.wo.bias.data.zero_() + module.wo.bias.zero_() elif isinstance(module, UMT5DenseGatedActDense): - module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + module.wi_0.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None: - module.wi_0.bias.data.zero_() - module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + module.wi_0.bias.zero_() + module.wi_1.weight.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None: - module.wi_1.bias.data.zero_() - module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + module.wi_1.bias.zero_() + module.wo.weight.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) if hasattr(module.wo, "bias") and module.wo.bias is not None: - module.wo.bias.data.zero_() + module.wo.bias.zero_() elif isinstance(module, UMT5Attention): # Mesh TensorFlow attention initialization to avoid scaling before softmax # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 d_model = self.config.d_model key_value_proj_dim = self.config.d_kv n_heads = self.config.num_heads - module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) - module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) - module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) - module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) + module.q.weight.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) + module.k.weight.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.v.weight.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.o.weight.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) if module.has_relative_attention_bias: - module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) + module.relative_attention_bias.weight.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id @@ -591,9 +592,9 @@ def _shift_right(self, input_ids): class UMT5Stack(UMT5PreTrainedModel): - def __init__(self, config, embed_tokens=None): + def __init__(self, config): super().__init__(config) - self.embed_tokens = embed_tokens + self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model) self.is_decoder = config.is_decoder self.block = nn.ModuleList([UMT5Block(config, layer_idx=i) for i in range(config.num_layers)]) self.final_layer_norm = UMT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) @@ -914,7 +915,10 @@ class UMT5Model(UMT5PreTrainedModel): model_type = "umt5" config: UMT5Config - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + _tied_weights_keys = { + "encoder.embed_tokens.weight": "shared.weight", + "decoder.embed_tokens.weight": "shared.weight", + } def __init__(self, config): super().__init__(config) @@ -924,13 +928,13 @@ def __init__(self, config): encoder_config.is_decoder = False encoder_config.use_cache = False encoder_config.tie_encoder_decoder = False - self.encoder = UMT5Stack(encoder_config, self.shared) + self.encoder = UMT5Stack(encoder_config) decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True decoder_config.tie_encoder_decoder = False decoder_config.num_layers = config.num_decoder_layers - self.decoder = UMT5Stack(decoder_config, self.shared) + self.decoder = UMT5Stack(decoder_config) # Initialize weights and apply final processing self.post_init() @@ -945,12 +949,6 @@ def set_input_embeddings(self, new_embeddings): self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) - # Copied from transformers.models.t5.modeling_t5.T5Model._tie_weights - def _tie_weights(self): - if self.config.tie_word_embeddings: - self._tie_embedding_weights(self.encoder.embed_tokens, self.shared) - self._tie_embedding_weights(self.decoder.embed_tokens, self.shared) - # Copied from transformers.models.t5.modeling_t5.T5Model.get_encoder def get_encoder(self): return self.encoder @@ -1096,7 +1094,11 @@ class UMT5ForConditionalGeneration(UMT5PreTrainedModel, GenerationMixin): ```""" model_type = "umt5" - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + _tied_weights_keys = { + "encoder.embed_tokens.weight": "shared.weight", + "decoder.embed_tokens.weight": "shared.weight", + "lm_head.weight": "shared.weight", + } def __init__(self, config): super().__init__(config) @@ -1108,13 +1110,13 @@ def __init__(self, config): encoder_config.is_decoder = False encoder_config.use_cache = False encoder_config.tie_encoder_decoder = False - self.encoder = UMT5Stack(encoder_config, self.shared) + self.encoder = UMT5Stack(encoder_config) decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True decoder_config.tie_encoder_decoder = False decoder_config.num_layers = config.num_decoder_layers - self.decoder = UMT5Stack(decoder_config, self.shared) + self.decoder = UMT5Stack(decoder_config) self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) @@ -1131,12 +1133,6 @@ def set_input_embeddings(self, new_embeddings): self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) - # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration._tie_weights - def _tie_weights(self): - if self.config.tie_word_embeddings: - self._tie_embedding_weights(self.encoder.embed_tokens, self.shared) - self._tie_embedding_weights(self.decoder.embed_tokens, self.shared) - # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.get_encoder def get_encoder(self): return self.encoder @@ -1308,7 +1304,9 @@ class UMT5EncoderModel(UMT5PreTrainedModel): model_type = "umt5" # config_class = UMT5Config - _tied_weights_keys = ["encoder.embed_tokens.weight"] + _tied_weights_keys = { + "encoder.embed_tokens.weight": "shared.weight", + } def __init__(self, config): super().__init__(config) @@ -1317,7 +1315,7 @@ def __init__(self, config): encoder_config = copy.deepcopy(config) encoder_config.use_cache = False encoder_config.is_encoder_decoder = False - self.encoder = UMT5Stack(encoder_config, self.shared) + self.encoder = UMT5Stack(encoder_config) # Initialize weights and apply final processing self.post_init() @@ -1331,11 +1329,6 @@ def set_input_embeddings(self, new_embeddings): self.shared = new_embeddings self.encoder.set_input_embeddings(new_embeddings) - # Copied from transformers.models.t5.modeling_t5.T5EncoderModel._tie_weights - def _tie_weights(self): - if self.config.tie_word_embeddings: - self._tie_embedding_weights(self.encoder.embed_tokens, self.shared) - # Copied from transformers.models.t5.modeling_t5.T5EncoderModel.get_encoder def get_encoder(self): return self.encoder @@ -1396,7 +1389,6 @@ def forward( ) class UMT5ForSequenceClassification(UMT5PreTrainedModel): _keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"] - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] # Copied from transformers.models.t5.modeling_t5.T5ForSequenceClassification.__init__ with T5->UMT5 def __init__(self, config: UMT5Config): @@ -1540,7 +1532,6 @@ def forward( @auto_docstring class UMT5ForTokenClassification(UMT5PreTrainedModel): _keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"] - _tied_weights_keys = ["transformer.encoder.embed_tokens.weight"] # Copied from transformers.models.t5.modeling_t5.T5ForTokenClassification.__init__ with T5->UMT5 def __init__(self, config: UMT5Config): @@ -1614,7 +1605,10 @@ def forward( @auto_docstring class UMT5ForQuestionAnswering(UMT5PreTrainedModel): - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + _tied_weights_keys = { + "encoder.embed_tokens.weight": "shared.weight", + "decoder.embed_tokens.weight": "shared.weight", + } def __init__(self, config): super().__init__(config) @@ -1626,13 +1620,13 @@ def __init__(self, config): encoder_config.is_decoder = False encoder_config.use_cache = False encoder_config.tie_encoder_decoder = False - self.encoder = UMT5Stack(encoder_config, self.shared) + self.encoder = UMT5Stack(encoder_config) decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True decoder_config.tie_encoder_decoder = False decoder_config.num_layers = config.num_decoder_layers - self.decoder = UMT5Stack(decoder_config, self.shared) + self.decoder = UMT5Stack(decoder_config) self.num_labels = config.num_labels self.qa_outputs = nn.Linear(config.d_model, config.num_labels) @@ -1650,12 +1644,6 @@ def set_input_embeddings(self, new_embeddings): self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) - # Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering._tie_weights - def _tie_weights(self): - if self.config.tie_word_embeddings: - self._tie_embedding_weights(self.encoder.embed_tokens, self.shared) - self._tie_embedding_weights(self.decoder.embed_tokens, self.shared) - # Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering.get_encoder def get_encoder(self): return self.encoder diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index 8bdec6b3cae8..bee61f38fc58 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -740,12 +740,13 @@ class UniSpeechPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" # gumbel softmax requires special init if isinstance(module, UniSpeechGumbelVectorQuantizer): - module.weight_proj.weight.data.normal_(mean=0.0, std=1) - module.weight_proj.bias.data.zero_() + module.weight_proj.weight.normal_(mean=0.0, std=1) + module.weight_proj.bias.zero_() nn.init.uniform_(module.codevectors) elif isinstance(module, UniSpeechPositionalConvEmbedding): nn.init.normal_( @@ -759,13 +760,13 @@ def _init_weights(self, module): nn.init.uniform_(module.projection.weight, a=-k, b=k) nn.init.uniform_(module.projection.bias, a=-k, b=k) elif isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, nn.Conv1d): nn.init.kaiming_normal_(module.weight) @@ -1221,7 +1222,7 @@ def __init__(self, config, target_lang: Optional[str] = None): # Initialize weights and apply final processing self.post_init() - def tie_weights(self): + def tie_weights(self, missing_keys=None): """ This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when passing `target_lang=...` to `from_pretrained(...)`. diff --git a/src/transformers/models/unispeech/modular_unispeech.py b/src/transformers/models/unispeech/modular_unispeech.py index 534490235db1..73724c5351b6 100644 --- a/src/transformers/models/unispeech/modular_unispeech.py +++ b/src/transformers/models/unispeech/modular_unispeech.py @@ -147,12 +147,13 @@ class UniSpeechPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" # gumbel softmax requires special init if isinstance(module, UniSpeechGumbelVectorQuantizer): - module.weight_proj.weight.data.normal_(mean=0.0, std=1) - module.weight_proj.bias.data.zero_() + module.weight_proj.weight.normal_(mean=0.0, std=1) + module.weight_proj.bias.zero_() nn.init.uniform_(module.codevectors) elif isinstance(module, UniSpeechPositionalConvEmbedding): nn.init.normal_( @@ -166,13 +167,13 @@ def _init_weights(self, module): nn.init.uniform_(module.projection.weight, a=-k, b=k) nn.init.uniform_(module.projection.bias, a=-k, b=k) elif isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, nn.Conv1d): nn.init.kaiming_normal_(module.weight) diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index 57e5d3cdbcc0..01de810850e7 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -745,12 +745,13 @@ class UniSpeechSatPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" # gumbel softmax requires special init if isinstance(module, UniSpeechSatGumbelVectorQuantizer): - module.weight_proj.weight.data.normal_(mean=0.0, std=1) - module.weight_proj.bias.data.zero_() + module.weight_proj.weight.normal_(mean=0.0, std=1) + module.weight_proj.bias.zero_() nn.init.uniform_(module.codevectors) elif isinstance(module, UniSpeechSatPositionalConvEmbedding): nn.init.normal_( @@ -764,13 +765,13 @@ def _init_weights(self, module): nn.init.uniform_(module.projection.weight, a=-k, b=k) nn.init.uniform_(module.projection.bias, a=-k, b=k) elif isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, nn.Conv1d): nn.init.kaiming_normal_(module.weight) @@ -1216,7 +1217,7 @@ def __init__(self, config, target_lang: Optional[str] = None): # Initialize weights and apply final processing self.post_init() - def tie_weights(self): + def tie_weights(self, missing_keys=None): """ This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when passing `target_lang=...` to `from_pretrained(...)`. diff --git a/src/transformers/models/unispeech_sat/modular_unispeech_sat.py b/src/transformers/models/unispeech_sat/modular_unispeech_sat.py index e209c7c18ea3..cb94ec81a3db 100644 --- a/src/transformers/models/unispeech_sat/modular_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modular_unispeech_sat.py @@ -159,12 +159,13 @@ class UniSpeechSatPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" # gumbel softmax requires special init if isinstance(module, UniSpeechSatGumbelVectorQuantizer): - module.weight_proj.weight.data.normal_(mean=0.0, std=1) - module.weight_proj.bias.data.zero_() + module.weight_proj.weight.normal_(mean=0.0, std=1) + module.weight_proj.bias.zero_() nn.init.uniform_(module.codevectors) elif isinstance(module, UniSpeechSatPositionalConvEmbedding): nn.init.normal_( @@ -178,13 +179,13 @@ def _init_weights(self, module): nn.init.uniform_(module.projection.weight, a=-k, b=k) nn.init.uniform_(module.projection.bias, a=-k, b=k) elif isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, nn.Conv1d): nn.init.kaiming_normal_(module.weight) diff --git a/src/transformers/models/univnet/modeling_univnet.py b/src/transformers/models/univnet/modeling_univnet.py index 048d68e7276a..1b208acdc5d9 100644 --- a/src/transformers/models/univnet/modeling_univnet.py +++ b/src/transformers/models/univnet/modeling_univnet.py @@ -591,12 +591,13 @@ def forward( waveform_lengths=waveform_lengths, ) + @torch.no_grad() def _init_weights(self, module): """Initialize the weights.""" if isinstance(module, (nn.Linear, nn.Conv1d, nn.ConvTranspose1d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() def apply_weight_norm(self): weight_norm = nn.utils.weight_norm diff --git a/src/transformers/models/upernet/modeling_upernet.py b/src/transformers/models/upernet/modeling_upernet.py index 5c9521766379..64bd7e958f7b 100644 --- a/src/transformers/models/upernet/modeling_upernet.py +++ b/src/transformers/models/upernet/modeling_upernet.py @@ -272,14 +272,15 @@ class UperNetPreTrainedModel(PreTrainedModel): input_modalities = "image" _no_split_modules = [] + @torch.no_grad() def _init_weights(self, module): if isinstance(module, nn.Conv2d): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.BatchNorm2d): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() @auto_docstring( diff --git a/src/transformers/models/vaultgemma/modeling_vaultgemma.py b/src/transformers/models/vaultgemma/modeling_vaultgemma.py index 51071e59997b..40977bfc2c42 100644 --- a/src/transformers/models/vaultgemma/modeling_vaultgemma.py +++ b/src/transformers/models/vaultgemma/modeling_vaultgemma.py @@ -370,12 +370,13 @@ class VaultGemmaPreTrainedModel(PreTrainedModel): "attentions": VaultGemmaAttention, } + @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight) if "RMSNorm" in module.__class__.__name__: - module.weight.data.zero_() + module.weight.zero_() @auto_docstring @@ -508,7 +509,7 @@ def forward( @auto_docstring class VaultGemmaForCausalLM(VaultGemmaPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} diff --git a/src/transformers/models/video_llama_3/modeling_video_llama_3.py b/src/transformers/models/video_llama_3/modeling_video_llama_3.py index 6454da2a73c4..37370bd91266 100644 --- a/src/transformers/models/video_llama_3/modeling_video_llama_3.py +++ b/src/transformers/models/video_llama_3/modeling_video_llama_3.py @@ -745,7 +745,7 @@ class VideoLlama3CausalLMOutputWithPast(ModelOutput): class VideoLlama3ForConditionalGeneration(VideoLlama3PreTrainedModel, GenerationMixin): _checkpoint_conversion_mapping = {} - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} _can_compile_fullgraph = False def __init__(self, config: VideoLlama3Config): diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index 3f874c2e9353..495719cb22c7 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -136,6 +136,7 @@ class VideoLlavaPreTrainedModel(PreTrainedModel): _can_compile_fullgraph = True _supports_attention_backend = True + @torch.no_grad() def _init_weights(self, module): std = ( self.config.initializer_range @@ -144,16 +145,16 @@ def _init_weights(self, module): ) if hasattr(module, "class_embedding"): - module.class_embedding.data.normal_(mean=0.0, std=std) + module.class_embedding.normal_(mean=0.0, std=std) if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() @auto_docstring( @@ -424,7 +425,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi "^multi_modal_projector": "model.multi_modal_projector", "^language_model.lm_head": "lm_head", } - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} def __init__(self, config: VideoLlavaConfig): super().__init__(config) diff --git a/src/transformers/models/videomae/modeling_videomae.py b/src/transformers/models/videomae/modeling_videomae.py index 95163da0311f..b1a7179771d6 100755 --- a/src/transformers/models/videomae/modeling_videomae.py +++ b/src/transformers/models/videomae/modeling_videomae.py @@ -392,15 +392,16 @@ class VideoMAEPreTrainedModel(PreTrainedModel): "attentions": VideoMAESelfAttention, } + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv3d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) @auto_docstring diff --git a/src/transformers/models/vilt/modeling_vilt.py b/src/transformers/models/vilt/modeling_vilt.py index 9a32ee12be13..4c525b2d8f92 100755 --- a/src/transformers/models/vilt/modeling_vilt.py +++ b/src/transformers/models/vilt/modeling_vilt.py @@ -516,19 +516,20 @@ class ViltPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["ViltEmbeddings", "ViltSelfAttention"] + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) @auto_docstring @@ -688,7 +689,10 @@ def forward(self, hidden_states): """ ) class ViltForMaskedLM(ViltPreTrainedModel): - _tied_weights_keys = ["mlm_score.decoder.weight", "mlm_score.decoder.bias"] + _tied_weights_keys = { + "mlm_score.decoder.weight": "vilt.embeddings.text_embeddings.weight", + "mlm_score.decoder.bias": "mlm_score.bias", + } def __init__(self, config): super().__init__(config) @@ -846,12 +850,6 @@ def __init__(self, config, weight=None): if weight is not None: self.decoder.weight = weight - # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` - self.decoder.bias = self.bias - - def _tie_weights(self): - self.decoder.bias = self.bias - def forward(self, x): x = self.transform(x) x = self.decoder(x) diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index 16606f8ccf4d..791ae03a3aec 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -291,7 +291,7 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin) "^multi_modal_projector": "model.multi_modal_projector", "^language_model.lm_head": "lm_head", } - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} def __init__(self, config: VipLlavaConfig): super().__init__(config) diff --git a/src/transformers/models/visual_bert/modeling_visual_bert.py b/src/transformers/models/visual_bert/modeling_visual_bert.py index b8a68cd257ae..a085f8954f03 100755 --- a/src/transformers/models/visual_bert/modeling_visual_bert.py +++ b/src/transformers/models/visual_bert/modeling_visual_bert.py @@ -431,16 +431,9 @@ def __init__(self, config): # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. - self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` - self.decoder.bias = self.bias - - def _tie_weights(self): - self.decoder.bias = self.bias - def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -467,17 +460,18 @@ class VisualBertPreTrainedModel(PreTrainedModel): input_modalities = ["image", "text"] supports_gradient_checkpointing = True + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Embedding)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if hasattr(module, "bias") and module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, VisualBertLMPredictionHead): - module.bias.data.zero_() + module.bias.zero_() @dataclass @@ -702,7 +696,10 @@ def forward( """ ) class VisualBertForPreTraining(VisualBertPreTrainedModel): - _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"] + _tied_weights_keys = { + "cls.predictions.decoder.bias": "cls.predictions.bias", + "cls.predictions.decoder.weight": "visual_bert.embeddings.word_embeddings.weight", + } def __init__(self, config): super().__init__(config) @@ -1341,7 +1338,10 @@ def forward(self, query, key, attention_mask): """ ) class VisualBertForRegionToPhraseAlignment(VisualBertPreTrainedModel): - _tied_weights_keys = ["cls.predictions.decoder.bias"] + _tied_weights_keys = { + "cls.predictions.decoder.bias": "cls.predictions.bias", + "cls.predictions.decoder.weight": "visual_bert.embeddings.word_embeddings.weight", + } def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py index 7923264d7e01..bef55534d577 100644 --- a/src/transformers/models/vit/modeling_vit.py +++ b/src/transformers/models/vit/modeling_vit.py @@ -365,34 +365,41 @@ class ViTPreTrainedModel(PreTrainedModel): "attentions": ViTSelfAttention, } + @torch.no_grad() def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid # `trunc_normal_cpu` not implemented in `half` issues - module.weight.data = nn.init.trunc_normal_( - module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range - ).to(module.weight.dtype) + module.weight.copy_( + nn.init.trunc_normal_(module.weight.to(torch.float32), mean=0.0, std=self.config.initializer_range).to( + module.weight.dtype + ) + ) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, ViTEmbeddings): - module.position_embeddings.data = nn.init.trunc_normal_( - module.position_embeddings.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.position_embeddings.dtype) - - module.cls_token.data = nn.init.trunc_normal_( - module.cls_token.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.cls_token.dtype) + module.position_embeddings.copy_( + nn.init.trunc_normal_( + module.position_embeddings.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + ) + + module.cls_token.copy_( + nn.init.trunc_normal_( + module.cls_token.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.cls_token.dtype) + ) if module.mask_token is not None: - module.mask_token.data.zero_() + module.mask_token.zero_() @auto_docstring diff --git a/src/transformers/models/vit_mae/modeling_vit_mae.py b/src/transformers/models/vit_mae/modeling_vit_mae.py index 159fca54943e..479f84ab77ed 100755 --- a/src/transformers/models/vit_mae/modeling_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_vit_mae.py @@ -186,7 +186,7 @@ def initialize_weights(self): pos_embed = get_2d_sincos_pos_embed( self.position_embeddings.shape[-1], int(self.patch_embeddings.num_patches**0.5), add_cls_token=True ) - self.position_embeddings.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + self.position_embeddings.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) # initialize patch_embeddings like nn.Linear (instead of nn.Conv2d) w = self.patch_embeddings.projection.weight.data @@ -530,20 +530,21 @@ class ViTMAEPreTrainedModel(PreTrainedModel): "attentions": ViTMAESelfAttention, } + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, ViTMAEEmbeddings): module.initialize_weights() elif isinstance(module, ViTMAEDecoder): - module.mask_token.data.zero_() - module.decoder_pos_embed.data.zero_() + module.mask_token.zero_() + module.decoder_pos_embed.zero_() @auto_docstring @@ -682,7 +683,7 @@ def initialize_weights(self, num_patches): decoder_pos_embed = get_2d_sincos_pos_embed( self.decoder_pos_embed.shape[-1], int(num_patches**0.5), add_cls_token=True ) - self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) + self.decoder_pos_embed.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) torch.nn.init.normal_(self.mask_token, std=self.config.initializer_range) diff --git a/src/transformers/models/vit_msn/modeling_vit_msn.py b/src/transformers/models/vit_msn/modeling_vit_msn.py index 1ed50e9da579..e10dfb6d123f 100644 --- a/src/transformers/models/vit_msn/modeling_vit_msn.py +++ b/src/transformers/models/vit_msn/modeling_vit_msn.py @@ -370,20 +370,21 @@ class ViTMSNPreTrainedModel(PreTrainedModel): # todo: Resort to https://github.com/facebookresearch/msn/blob/main/src/deit.py#L200-#L211 # when creating pre-training scripts. + @torch.no_grad() def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, ViTMSNEmbeddings): - module.cls_token.data.zero_() - module.position_embeddings.data.zero_() + module.cls_token.zero_() + module.position_embeddings.zero_() if module.mask_token is not None: - module.mask_token.data.zero_() + module.mask_token.zero_() @auto_docstring diff --git a/src/transformers/models/vitdet/modeling_vitdet.py b/src/transformers/models/vitdet/modeling_vitdet.py index b02b66f4d52c..a235b25a57c5 100644 --- a/src/transformers/models/vitdet/modeling_vitdet.py +++ b/src/transformers/models/vitdet/modeling_vitdet.py @@ -593,48 +593,57 @@ class VitDetPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = [] + @torch.no_grad() def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid # `trunc_normal_cpu` not implemented in `half` issues - module.weight.data = nn.init.trunc_normal_( - module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range - ).to(module.weight.dtype) + module.weight.copy_( + nn.init.trunc_normal_(module.weight.to(torch.float32), mean=0.0, std=self.config.initializer_range).to( + module.weight.dtype + ) + ) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, VitDetEmbeddings): - module.position_embeddings.data = nn.init.trunc_normal_( - module.position_embeddings.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.position_embeddings.dtype) + module.position_embeddings.copy_( + nn.init.trunc_normal_( + module.position_embeddings.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + ) elif isinstance(module, VitDetAttention) and self.config.use_relative_position_embeddings: - module.rel_pos_h.data = nn.init.trunc_normal_( - module.rel_pos_h.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, + module.rel_pos_h.copy_( + nn.init.trunc_normal_( + module.rel_pos_h.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ) ) - module.rel_pos_w.data = nn.init.trunc_normal_( - module.rel_pos_w.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, + module.rel_pos_w.copy_( + nn.init.trunc_normal_( + module.rel_pos_w.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ) ) elif isinstance(module, VitDetResBottleneckBlock): for layer in [module.conv1, module.conv2, module.conv3]: caffe2_msra_fill(layer) for layer in [module.norm1, module.norm2]: - layer.weight.data.fill_(1.0) - layer.bias.data.zero_() + layer.weight.fill_(1.0) + layer.bias.zero_() # zero init last norm layer. - module.norm3.weight.data.zero_() - module.norm3.bias.data.zero_() + module.norm3.weight.zero_() + module.norm3.bias.zero_() @auto_docstring diff --git a/src/transformers/models/vitmatte/modeling_vitmatte.py b/src/transformers/models/vitmatte/modeling_vitmatte.py index 8863056c5190..8cf9841d1e47 100644 --- a/src/transformers/models/vitmatte/modeling_vitmatte.py +++ b/src/transformers/models/vitmatte/modeling_vitmatte.py @@ -58,11 +58,12 @@ class VitMattePreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = [] + @torch.no_grad() def _init_weights(self, module: nn.Module): if isinstance(module, (nn.Conv2d, nn.BatchNorm2d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() class VitMatteBasicConv3x3(nn.Module): diff --git a/src/transformers/models/vitpose/modeling_vitpose.py b/src/transformers/models/vitpose/modeling_vitpose.py index 247e7b47ccec..f87396b564f7 100644 --- a/src/transformers/models/vitpose/modeling_vitpose.py +++ b/src/transformers/models/vitpose/modeling_vitpose.py @@ -66,19 +66,22 @@ class VitPosePreTrainedModel(PreTrainedModel): input_modalities = "image" supports_gradient_checkpointing = True + @torch.no_grad() def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid # `trunc_normal_cpu` not implemented in `half` issues - module.weight.data = nn.init.trunc_normal_( - module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range - ).to(module.weight.dtype) + module.weight.copy_( + nn.init.trunc_normal_(module.weight.to(torch.float32), mean=0.0, std=self.config.initializer_range).to( + module.weight.dtype + ) + ) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) def flip_back(output_flipped, flip_pairs, target_type="gaussian-heatmap"): diff --git a/src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py b/src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py index e4fb4276a313..c5c5d8ffbe02 100644 --- a/src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +++ b/src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py @@ -357,25 +357,30 @@ class VitPoseBackbonePreTrainedModel(PreTrainedModel): "attentions": VitPoseBackboneSelfAttention, } + @torch.no_grad() def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm, VitPoseBackboneEmbeddings]): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid # `trunc_normal_cpu` not implemented in `half` issues - module.weight.data = nn.init.trunc_normal_( - module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range - ).to(module.weight.dtype) + module.weight.copy_( + nn.init.trunc_normal_(module.weight.to(torch.float32), mean=0.0, std=self.config.initializer_range).to( + module.weight.dtype + ) + ) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, VitPoseBackboneEmbeddings): - module.position_embeddings.data = nn.init.trunc_normal_( - module.position_embeddings.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.position_embeddings.dtype) + module.position_embeddings.copy_( + nn.init.trunc_normal_( + module.position_embeddings.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + ) @auto_docstring( diff --git a/src/transformers/models/vits/modeling_vits.py b/src/transformers/models/vits/modeling_vits.py index bae8d44e0d13..dd9117e309a3 100644 --- a/src/transformers/models/vits/modeling_vits.py +++ b/src/transformers/models/vits/modeling_vits.py @@ -1201,33 +1201,34 @@ class VitsPreTrainedModel(PreTrainedModel): main_input_name = "input_ids" supports_gradient_checkpointing = True + @torch.no_grad() def _init_weights(self, module: nn.Module): """Initialize the weights""" std = self.config.initializer_range if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, (nn.Conv1d, nn.ConvTranspose1d)): nn.init.kaiming_normal_(module.weight) if module.bias is not None: k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) nn.init.uniform_(module.bias, a=-k, b=k) elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, VitsAttention): if self.config.window_size: head_dim = self.config.hidden_size // self.config.num_attention_heads nn.init.normal_(module.emb_rel_k, std=head_dim**-0.5) nn.init.normal_(module.emb_rel_v, std=head_dim**-0.5) elif isinstance(module, VitsElementwiseAffine): - module.translate.data.zero_() - module.log_scale.data.zero_() + module.translate.zero_() + module.log_scale.zero_() @auto_docstring( diff --git a/src/transformers/models/vivit/modeling_vivit.py b/src/transformers/models/vivit/modeling_vivit.py index 098c891922e2..ed55faac7aa0 100755 --- a/src/transformers/models/vivit/modeling_vivit.py +++ b/src/transformers/models/vivit/modeling_vivit.py @@ -375,22 +375,23 @@ class VivitPreTrainedModel(PreTrainedModel): "attentions": VivitSelfAttention, } + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv3d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, VivitEmbeddings): - module.cls_token.data.zero_() - module.position_embeddings.data.zero_() + module.cls_token.zero_() + module.position_embeddings.zero_() @auto_docstring diff --git a/src/transformers/models/vjepa2/modeling_vjepa2.py b/src/transformers/models/vjepa2/modeling_vjepa2.py index 86d002ede4be..f2ab5b1f2cf8 100644 --- a/src/transformers/models/vjepa2/modeling_vjepa2.py +++ b/src/transformers/models/vjepa2/modeling_vjepa2.py @@ -941,6 +941,7 @@ class VJEPA2PreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flash_attn = True + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" @@ -949,9 +950,9 @@ def _init_weights(self, module): # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid # `trunc_normal_cpu` not implemented in `half` issues def trunc_normal_f32_(weight, std): - data_float_32 = weight.data.to(torch.float32) + data_float_32 = weight.to(torch.float32) data_init = nn.init.trunc_normal_(data_float_32, mean=0.0, std=std) - weight.data = data_init.to(weight.dtype) + weight.copy_(data_init.to(weight.dtype)) if isinstance(module, VJEPA2AttentivePooler): trunc_normal_f32_(module.query_tokens, std=init_std) @@ -963,16 +964,16 @@ def trunc_normal_f32_(weight, std): trunc_normal_f32_(module.cross_attention_layer.mlp.fc2.weight, std=std) elif isinstance(module, VJEPA2PredictorEmbeddings): if module.zero_init_mask_tokens: - module.mask_tokens.data.zero_() + module.mask_tokens.zero_() else: trunc_normal_f32_(module.mask_tokens, std=init_std) elif isinstance(module, (nn.Linear, nn.Conv2d, nn.Conv3d)): trunc_normal_f32_(module.weight, std=init_std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) @auto_docstring diff --git a/src/transformers/models/voxtral/modeling_voxtral.py b/src/transformers/models/voxtral/modeling_voxtral.py index bc309bddf006..bec9ffe55641 100644 --- a/src/transformers/models/voxtral/modeling_voxtral.py +++ b/src/transformers/models/voxtral/modeling_voxtral.py @@ -231,6 +231,7 @@ class VoxtralPreTrainedModel(PreTrainedModel): _supports_attention_backend = True _can_compile_fullgraph = True + @torch.no_grad() def _init_weights(self, module): # important: this ported version of Voxtral isn't meant for training from scratch - only # inference and fine-tuning - so the proper init weights code has been removed @@ -241,16 +242,16 @@ def _init_weights(self, module): ) if isinstance(module, (nn.Linear, nn.Conv1d)): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() @auto_docstring( @@ -391,9 +392,6 @@ def forward(self, audio_features): """ ) class VoxtralForConditionalGeneration(VoxtralPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] - _tp_plan = {"lm_head": "colwise_rep"} - _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} _keep_in_fp32_modules_strict = ["embed_positions"] def __init__(self, config): diff --git a/src/transformers/models/voxtral/modular_voxtral.py b/src/transformers/models/voxtral/modular_voxtral.py index a3df19390892..36e705fcc770 100644 --- a/src/transformers/models/voxtral/modular_voxtral.py +++ b/src/transformers/models/voxtral/modular_voxtral.py @@ -132,9 +132,6 @@ def forward(self, audio_features): """ ) class VoxtralForConditionalGeneration(VoxtralPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] - _tp_plan = {"lm_head": "colwise_rep"} - _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} _keep_in_fp32_modules_strict = ["embed_positions"] def __init__(self, config): diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index 82399d0933dc..e77cc49fe208 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -980,6 +980,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" # Wav2Vec2ForPreTraining last 2 linear layers need standard Linear init. @@ -990,8 +991,8 @@ def _init_weights(self, module): module.project_q._is_hf_initialized = True # gumbel softmax requires special init elif isinstance(module, Wav2Vec2GumbelVectorQuantizer): - module.weight_proj.weight.data.normal_(mean=0.0, std=1) - module.weight_proj.bias.data.zero_() + module.weight_proj.weight.normal_(mean=0.0, std=1) + module.weight_proj.bias.zero_() nn.init.uniform_(module.codevectors) elif isinstance(module, Wav2Vec2PositionalConvEmbedding): nn.init.normal_( @@ -1005,13 +1006,13 @@ def _init_weights(self, module): nn.init.uniform_(module.projection.weight, a=-k, b=k) nn.init.uniform_(module.projection.bias, a=-k, b=k) elif isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, nn.Conv1d): nn.init.kaiming_normal_(module.weight) @@ -1720,7 +1721,7 @@ def __init__(self, config, target_lang: Optional[str] = None): # Initialize weights and apply final processing self.post_init() - def tie_weights(self): + def tie_weights(self, missing_keys=None): """ This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when passing `target_lang=...` to `from_pretrained(...)`. diff --git a/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py b/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py index c8593d38d131..65c53653c191 100644 --- a/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +++ b/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py @@ -711,6 +711,7 @@ class Wav2Vec2BertPreTrainedModel(PreTrainedModel): input_modalities = "audio" supports_gradient_checkpointing = True + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, Wav2Vec2BertSelfAttention): @@ -723,13 +724,13 @@ def _init_weights(self, module): nn.init.uniform_(module.projection.weight, a=-k, b=k) nn.init.uniform_(module.projection.bias, a=-k, b=k) elif isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, nn.Conv1d): nn.init.kaiming_normal_(module.weight) @@ -738,15 +739,15 @@ def _init_weights(self, module): nn.init.uniform_(module.bias, a=-k, b=k) elif isinstance(module, Wav2Vec2BertModel): if hasattr(module, "masked_spec_embed"): - module.masked_spec_embed.data.uniform_() + module.masked_spec_embed.uniform_() elif isinstance( module, (Wav2Vec2BertForSequenceClassification, Wav2Vec2BertForAudioFrameClassification, Wav2Vec2BertForXVector), ): if hasattr(module, "layer_weights"): - module.layer_weights.data.fill_(1.0 / (self.config.num_hidden_layers + 1)) + module.layer_weights.fill_(1.0 / (self.config.num_hidden_layers + 1)) elif isinstance(module, AMSoftmaxLoss): # noqa: F821 - module.weight.data.normal_() + module.weight.normal_() # Ignore copy def _get_feat_extract_output_lengths( diff --git a/src/transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py b/src/transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py index 3bce99771f55..b9949c62368c 100644 --- a/src/transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +++ b/src/transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py @@ -583,6 +583,7 @@ class Wav2Vec2BertPreTrainedModel(PreTrainedModel): input_modalities = "audio" supports_gradient_checkpointing = True + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, Wav2Vec2BertSelfAttention): @@ -595,13 +596,13 @@ def _init_weights(self, module): nn.init.uniform_(module.projection.weight, a=-k, b=k) nn.init.uniform_(module.projection.bias, a=-k, b=k) elif isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, nn.Conv1d): nn.init.kaiming_normal_(module.weight) @@ -610,15 +611,15 @@ def _init_weights(self, module): nn.init.uniform_(module.bias, a=-k, b=k) elif isinstance(module, Wav2Vec2BertModel): if hasattr(module, "masked_spec_embed"): - module.masked_spec_embed.data.uniform_() + module.masked_spec_embed.uniform_() elif isinstance( module, (Wav2Vec2BertForSequenceClassification, Wav2Vec2BertForAudioFrameClassification, Wav2Vec2BertForXVector), ): if hasattr(module, "layer_weights"): - module.layer_weights.data.fill_(1.0 / (self.config.num_hidden_layers + 1)) + module.layer_weights.fill_(1.0 / (self.config.num_hidden_layers + 1)) elif isinstance(module, AMSoftmaxLoss): # noqa: F821 - module.weight.data.normal_() + module.weight.normal_() # Ignore copy def _get_feat_extract_output_lengths( diff --git a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py index 9fddc1ce224f..f3ee90ba8576 100644 --- a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +++ b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py @@ -851,18 +851,17 @@ class Wav2Vec2ConformerPreTrainedModel(PreTrainedModel): input_modalities = "audio" supports_gradient_checkpointing = True + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" # Wav2Vec2ForPreTraining last 2 linear layers need standard Linear init. if isinstance(module, Wav2Vec2ConformerForPreTraining): module.project_hid.reset_parameters() module.project_q.reset_parameters() - module.project_hid._is_hf_initialized = True - module.project_q._is_hf_initialized = True # gumbel softmax requires special init elif isinstance(module, Wav2Vec2ConformerGumbelVectorQuantizer): - module.weight_proj.weight.data.normal_(mean=0.0, std=1) - module.weight_proj.bias.data.zero_() + module.weight_proj.weight.normal_(mean=0.0, std=1) + module.weight_proj.bias.zero_() nn.init.uniform_(module.codevectors) elif isinstance(module, Wav2Vec2ConformerSelfAttention): if hasattr(module, "pos_bias_u"): @@ -881,13 +880,13 @@ def _init_weights(self, module): nn.init.uniform_(module.projection.weight, a=-k, b=k) nn.init.uniform_(module.projection.bias, a=-k, b=k) elif isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, nn.Conv1d): nn.init.kaiming_normal_(module.weight) diff --git a/src/transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py index 7a0e757a8496..55203180dc9c 100644 --- a/src/transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +++ b/src/transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py @@ -550,18 +550,17 @@ class Wav2Vec2ConformerPreTrainedModel(PreTrainedModel): input_modalities = "audio" supports_gradient_checkpointing = True + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" # Wav2Vec2ForPreTraining last 2 linear layers need standard Linear init. if isinstance(module, Wav2Vec2ConformerForPreTraining): module.project_hid.reset_parameters() module.project_q.reset_parameters() - module.project_hid._is_hf_initialized = True - module.project_q._is_hf_initialized = True # gumbel softmax requires special init elif isinstance(module, Wav2Vec2ConformerGumbelVectorQuantizer): - module.weight_proj.weight.data.normal_(mean=0.0, std=1) - module.weight_proj.bias.data.zero_() + module.weight_proj.weight.normal_(mean=0.0, std=1) + module.weight_proj.bias.zero_() nn.init.uniform_(module.codevectors) elif isinstance(module, Wav2Vec2ConformerSelfAttention): if hasattr(module, "pos_bias_u"): @@ -580,13 +579,13 @@ def _init_weights(self, module): nn.init.uniform_(module.projection.weight, a=-k, b=k) nn.init.uniform_(module.projection.bias, a=-k, b=k) elif isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, nn.Conv1d): nn.init.kaiming_normal_(module.weight) diff --git a/src/transformers/models/wavlm/modeling_wavlm.py b/src/transformers/models/wavlm/modeling_wavlm.py index 274d83fa8914..3a251db3258a 100755 --- a/src/transformers/models/wavlm/modeling_wavlm.py +++ b/src/transformers/models/wavlm/modeling_wavlm.py @@ -603,12 +603,13 @@ class WavLMPreTrainedModel(PreTrainedModel): _supports_sdpa = False _supports_flex_attn = False + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" # gumbel softmax requires special init if isinstance(module, WavLMGumbelVectorQuantizer): - module.weight_proj.weight.data.normal_(mean=0.0, std=1) - module.weight_proj.bias.data.zero_() + module.weight_proj.weight.normal_(mean=0.0, std=1) + module.weight_proj.bias.zero_() nn.init.uniform_(module.codevectors) elif isinstance(module, WavLMPositionalConvEmbedding): nn.init.normal_( @@ -622,13 +623,13 @@ def _init_weights(self, module): nn.init.uniform_(module.projection.weight, a=-k, b=k) nn.init.uniform_(module.projection.bias, a=-k, b=k) elif isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, nn.Conv1d): nn.init.kaiming_normal_(module.weight) @@ -1145,7 +1146,7 @@ def __init__(self, config, target_lang: Optional[str] = None): # Initialize weights and apply final processing self.post_init() - def tie_weights(self): + def tie_weights(self, missing_keys=None): """ This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when passing `target_lang=...` to `from_pretrained(...)`. diff --git a/src/transformers/models/wavlm/modular_wavlm.py b/src/transformers/models/wavlm/modular_wavlm.py index 4020f0b3335b..c50f2a4ec7e1 100644 --- a/src/transformers/models/wavlm/modular_wavlm.py +++ b/src/transformers/models/wavlm/modular_wavlm.py @@ -513,12 +513,13 @@ class WavLMPreTrainedModel(PreTrainedModel, Wav2Vec2PreTrainedModel): _supports_sdpa = False _supports_flex_attn = False + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" # gumbel softmax requires special init if isinstance(module, WavLMGumbelVectorQuantizer): - module.weight_proj.weight.data.normal_(mean=0.0, std=1) - module.weight_proj.bias.data.zero_() + module.weight_proj.weight.normal_(mean=0.0, std=1) + module.weight_proj.bias.zero_() nn.init.uniform_(module.codevectors) elif isinstance(module, WavLMPositionalConvEmbedding): nn.init.normal_( @@ -532,13 +533,13 @@ def _init_weights(self, module): nn.init.uniform_(module.projection.weight, a=-k, b=k) nn.init.uniform_(module.projection.bias, a=-k, b=k) elif isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, nn.Conv1d): nn.init.kaiming_normal_(module.weight) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 3fc03b3d54d5..6e91445ca961 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -538,24 +538,25 @@ class WhisperPreTrainedModel(PreTrainedModel): _can_compile_fullgraph = True + @torch.no_grad() def _init_weights(self, module): std = self.config.init_std if isinstance(module, (nn.Linear, nn.Conv1d)): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.weight.data.fill_(1.0) - module.bias.data.zero_() + module.weight.fill_(1.0) + module.bias.zero_() elif isinstance(module, WhisperEncoder): module.embed_positions.weight.copy_(sinusoids(*module.embed_positions.weight.shape)) elif isinstance(module, WhisperForAudioClassification): if self.config.use_weighted_layer_sum: - module.layer_weights.data.fill_(1.0 / (self.config.num_hidden_layers + 1)) + module.layer_weights.fill_(1.0 / (self.config.num_hidden_layers + 1)) def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): """ @@ -1097,7 +1098,7 @@ def forward( ) class WhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTrainedModel): base_model_prefix = "model" - _tied_weights_keys = ["proj_out.weight"] + _tied_weights_keys = {"proj_out.weight": "model.decoder.embed_tokens.weight"} def __init__(self, config: WhisperConfig): super().__init__(config) @@ -1278,7 +1279,7 @@ def forward(self, *args, **kwargs): """ ) class WhisperForCausalLM(WhisperPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["proj_out.weight"] + _tied_weights_keys = {"proj_out.weight": "model.decoder.embed_tokens.weight"} main_input_name = "input_ids" def __init__(self, config): diff --git a/src/transformers/models/x_clip/modeling_x_clip.py b/src/transformers/models/x_clip/modeling_x_clip.py index 7d59d57341e8..36be6ad43294 100644 --- a/src/transformers/models/x_clip/modeling_x_clip.py +++ b/src/transformers/models/x_clip/modeling_x_clip.py @@ -504,12 +504,13 @@ class XCLIPPreTrainedModel(PreTrainedModel): input_modalities = ["image", "text"] supports_gradient_checkpointing = True + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" factor = self.config.initializer_factor if isinstance(module, XCLIPTextEmbeddings): - module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) - module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) + module.token_embedding.weight.normal_(mean=0.0, std=factor * 0.02) + module.position_embedding.weight.normal_(mean=0.0, std=factor * 0.02) elif isinstance(module, XCLIPVisionEmbeddings): factor = self.config.initializer_factor nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) @@ -544,12 +545,12 @@ def _init_weights(self, module): nn.init.normal_(module.position_embedding, std=self.config.initializer_factor) if isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_factor) + module.weight.normal_(mean=0.0, std=self.config.initializer_factor) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->XCLIP diff --git a/src/transformers/models/xcodec/modeling_xcodec.py b/src/transformers/models/xcodec/modeling_xcodec.py index 774f9c74b8de..7e5b802e72f7 100644 --- a/src/transformers/models/xcodec/modeling_xcodec.py +++ b/src/transformers/models/xcodec/modeling_xcodec.py @@ -327,26 +327,27 @@ class XcodecPreTrainedModel(PreTrainedAudioTokenizerBase): main_input_name = "input_values" input_modalities = "audio" + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, nn.Conv1d): nn.init.kaiming_normal_(module.weight) if module.bias is not None: k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) nn.init.uniform_(module.bias, a=-k, b=k) elif module.__class__.__name__ == "Snake1d": - module.alpha.data.fill_(1.0) + module.alpha.fill_(1.0) elif isinstance(module, nn.ConvTranspose1d): module.reset_parameters() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=0.02) + module.weight.normal_(mean=0.0, std=0.02) elif isinstance(module, XcodecModel): # The conv1d are not handled correctly, as `self.acoustic_encoder/decoder` are initialized from a PreTrainedModel, # but then only the submodules are used (which are not PreTrainedModels...) -> here we reinit them as in DacModel @@ -354,10 +355,12 @@ def _init_weights(self, module): if isinstance(submodule, nn.Conv1d): nn.init.trunc_normal_(submodule.weight, std=0.02) nn.init.constant_(submodule.bias, 0) + submodule._is_hf_initialized = True for submodule in module.acoustic_decoder.modules(): if isinstance(submodule, nn.Conv1d): nn.init.trunc_normal_(submodule.weight, std=0.02) nn.init.constant_(submodule.bias, 0) + submodule._is_hf_initialized = True def apply_weight_norm(self): """Apply weight norm in the acoustic encoder and decoder because the original checkpoint has weight norm applied.""" @@ -401,9 +404,8 @@ def __init__(self, config): super().__init__(config) self.config = config self.pad = config.hop_length // 2 - acoustic_model = AutoModel.from_config(config.acoustic_model_config) - self.acoustic_encoder = acoustic_model.encoder - self.acoustic_decoder = acoustic_model.decoder + self.acoustic_model = AutoModel.from_config(config.acoustic_model_config) + self._adjust_dac_decoder(self.acoustic_decoder) self.encoder_semantic = SemanticEncoder(config) self.decoder_semantic = SemanticDecoder(config) @@ -416,6 +418,14 @@ def __init__(self, config): # Initialize weights and apply final processing self.post_init() + @property + def acoustic_encoder(self): + return self.acoustic_model.encoder + + @property + def acoustic_decoder(self): + return self.acoustic_model.decoder + @staticmethod def _adjust_dac_decoder(decoder: nn.Module): r""" diff --git a/src/transformers/models/xglm/modeling_xglm.py b/src/transformers/models/xglm/modeling_xglm.py index c5a59fe8b3d9..6edd50844c25 100755 --- a/src/transformers/models/xglm/modeling_xglm.py +++ b/src/transformers/models/xglm/modeling_xglm.py @@ -361,21 +361,22 @@ class XGLMPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["XGLMDecoderLayer"] + @torch.no_grad() def _init_weights(self, module): std = self.config.init_std if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() @auto_docstring class XGLMModel(XGLMPreTrainedModel): - def __init__(self, config: XGLMConfig, embed_tokens: Optional[nn.Embedding] = None): + def __init__(self, config: XGLMConfig): r""" embed_tokens (`nn.Embedding`, *optional*): output embeddings @@ -387,12 +388,9 @@ def __init__(self, config: XGLMConfig, embed_tokens: Optional[nn.Embedding] = No self.max_target_positions = config.max_position_embeddings embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 - if embed_tokens is not None: - self.embed_tokens = embed_tokens - else: - self.embed_tokens = XGLMScaledWordEmbedding( - config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale - ) + self.embed_tokens = XGLMScaledWordEmbedding( + config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale + ) self.embed_positions = XGLMSinusoidalPositionalEmbedding( config.max_position_embeddings, @@ -559,7 +557,7 @@ def forward( ) class XGLMForCausalLM(XGLMPreTrainedModel, GenerationMixin): base_model_prefix = "model" - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/xlm/modeling_xlm.py b/src/transformers/models/xlm/modeling_xlm.py index 856a84c76007..5ed343824902 100755 --- a/src/transformers/models/xlm/modeling_xlm.py +++ b/src/transformers/models/xlm/modeling_xlm.py @@ -614,21 +614,22 @@ def dummy_inputs(self): langs_list = None return {"input_ids": inputs_list, "attention_mask": attns_list, "langs": langs_list} + @torch.no_grad() def _init_weights(self, module): """Initialize the weights.""" if isinstance(module, nn.Embedding): if self.config is not None and self.config.embed_init_std is not None: nn.init.normal_(module.weight, mean=0, std=self.config.embed_init_std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() if isinstance(module, nn.Linear): if self.config is not None and self.config.init_std is not None: nn.init.normal_(module.weight, mean=0, std=self.config.init_std) if module.bias is not None: nn.init.constant_(module.bias, 0.0) if isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) if isinstance(module, XLMModel) and self.config.sinusoidal_embeddings: create_sinusoidal_embeddings( self.config.max_position_embeddings, self.config.emb_dim, out=module.position_embeddings.weight @@ -921,7 +922,7 @@ def forward(self, x, y=None): """ ) class XLMWithLMHeadModel(XLMPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["pred_layer.proj.weight"] + _tied_weights_keys = {"pred_layer.proj.weight": "transformer.embeddings.weight"} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index 074755d68362..05fa46b23f54 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -383,7 +383,6 @@ def __init__(self, config): self.decoder = nn.Linear(config.hidden_size, config.vocab_size) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - self.decoder.bias = self.bias def forward(self, features, **kwargs): x = self.dense(features) @@ -395,14 +394,6 @@ def forward(self, features, **kwargs): return x - def _tie_weights(self): - # To tie those two weights if they get disconnected (on TPU or when the bias is resized) - # For accelerate compatibility and to not break backward compatibility - if self.decoder.bias.device.type == "meta": - self.decoder.bias = self.bias - else: - self.bias = self.decoder.bias - @auto_docstring class XLMRobertaPreTrainedModel(PreTrainedModel): @@ -419,21 +410,22 @@ class XLMRobertaPreTrainedModel(PreTrainedModel): "cross_attentions": XLMRobertaCrossAttention, } + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, XLMRobertaLMHead): - module.bias.data.zero_() + module.bias.zero_() class XLMRobertaEmbeddings(nn.Module): @@ -738,7 +730,10 @@ def _create_attention_masks( """ ) class XLMRobertaForCausalLM(XLMRobertaPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] + _tied_weights_keys = { + "lm_head.decoder.weight": "roberta.embeddings.word_embeddings.weight", + "lm_head.decoder.bias": "lm_head.bias", + } def __init__(self, config): super().__init__(config) @@ -746,7 +741,6 @@ def __init__(self, config): if not config.is_decoder: logger.warning("If you want to use `XLMRobertaLMHeadModel` as a standalone, add `is_decoder=True.`") self.lm_head = XLMRobertaLMHead(config) - self.roberta = XLMRobertaModel(config, add_pooling_layer=False) # Initialize weights and apply final processing @@ -844,7 +838,10 @@ def forward( @auto_docstring class XLMRobertaForMaskedLM(XLMRobertaPreTrainedModel): - _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] + _tied_weights_keys = { + "lm_head.decoder.weight": "roberta.embeddings.word_embeddings.weight", + "lm_head.decoder.bias": "lm_head.bias", + } def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/xlm_roberta/modular_xlm_roberta.py b/src/transformers/models/xlm_roberta/modular_xlm_roberta.py index 4b61a30f7190..fa42c3e9123f 100644 --- a/src/transformers/models/xlm_roberta/modular_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modular_xlm_roberta.py @@ -60,10 +60,14 @@ class XLMRobertaModel(RobertaModel): """ ) class XLMRobertaForCausalLM(RobertaForCausalLM): + _tied_weights_keys = { + "lm_head.decoder.weight": "roberta.embeddings.word_embeddings.weight", + "lm_head.decoder.bias": "lm_head.bias", + } + def __init__(self, config): super().__init__(config) del self.xlm_roberta - self.roberta = XLMRobertaModel(config, add_pooling_layer=False) @can_return_tuple @@ -152,6 +156,11 @@ def forward( @auto_docstring class XLMRobertaForMaskedLM(RobertaForMaskedLM): + _tied_weights_keys = { + "lm_head.decoder.weight": "roberta.embeddings.word_embeddings.weight", + "lm_head.decoder.bias": "lm_head.bias", + } + def __init__(self, config): super().__init__(config) del self.xlm_roberta diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index a0f13d505d6e..a6200dc1ddde 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -535,21 +535,22 @@ class XLMRobertaXLPreTrainedModel(PreTrainedModel): "cross_attentions": XLMRobertaXLCrossAttention, } + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, XLMRobertaXLLMHead): - module.bias.data.zero_() + module.bias.zero_() class XLMRobertaXLPooler(nn.Module): @@ -729,7 +730,6 @@ def __init__(self, config): self.decoder = nn.Linear(config.hidden_size, config.vocab_size) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - self.decoder.bias = self.bias def forward(self, features, **kwargs): x = self.dense(features) @@ -741,14 +741,6 @@ def forward(self, features, **kwargs): return x - def _tie_weights(self) -> None: - # For accelerate compatibility and to not break backward compatibility - if self.decoder.bias.device.type == "meta": - self.decoder.bias = self.bias - else: - # To tie those two weights if they get disconnected (on TPU or when the bias is resized) - self.bias = self.decoder.bias - class XLMRobertaXLClassificationHead(nn.Module): """Head for sentence-level classification tasks.""" @@ -778,7 +770,10 @@ def forward(self, features, **kwargs): """ ) class XLMRobertaXLForCausalLM(XLMRobertaXLPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] + _tied_weights_keys = { + "lm_head.decoder.weight": "roberta.embeddings.word_embeddings.weight", + "lm_head.decoder.bias": "lm_head.bias", + } def __init__(self, config): super().__init__(config) @@ -875,7 +870,10 @@ def forward( @auto_docstring class XLMRobertaXLForMaskedLM(XLMRobertaXLPreTrainedModel): - _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] + _tied_weights_keys = { + "lm_head.decoder.weight": "roberta.embeddings.word_embeddings.weight", + "lm_head.decoder.bias": "lm_head.bias", + } def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/xlm_roberta_xl/modular_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modular_xlm_roberta_xl.py index bca175a6934e..ec2dcf9a0a39 100644 --- a/src/transformers/models/xlm_roberta_xl/modular_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modular_xlm_roberta_xl.py @@ -244,7 +244,6 @@ def __init__(self, config): self.decoder = nn.Linear(config.hidden_size, config.vocab_size) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - self.decoder.bias = self.bias def forward(self, features, **kwargs): x = self.dense(features) @@ -256,14 +255,6 @@ def forward(self, features, **kwargs): return x - def _tie_weights(self) -> None: - # For accelerate compatibility and to not break backward compatibility - if self.decoder.bias.device.type == "meta": - self.decoder.bias = self.bias - else: - # To tie those two weights if they get disconnected (on TPU or when the bias is resized) - self.bias = self.decoder.bias - class XLMRobertaXLClassificationHead(RobertaClassificationHead): pass @@ -275,7 +266,10 @@ class XLMRobertaXLClassificationHead(RobertaClassificationHead): """ ) class XLMRobertaXLForCausalLM(XLMRobertaXLPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] + _tied_weights_keys = { + "lm_head.decoder.weight": "roberta.embeddings.word_embeddings.weight", + "lm_head.decoder.bias": "lm_head.bias", + } def __init__(self, config): super().__init__(config) @@ -372,7 +366,10 @@ def forward( @auto_docstring class XLMRobertaXLForMaskedLM(XLMRobertaXLPreTrainedModel): - _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] + _tied_weights_keys = { + "lm_head.decoder.weight": "roberta.embeddings.word_embeddings.weight", + "lm_head.decoder.bias": "lm_head.bias", + } def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/xlnet/modeling_xlnet.py b/src/transformers/models/xlnet/modeling_xlnet.py index 67f9f1bf7874..a52ae140e77d 100755 --- a/src/transformers/models/xlnet/modeling_xlnet.py +++ b/src/transformers/models/xlnet/modeling_xlnet.py @@ -635,19 +635,20 @@ class XLNetPreTrainedModel(PreTrainedModel): config: XLNetConfig base_model_prefix = "transformer" + @torch.no_grad() def _init_weights(self, module): """Initialize the weights.""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, XLNetRelativeAttention): for param in [ module.q, @@ -660,9 +661,9 @@ def _init_weights(self, module): module.r_w_bias, module.seg_embed, ]: - param.data.normal_(mean=0.0, std=self.config.initializer_range) + param.normal_(mean=0.0, std=self.config.initializer_range) elif isinstance(module, XLNetModel): - module.mask_emb.data.normal_(mean=0.0, std=self.config.initializer_range) + module.mask_emb.normal_(mean=0.0, std=self.config.initializer_range) @dataclass @@ -1233,7 +1234,7 @@ def forward( """ ) class XLNetLMHeadModel(XLNetPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_loss.weight"] + _tied_weights_keys = {"lm_loss.weight": "transformer.word_embedding.weight"} def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/xlstm/modeling_xlstm.py b/src/transformers/models/xlstm/modeling_xlstm.py index ee9627c404a6..9cfd5e95f3da 100644 --- a/src/transformers/models/xlstm/modeling_xlstm.py +++ b/src/transformers/models/xlstm/modeling_xlstm.py @@ -1241,6 +1241,7 @@ def _module_name_map(self, module): return name return "" + @torch.no_grad() def _init_weights(self, module): if isinstance(module, nn.Embedding): small_init_method(self.config.hidden_size)(self.embeddings.weight) diff --git a/src/transformers/models/xmod/modeling_xmod.py b/src/transformers/models/xmod/modeling_xmod.py index fc9cfca7359d..b50d4fb64600 100644 --- a/src/transformers/models/xmod/modeling_xmod.py +++ b/src/transformers/models/xmod/modeling_xmod.py @@ -627,22 +627,22 @@ class XmodPreTrainedModel(PreTrainedModel): "cross_attentions": XmodCrossAttention, } - # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights with BertLMPredictionHead->XmodLMHead + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, XmodLMHead): - module.bias.data.zero_() + module.bias.zero_() def set_default_language(self, language: str): """ @@ -852,7 +852,10 @@ def _create_attention_masks( """ ) class XmodForCausalLM(XmodPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] + _tied_weights_keys = { + "lm_head.decoder.weight": "roberta.embeddings.word_embeddings.weight", + "lm_head.decoder.bias": "lm_head.bias", + } # Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM.__init__ with Roberta->Xmod def __init__(self, config): @@ -960,7 +963,10 @@ def forward( @auto_docstring class XmodForMaskedLM(XmodPreTrainedModel): - _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] + _tied_weights_keys = { + "lm_head.decoder.weight": "roberta.embeddings.word_embeddings.weight", + "lm_head.decoder.bias": "lm_head.bias", + } # Copied from transformers.models.roberta.modeling_roberta.RobertaForMaskedLM.__init__ with Roberta->Xmod def __init__(self, config): @@ -1049,7 +1055,6 @@ def __init__(self, config): self.decoder = nn.Linear(config.hidden_size, config.vocab_size) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - self.decoder.bias = self.bias def forward(self, features, **kwargs): x = self.dense(features) @@ -1061,14 +1066,6 @@ def forward(self, features, **kwargs): return x - def _tie_weights(self): - # To tie those two weights if they get disconnected (on TPU or when the bias is resized) - # For accelerate compatibility and to not break backward compatibility - if self.decoder.bias.device.type == "meta": - self.decoder.bias = self.bias - else: - self.bias = self.decoder.bias - @auto_docstring( custom_intro=""" diff --git a/src/transformers/models/yolos/modeling_yolos.py b/src/transformers/models/yolos/modeling_yolos.py index 527b4d34c3b1..edd6cfd5b10e 100755 --- a/src/transformers/models/yolos/modeling_yolos.py +++ b/src/transformers/models/yolos/modeling_yolos.py @@ -445,15 +445,16 @@ class YolosPreTrainedModel(PreTrainedModel): "attentions": YolosSelfAttention, } + @torch.no_grad() def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) @auto_docstring diff --git a/src/transformers/models/yoso/modeling_yoso.py b/src/transformers/models/yoso/modeling_yoso.py index ac79fe54b4c4..ce945d24bdb9 100644 --- a/src/transformers/models/yoso/modeling_yoso.py +++ b/src/transformers/models/yoso/modeling_yoso.py @@ -578,16 +578,9 @@ def __init__(self, config): # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. - self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` - self.decoder.bias = self.bias - - def _tie_weights(self): - self.decoder.bias = self.bias - def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -611,22 +604,23 @@ class YosoPreTrainedModel(PreTrainedModel): base_model_prefix = "yoso" supports_gradient_checkpointing = True + @torch.no_grad() def _init_weights(self, module: nn.Module): """Initialize the weights""" std = self.config.initializer_range if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) elif isinstance(module, YosoLMPredictionHead): - module.bias.data.zero_() + module.bias.zero_() @auto_docstring @@ -717,7 +711,10 @@ def forward( @auto_docstring class YosoForMaskedLM(YosoPreTrainedModel): - _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"] + _tied_weights_keys = { + "cls.predictions.decoder.bias": "cls.predictions.bias", + "cls.predictions.decoder.weight": "yoso.embeddings.word_embeddings.weight", + } def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index a144fbd589cf..322a762495c5 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -792,20 +792,21 @@ class ZambaPreTrainedModel(PreTrainedModel): # Note: only supports ZambaHybridDynamicCache _is_stateful = True + @torch.no_grad() def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, (nn.Linear, nn.Conv1d)): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) + module.weight.normal_(mean=0.0, std=std) if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + module.weight[module.padding_idx].zero_() elif isinstance(module, ZambaRMSNorm): - module.weight.data.fill_(1.0) + module.weight.fill_(1.0) elif isinstance(module, ZambaMambaMixer): - module.x_proj_weight.data.normal_(mean=0.0, std=std) + module.x_proj_weight.normal_(mean=0.0, std=std) dt_init_std = self.config.mamba_dt_rank**-0.5 nn.init.uniform_(module.dt_proj_weight, -dt_init_std, dt_init_std) @@ -817,12 +818,12 @@ def _init_weights(self, module): ).clamp(min=self.config.time_step_floor) # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 inv_dt = dt + torch.log(-torch.expm1(-dt)) - module.dt_proj_bias.data.copy_(inv_dt) + module.dt_proj_bias.copy_(inv_dt) A = torch.arange(1, module.ssm_state_size + 1, dtype=torch.float32)[None, :] A = A.expand(module.intermediate_size, -1).contiguous() - module.A_log.data.copy_(torch.log(A).reshape(module.n_mamba_heads, module.mamba_head_dim, -1)) - module.D.data.fill_(1.0) + module.A_log.copy_(torch.log(A).reshape(module.n_mamba_heads, module.mamba_head_dim, -1)) + module.D.fill_(1.0) @auto_docstring @@ -853,7 +854,7 @@ def __init__(self, config: ZambaConfig): mamba_layers = iter(mamba_layers) linear_layers = iter(linear_layers) layers = [] - self._tied_weights_keys = [] + self._tied_weights_keys = {} for layer_id, layer_type in enumerate(self.layers_block_type): if layer_type == "hybrid": prefix_name = f"layers.{layer_id}." @@ -868,7 +869,7 @@ def __init__(self, config: ZambaConfig): "shared_transf.input_layernorm.weight", "shared_transf.pre_ff_layernorm.weight", ] - self._tied_weights_keys = [*self._tied_weights_keys, *[prefix_name + key for key in tied_keys]] + self._tied_weights_keys.update({prefix_name + key: f"layers.0.{key}" for key in tied_keys}) layers.append(ZambaHybridLayer(block, next(linear_layers), next(mamba_layers))) else: layers.append(next(mamba_layers)) @@ -1033,10 +1034,11 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position): # Adapted from transformers.models.jamba.modeling_jamba.JambaForCausalLM with Jamba->Zamba, JAMBA->ZAMBA class ZambaForCausalLM(ZambaPreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} + def __init__(self, config: ZambaConfig): super().__init__(config) self.model = ZambaModel(config) - self._tied_weights_keys = ["lm_head.weight", *self.model._tied_weights_keys] self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 8f6efc7dbe1c..6a6544f15856 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -20,7 +20,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -import re from collections.abc import Callable from itertools import cycle from typing import Any, Optional, Union @@ -1215,6 +1214,7 @@ class Zamba2PreTrainedModel(PreTrainedModel): # Note: only supports Zamba2HybridDynamicCache _is_stateful = True + @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) if isinstance(module, Zamba2MambaMixer): @@ -1225,11 +1225,11 @@ def _init_weights(self, module): ).clamp(min=self.config.time_step_floor) # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 inv_dt = dt + torch.log(-torch.expm1(-dt)) - module.dt_bias.data.copy_(inv_dt) + module.dt_bias.copy_(inv_dt) A = torch.arange(1, module.num_heads + 1) - module.A_log.data.copy_(torch.log(A)) - module.D.data.fill_(1.0) + module.A_log.copy_(torch.log(A)) + module.D.fill_(1.0) @auto_docstring @@ -1436,47 +1436,14 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position): def get_layers(self, blocks, linear_layers, mamba_layers): layers = [] - self._tied_weights_keys = [] + self._tied_weights_keys = {} self.first_transformer_layer_id = 0 for layer_id, layer_type in enumerate(self.layers_block_type): if layer_type == "hybrid": - if self.first_transformer_layer_id == 0: - self.first_transformer_layer_id = layer_id block = next(blocks) if self.config.num_mem_blocks * len(self.config.hybrid_layer_ids) > 1: - prefix_pattern = rf"^layers\.{layer_id}\.shared_transformer\." - main_keys_pattern = re.compile( - prefix_pattern - + r"(?:" - + r"self_attn\.(?:q_proj|k_proj|v_proj|o_proj)\.weight|" - + r"feed_forward\.(?:gate_up_proj|down_proj)\.weight|" - + r"(?:input_layernorm|pre_ff_layernorm)\.weight" - + r")$" - ) - self._tied_weights_keys.append(main_keys_pattern) - - adapter_id = 0 - for _layer_type in self.layers_block_type: - if _layer_type == "hybrid" and adapter_id % self.config.num_mem_blocks == block.block_id: - adapter_pattern = re.compile( - r"^shared_transformer\.feed_forward\.gate_up_proj_adapter_list\." - + str(adapter_id) - + r"\.(?:0|1)\.weight$" - ) - self._tied_weights_keys.append(adapter_pattern) - adapter_id += 1 - if self.config.use_shared_attention_adapter: - adapter_id = 0 - for _layer_type in self.layers_block_type: - if _layer_type == "hybrid" and adapter_id % self.config.num_mem_blocks == block.block_id: - attn_adapter_pattern = re.compile( - r"^shared_transformer\.self_attn\." - + r"(?:linear_q_adapter_list|linear_k_adapter_list|linear_v_adapter_list)\." - + str(adapter_id) - + r"\.(?:0|1)\.weight$" - ) - self._tied_weights_keys.append(attn_adapter_pattern) - adapter_id += 1 + prefix_pattern = f"layers.{layer_id}.shared_transformer" + self._tied_weights_keys.update({prefix_pattern: "layers.0.shared_transformer"}) layers.append(Zamba2HybridLayer(block, next(linear_layers), next(mamba_layers))) else: layers.append(next(mamba_layers)) @@ -1485,10 +1452,11 @@ def get_layers(self, blocks, linear_layers, mamba_layers): # Adapted from transformers.models.jamba.modeling_jamba.JambaForCausalLM with Jamba->Zamba2, JAMBA->ZAMBA2 class Zamba2ForCausalLM(Zamba2PreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} + def __init__(self, config: Zamba2Config): super().__init__(config) self.model = Zamba2Model(config) - self._tied_weights_keys = ["lm_head.weight", *self.model._tied_weights_keys] self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index b884e2b38e4a..33499e6bdef5 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -import re from collections.abc import Callable from itertools import cycle from typing import Optional, Union @@ -904,6 +903,7 @@ class Zamba2PreTrainedModel(PreTrainedModel): # Note: only supports Zamba2HybridDynamicCache _is_stateful = True + @torch.no_grad() def _init_weights(self, module): super()._init_weights(module) if isinstance(module, Zamba2MambaMixer): @@ -914,11 +914,11 @@ def _init_weights(self, module): ).clamp(min=self.config.time_step_floor) # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 inv_dt = dt + torch.log(-torch.expm1(-dt)) - module.dt_bias.data.copy_(inv_dt) + module.dt_bias.copy_(inv_dt) A = torch.arange(1, module.num_heads + 1) - module.A_log.data.copy_(torch.log(A)) - module.D.data.fill_(1.0) + module.A_log.copy_(torch.log(A)) + module.D.fill_(1.0) class Zamba2Model(ZambaModel, Zamba2PreTrainedModel): @@ -967,47 +967,14 @@ def __init__(self, config: Zamba2Config): def get_layers(self, blocks, linear_layers, mamba_layers): layers = [] - self._tied_weights_keys = [] + self._tied_weights_keys = {} self.first_transformer_layer_id = 0 for layer_id, layer_type in enumerate(self.layers_block_type): if layer_type == "hybrid": - if self.first_transformer_layer_id == 0: - self.first_transformer_layer_id = layer_id block = next(blocks) if self.config.num_mem_blocks * len(self.config.hybrid_layer_ids) > 1: - prefix_pattern = rf"^layers\.{layer_id}\.shared_transformer\." - main_keys_pattern = re.compile( - prefix_pattern - + r"(?:" - + r"self_attn\.(?:q_proj|k_proj|v_proj|o_proj)\.weight|" - + r"feed_forward\.(?:gate_up_proj|down_proj)\.weight|" - + r"(?:input_layernorm|pre_ff_layernorm)\.weight" - + r")$" - ) - self._tied_weights_keys.append(main_keys_pattern) - - adapter_id = 0 - for _layer_type in self.layers_block_type: - if _layer_type == "hybrid" and adapter_id % self.config.num_mem_blocks == block.block_id: - adapter_pattern = re.compile( - r"^shared_transformer\.feed_forward\.gate_up_proj_adapter_list\." - + str(adapter_id) - + r"\.(?:0|1)\.weight$" - ) - self._tied_weights_keys.append(adapter_pattern) - adapter_id += 1 - if self.config.use_shared_attention_adapter: - adapter_id = 0 - for _layer_type in self.layers_block_type: - if _layer_type == "hybrid" and adapter_id % self.config.num_mem_blocks == block.block_id: - attn_adapter_pattern = re.compile( - r"^shared_transformer\.self_attn\." - + r"(?:linear_q_adapter_list|linear_k_adapter_list|linear_v_adapter_list)\." - + str(adapter_id) - + r"\.(?:0|1)\.weight$" - ) - self._tied_weights_keys.append(attn_adapter_pattern) - adapter_id += 1 + prefix_pattern = f"layers.{layer_id}.shared_transformer" + self._tied_weights_keys.update({prefix_pattern: "layers.0.shared_transformer"}) layers.append(Zamba2HybridLayer(block, next(linear_layers), next(mamba_layers))) else: layers.append(next(mamba_layers)) diff --git a/src/transformers/models/zoedepth/modeling_zoedepth.py b/src/transformers/models/zoedepth/modeling_zoedepth.py index eb2cc630c021..f077fd387dd3 100644 --- a/src/transformers/models/zoedepth/modeling_zoedepth.py +++ b/src/transformers/models/zoedepth/modeling_zoedepth.py @@ -1211,15 +1211,16 @@ class ZoeDepthPreTrainedModel(PreTrainedModel): input_modalities = "image" supports_gradient_checkpointing = True + @torch.no_grad() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: - module.bias.data.zero_() + module.bias.zero_() elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) + module.bias.zero_() + module.weight.fill_(1.0) @auto_docstring( diff --git a/src/transformers/quantizers/base.py b/src/transformers/quantizers/base.py index 5ba372a41fcb..e09dbb751e4b 100644 --- a/src/transformers/quantizers/base.py +++ b/src/transformers/quantizers/base.py @@ -12,13 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod +from copy import deepcopy from typing import TYPE_CHECKING, Any, Optional, Union -from ..utils import is_torch_available, logging +from ..utils import is_accelerate_available, is_torch_available, logging from ..utils.quantization_config import QuantizationConfigMixin, QuantizationMethod from .quantizers_utils import get_module_from_name +if is_accelerate_available(): + from accelerate.utils import find_tied_parameters + if TYPE_CHECKING: from ..modeling_utils import PreTrainedModel @@ -41,6 +45,52 @@ def _assign_original_dtype(module, original_dtype): _assign_original_dtype(child, original_dtype) +def get_keys_to_not_convert(model): + r""" + An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules + we may want to keep the lm_head in full precision for numerical stability reasons. For other architectures, we want + to keep the tied weights of the model. The function will return a list of the keys of the modules to not convert in + int8. + + Parameters: + model (`torch.nn.Module`): + Input model + """ + # Create a copy of the model and tie the weights, then + # check if it contains tied weights + tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager` + tied_model.tie_weights() + + tied_params = find_tied_parameters(tied_model) + tied_keys = sum(tied_params, []) + has_tied_params = len(tied_keys) > 0 + + # If there is not tied weights, we want to keep the lm_head(output_embedding) in full precision + if not has_tied_params: + output_emb = model.get_output_embeddings() + if output_emb is not None: + list_last_module = [name for name, module in model.named_modules() if id(module) == id(output_emb)] + return list_last_module + + # otherwise, no tied weights, no output embedding defined, simply keep the last module in full precision + list_modules = list(model.named_parameters()) + list_last_module = [list_modules[-1][0]] + # add last module together with tied weights + intersection = set(list_last_module) - set(tied_keys) + list_untouched = list(set(tied_keys)) + list(intersection) + + # remove ".weight" from the keys + names_to_remove = [".weight", ".bias"] + filtered_module_names = [] + for name in list_untouched: + for name_to_remove in names_to_remove: + if name_to_remove in name: + name = name.replace(name_to_remove, "") + filtered_module_names.append(name) + + return filtered_module_names + + class HfQuantizer(ABC): """ Abstract class of the HuggingFace quantizer. Supports for now quantizing HF transformers models for inference and/or quantization. @@ -315,8 +365,6 @@ def get_modules_to_not_convert( keep_in_fp32_modules: Optional[list[str]] = None, add_default_skips: bool = False, ): - from ..integrations import get_keys_to_not_convert - if skip_modules is None or add_default_skips: modules_to_not_convert = get_keys_to_not_convert(model) else: @@ -370,6 +418,17 @@ def _convert_model_for_quantization(self, model): model.config.get_text_config() ) + def get_quantize_ops(self): + raise NotImplementedError( + f"{self.quantization_config.quant_method} is not available yet and will be supported soon." + ) + + def is_valid_unexpected_keys(self, k): + """ + Check if the keys is valid or not even if it is not in the state_dict of the meta model. + This is because the state dict of the model might change after quantization like for 4bit bnb + """ + return False class SequentialLlama4TextExperts(ModuleList): """ diff --git a/src/transformers/quantizers/quantizer_bnb_4bit.py b/src/transformers/quantizers/quantizer_bnb_4bit.py index d1aefcd3a988..ce9eff223ed0 100644 --- a/src/transformers/quantizers/quantizer_bnb_4bit.py +++ b/src/transformers/quantizers/quantizer_bnb_4bit.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections import defaultdict -from functools import cached_property from typing import TYPE_CHECKING, Optional, Union from .base import HfQuantizer @@ -118,7 +117,10 @@ def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": return CustomDtype.INT4 def update_unexpected_keys(self, model, unexpected_keys: list[str]) -> list[str]: - return [k for k in unexpected_keys if not any(k.endswith(x) for x in self.bnb_keys)] + return [k for k in unexpected_keys if not self.is_valid_unexpected_keys(k)] + + def is_valid_unexpected_keys(self, k): + return any(k.endswith(x) for x in self.bnb_keys) def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool: import bitsandbytes as bnb @@ -172,17 +174,13 @@ def create_quantized_param( # We are ready for quantization in this case (note, the +1 is for the weight itself) if len(self.param_quant_stats[module_name]) == len(self.bnb_keys) + 1: - param_kwargs = {} - if self.is_bnb_supports_quant_storage_module: - param_kwargs["module"] = module - weight = self.param_quant_stats[module_name].pop(f"{module_name}.weight") new_value = bnb.nn.Params4bit.from_prequantized( data=weight, quantized_stats=self.param_quant_stats[module_name], requires_grad=False, device=target_device, - **param_kwargs, + module=module ) # Set it module._parameters[tensor_name] = new_value @@ -285,15 +283,6 @@ def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs def is_serializable(self, safe_serialization=None): return True - @cached_property - def is_bnb_supports_quant_storage_module(self) -> bool: - """ - determines if the current version of bitsandbytes supports - the `module` parameter in `Params4bit.from_prequantized` - :return: - """ - return True - @property def is_trainable(self) -> bool: return True @@ -305,3 +294,7 @@ def _dequantize(self, model): model, self.modules_to_not_convert, quantization_config=self.quantization_config ) return model + + def get_quantize_ops(self): + from ..integrations.bitsandbytes import Bnb4bitQuantize + return Bnb4bitQuantize(self) diff --git a/src/transformers/quantizers/quantizer_finegrained_fp8.py b/src/transformers/quantizers/quantizer_finegrained_fp8.py index 326ee8c015ab..d9a73e4c7a9a 100644 --- a/src/transformers/quantizers/quantizer_finegrained_fp8.py +++ b/src/transformers/quantizers/quantizer_finegrained_fp8.py @@ -75,6 +75,8 @@ def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype": dtype = torch.float32 return dtype + # TODO: make this into a `ConversionType` ops -> potentially requires all weights on all ranks + # depending on the layer type (moe -> no if ep) def create_quantized_param( self, model: "PreTrainedModel", @@ -93,8 +95,9 @@ def create_quantized_param( if tensor_name == "weight" and param_value.dtype != torch.float8_e4m3fn: raise ValueError("Expect quantized weights but got an unquantized weight") else: - if tensor_name == "weight_scale_inv": - raise ValueError("Expect unquantized weights but got a quantized weight_scale") + return + # if tensor_name == "weight_scale_inv": + # raise ValueError("Expect unquantized weights but got a quantized weight_scale") param_value = param_value.to(target_device) @@ -137,10 +140,10 @@ def create_quantized_param( _load_parameter_into_model(model, param_name.rsplit(".", 1)[0] + ".weight_scale_inv", scale) def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool: - from ..integrations.finegrained_fp8 import FP8Linear + from ..integrations.finegrained_fp8 import FP8Expert, FP8Linear module, tensor_name = get_module_from_name(model, param_name) - if isinstance(module, FP8Linear): + if isinstance(module, (FP8Linear, FP8Expert)): if self.pre_quantized or tensor_name == "bias": return False else: @@ -155,10 +158,12 @@ def _process_model_before_weight_loading( ): from ..integrations.finegrained_fp8 import replace_with_fp8_linear + # takes 2 fucking seconds self.modules_to_not_convert = self.get_modules_to_not_convert( model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules ) + # while this one is 81ms :) model = replace_with_fp8_linear( model, modules_to_not_convert=self.modules_to_not_convert, @@ -182,6 +187,10 @@ def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> li not_missing_keys.append(missing) return [k for k in missing_keys if k not in not_missing_keys] + # NOTE: TP is applied before quantization so this is only to add hooks. + # Quantization is incompatible with DTensors, so we have to anyway have + # gathers! But it should be model independant -> figure out where to put + # the gather and that's it. def update_tp_plan(self, config): if "Qwen3" in config.__class__.__name__: text_plan = { @@ -217,3 +226,7 @@ def is_trainable(self) -> bool: def get_accelerator_warm_up_factor(self): # Pre-processing is done cleanly, so we can allocate everything here return 2 + + def get_quantize_ops(self): + from ..integrations.finegrained_fp8 import Fp8Quantize + return Fp8Quantize(self) diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index c1c5f66f4aac..2e00e4516195 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -156,7 +156,6 @@ def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype": def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool: from ..integrations import Mxfp4GptOssExperts from ..models.gpt_oss.modeling_gpt_oss import GptOssExperts - # if we are dequantizing, the model doesn't have scales, and blocks only params like gate_up_proj and down_proj so we need to handle this case differently if self.quantization_config.dequantize and ("blocks" in param_name or "scales" in param_name): module, tensor_name = get_module_from_name(model, param_name[: -len("_blocks")]) @@ -417,6 +416,17 @@ def get_state_dict_and_metadata(self, model, safe_serialization: bool = False): metadata = {} return state_dict, metadata + def is_valid_unexpected_keys(self, k): + mxfp4_keys = ["_blocks", "_scales"] + if self.pre_quantized: + return any(k.endswith(x) for x in mxfp4_keys) + else: + return ["gate_up_proj", "down_proj"] + + def get_quantize_ops(self): + from ..integrations import Mxfp4Quantize + return Mxfp4Quantize(self) + def is_serializable(self, safe_serialization=None): return True diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index 9bc51f1bac65..4bafc3e84403 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -859,7 +859,7 @@ def wrapper(self, *args, **kwargs): # Check attention implementation is properly set for capturing attention outputs if recordable_keys.get("output_attentions", False): - supported_attn = ["eager", "eager_paged", "flex_attention"] + supported_attn = ["eager", "eager_paged", "flex_attention", "sdpa"] config_attn = getattr(self.config, "_attn_implementation", None) sub_configs = [getattr(self.config, key, None) for key in self.config.sub_configs] sub_configs_attn = [ @@ -877,13 +877,7 @@ def make_capture_wrapper(module, orig_forward, key, index): def wrapped_forward(*args, **kwargs): if key == "hidden_states" and len(collected_outputs[key]) == 0: collected_outputs[key] += (args[0],) - if kwargs.get("debug_io", False): - with model_addition_debugger_context( - module, kwargs.get("debug_io_dir", "~/model_debug"), kwargs.get("prune_layers") - ): - output = orig_forward(*args, **kwargs) - else: - output = orig_forward(*args, **kwargs) + output = orig_forward(*args, **kwargs) if not isinstance(output, tuple): collected_outputs[key] += (output,) elif output[index] is not None: @@ -924,7 +918,13 @@ def wrapped_forward(*args, **kwargs): monkey_patched_layers.append((module, original_forward)) try: - outputs = func(self, *args, **kwargs) + if kwargs.get("debug_io", False): + with model_addition_debugger_context( + self, kwargs.get("debug_io_dir", "model_debug"), kwargs.get("prune_layers") + ): + outputs = func(self, *args, **kwargs) + else: + outputs = func(self, *args, **kwargs) except TypeError as original_exception: # If we get a TypeError, it's possible that the model is not receiving the recordable kwargs correctly. # Get a TypeError even after removing the recordable kwargs -> re-raise the original exception diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index b38ea64cc4ff..bf2fba35fd0e 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -1178,9 +1178,12 @@ def is_mistral_common_available() -> bool: @lru_cache def is_opentelemetry_available() -> bool: - return _is_package_available("opentelemetry") and version.parse( - importlib.metadata.version("opentelemetry-api") - ) >= version.parse("1.30.0") + try: + return _is_package_available("opentelemetry") and version.parse( + importlib.metadata.version("opentelemetry-api") + ) >= version.parse("1.30.0") + except Exception as _: + return False def check_torch_load_is_safe() -> None: diff --git a/src/transformers/utils/loading_report.py b/src/transformers/utils/loading_report.py new file mode 100644 index 000000000000..17171af319ed --- /dev/null +++ b/src/transformers/utils/loading_report.py @@ -0,0 +1,243 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import re +import shutil +import sys +from collections import OrderedDict, defaultdict +from collections.abc import Iterable +from typing import Any, Optional + + +_DIGIT_RX = re.compile(r"(?<=\.)(\d+)(?=\.|$)") # numbers between dots or at the end + + +def _pattern_of(key: str) -> str: + """Replace every dot-delimited integer with '*' to get the structure.""" + return _DIGIT_RX.sub("*", key) + + +def _fmt_indices(values: list[int], cutoff=10) -> str: + """Format a list of ints as single number, {a, ..., b}, or first...last.""" + if len(values) == 1: + return str(values[0]) + values = sorted(values) + if len(values) > cutoff: + return f"{values[0]}...{values[-1]}" + return ", ".join(map(str, values)) + + +def update_key_name(mapping: dict[str, Any]) -> dict[str, Any]: + """ + Merge keys like 'layers.0.x', 'layers.1.x' into 'layers.{0, 1}.x' + BUT only merge together keys that have the exact same value. + Returns a new dict {merged_key: value}. + """ + # (pattern, value) -> list[set[int]] (per-star index values) + not_mapping = False + if not isinstance(mapping, dict): + mapping = {k: k for k in mapping} + not_mapping = True + + bucket: dict[tuple[str, Any], list[set[int]]] = defaultdict(list) + for key, val in mapping.items(): + digs = _DIGIT_RX.findall(key) + patt = _pattern_of(key) + for i, d in enumerate(digs): + if len(bucket[patt]) <= i: + bucket[patt].append(set()) + bucket[patt][i].add(int(d)) + bucket[patt].append(val) + + out_items = {} + for patt, values in bucket.items(): + sets, val = values[:-1], values[-1] + parts = patt.split("*") # stars are between parts + final = parts[0] + for i in range(1, len(parts)): + if i - 1 < len(sets) and sets[i - 1]: + insert = _fmt_indices(sorted(sets[i - 1])) + if len(sets[i - 1]) > 1: + final += "{" + insert + "}" + else: + final += insert + else: + final += "*" + final += parts[i] + + out_items[final] = val + out = OrderedDict(out_items) + if not_mapping: + return out.keys() + return out + + +# We have a class to simplify disabling ANSI colors +class ANSI: + palette = { + "reset": "", + "red": "", + "yellow": "", + "orange": "", + "purple": "", + "bold": "", + "italic": "", + "dim": "", + } + + def __init__(self, enable): + self.enable = enable + + def __getitem__(self, key): + return self.palette[key] if self.enable else "" + + +_ansi_re = re.compile(r"\x1b\[[0-9;]*m") + + +def _strip_ansi(s: str) -> str: + return _ansi_re.sub("", str(s)) + + +def _pad(text, width): + t = str(text) + pad = max(0, width - len(_strip_ansi(t))) + return t + " " * pad + + +def _make_table(rows, headers): + # compute display widths while ignoring ANSI codes + cols = list(zip(*([headers] + rows))) if rows else [headers] + widths = [max(len(_strip_ansi(x)) for x in col) for col in cols] + header_line = " | ".join(_pad(h, w) for h, w in zip(headers, widths)) + sep_line = "-+-".join("-" * w for w in widths) + body = [" | ".join(_pad(c, w) for c, w in zip(r, widths)) for r in rows] + return "\n".join([header_line, sep_line] + body) + + +def _color(s, color, ansi): + return f"{ansi[color]}{s}{ansi['reset']}" + + +def _get_terminal_width(default=80): + try: + return shutil.get_terminal_size().columns + except Exception: + return default + + +def log_state_dict_report( + *, + model, + pretrained_model_name_or_path, + logger: Optional[logging.Logger] = None, + error_msgs: Optional[Iterable[str]] = None, + unexpected_keys=None, + missing_keys=None, + mismatched_keys=None, + mismatched_shapes=None, + ignore_mismatched_sizes=True, + misc=None, + color=True, # allow disabling for plain logs + min_width_full_table=60, # terminal min width to attempt full table +): + """Log a readable report about state_dict loading issues. + + This version is terminal-size aware: for very small terminals it falls back to a compact + Key | Status view so output doesn't wrap badly. + """ + if logger is None: + logger = logging.getLogger(__name__) + + error_msgs = error_msgs or [] + unexpected_keys = unexpected_keys or [] + missing_keys = missing_keys or [] + mismatched_keys = mismatched_keys or [] + mismatched_shapes = mismatched_shapes or [] + misc = misc or {} + + # Detect whether the current stdout supports ANSI colors; allow callers to pass `color=False` to force no color + color_enabled = bool(color and sys.stdout.isatty()) + ansi = ANSI(color_enabled) + + if error_msgs: + error_msg = "\n\t".join(error_msgs) + if "size mismatch" in error_msg: + error_msg += ( + "\n\tYou may consider adding `ignore_mismatched_sizes=True` to `from_pretrained(...)` if appropriate." + ) + raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") + + term_w = _get_terminal_width() + rows = [] + if unexpected_keys: + for k in update_key_name(unexpected_keys): + status = "UNEXPECTED" + status = _color(status, "orange", ansi) + rows.append([k, status, "", ""]) + + if missing_keys: + for k in update_key_name(missing_keys): + status = "MISSING" + status = _color(status, "red", ansi) + rows.append([k, status, ""]) + + if mismatched_keys: + iterator = {a: (b, c) for a, b, c in mismatched_shapes} + for key, (shape_ckpt, shape_model) in update_key_name(iterator).items(): + status = "MISMATCH" + status = _color(status, "yellow", ansi) + data = [key, status] + data.append( + " ".join(["Reinit due to size mismatch", f"ckpt: {str(shape_ckpt)} vs model:{str(shape_model)}"]) + ) + rows.append(data) + + if misc: + for k, v in update_key_name(misc).items(): + status = "MISC" + status = _color(status, "purple", ansi) + _details = v[:term_w] + rows.append([k, status, _details]) + + if not rows: + return + + headers = ["Key", "Status"] + if term_w > 200: + headers += ["Details"] + else: + headers += ["", ""] + table = _make_table(rows, headers=headers) + + prelude = ( + f"{ansi['bold']}{model.__class__.__name__} LOAD REPORT{ansi['reset']} from: {pretrained_model_name_or_path}\n" + ) + tips = f"\n\n{ansi['italic']}Notes:" + if unexpected_keys: + tips += f"\n- {_color('UNEXPECTED', 'orange', ansi) + ansi['italic']}\t:can be ignored when loading from different task/architecture; not ok if you expect identical arch." + if missing_keys: + tips += f"\n- {_color('MISSING', 'red', ansi) + ansi['italic']}\t:those params were newly initialized because missing form the checkpoint. Consider training on your downstream task." + if mismatched_keys: + tips += f"\n- {_color('MISMATCH', 'yellow', ansi) + ansi['italic']}\t:ckpt weights were loaded, but they did not match the original empty weight." + if misc: + tips += f"\n- {_color('MISC', 'purple', ansi) + ansi['italic']}\t:originate from the conversion scheme" + tips += f"{ansi['reset']}" + + logger.warning(prelude + table + tips) + if not ignore_mismatched_sizes and mismatched_keys: + raise RuntimeError( + "You set `ignore_mismatched_sizes` to `False`, thus raising an error. For details look at the above report!" + ) + return prelude + table + tips diff --git a/src/transformers/utils/pytest_helpers.py b/src/transformers/utils/pytest_helpers.py new file mode 100644 index 000000000000..5f22e01ba508 --- /dev/null +++ b/src/transformers/utils/pytest_helpers.py @@ -0,0 +1,111 @@ +import argparse +import json +import re +from collections import Counter +from pathlib import Path + + +def _base_test_name(nodeid: str) -> str: + # Strip parameters like [param=..] from the last component + name = nodeid.split("::")[-1] + return re.sub(r"\[.*\]$", "", name) + + +def _class_name(nodeid: str) -> str | None: + parts = nodeid.split("::") + # nodeid can be: file::Class::test or file::test + if len(parts) >= 3: + return parts[-2] + return None + + +def _file_path(nodeid: str) -> str: + return nodeid.split("::")[0] + + +def _modeling_key(file_path: str) -> str | None: + # Extract "xxx" from test_modeling_xxx.py + m = re.search(r"test_modeling_([A-Za-z0-9_]+)\.py$", file_path) + if m: + return m.group(1) + return None + + +def summarize(report_path: str): + p = Path(report_path) + if not p.exists(): + raise FileNotFoundError(f"Report file not found: {p.resolve()}") + + data = json.loads(p.read_text()) + tests = data.get("tests", []) + + # Overall counts + outcomes = Counter(t.get("outcome", "unknown") for t in tests) + + # Filter failures (pytest-json-report uses "failed" and may have "error") + failed = [t for t in tests if t.get("outcome") in ("failed", "error")] + + # 1) Failures per test file + failures_per_file = Counter(_file_path(t.get("nodeid", "")) for t in failed) + + # 2) Failures per class (if any; otherwise "NO_CLASS") + failures_per_class = Counter((_class_name(t.get("nodeid", "")) or "NO_CLASS") for t in failed) + + # 3) Failures per base test name (function), aggregating parametrized cases + failures_per_testname = Counter(_base_test_name(t.get("nodeid", "")) for t in failed) + + # 4) Failures per test_modeling_xxx (derived from filename) + failures_per_modeling_key = Counter() + for t in failed: + key = _modeling_key(_file_path(t.get("nodeid", ""))) + if key: + failures_per_modeling_key[key] += 1 + + return { + "outcomes": outcomes, + "failures_per_file": failures_per_file, + "failures_per_class": failures_per_class, + "failures_per_testname": failures_per_testname, + "failures_per_modeling_key": failures_per_modeling_key, + } + + +def main(): + parser = argparse.ArgumentParser(description="Summarize pytest JSON report failures") + parser.add_argument( + "--report", default="report.json", help="Path to pytest JSON report file (default: report.json)" + ) + args = parser.parse_args() + + try: + summary = summarize(args.report) + except FileNotFoundError as e: + print(str(e)) + return + + outcomes = summary["outcomes"] + print("=== Overall ===") + total = sum(outcomes.values()) + print(f"Total tests: {total}") + for k in sorted(outcomes): + print(f"{k:>10}: {outcomes[k]}") + + def _print_counter(title, counter: Counter, label=""): + print(f"\n=== {title} ===") + if not counter: + print("None") + return + for key, cnt in sorted(counter.items(), key=lambda x: (x[1], x[0])): + if label: + print(f"{cnt:4d} {label}{key}") + else: + print(f"{cnt:4d} {key}") + + _print_counter("Failures per test class", summary["failures_per_class"], label="class ") + _print_counter("Failures per test_modeling_xxx", summary["failures_per_modeling_key"], label="model ") + _print_counter("Failures per test file", summary["failures_per_file"]) + _print_counter("Failures per test name (base)", summary["failures_per_testname"]) + + +if __name__ == "__main__": + main() diff --git a/tests/causal_lm_tester.py b/tests/causal_lm_tester.py index 5c25223428d6..bbcdadc9b2ca 100644 --- a/tests/causal_lm_tester.py +++ b/tests/causal_lm_tester.py @@ -248,6 +248,7 @@ def __init__( self.mamba_d_conv = mamba_d_conv self.mamba_expand = mamba_expand self.mamba_chunk_size = mamba_chunk_size + self.tie_word_embeddings = False def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) diff --git a/tests/models/autoformer/test_modeling_autoformer.py b/tests/models/autoformer/test_modeling_autoformer.py index 9da8abce9665..fd2345f3e94e 100644 --- a/tests/models/autoformer/test_modeling_autoformer.py +++ b/tests/models/autoformer/test_modeling_autoformer.py @@ -232,7 +232,7 @@ def test_save_load_strict(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) - self.assertEqual(info["missing_keys"], []) + self.assertEqual(info["missing_keys"], set()) def test_encoder_decoder_model_standalone(self): config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/bark/test_modeling_bark.py b/tests/models/bark/test_modeling_bark.py index 40991788e346..607a8cd848f1 100644 --- a/tests/models/bark/test_modeling_bark.py +++ b/tests/models/bark/test_modeling_bark.py @@ -539,7 +539,7 @@ def test_save_load_strict(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) - self.assertEqual(info["missing_keys"], []) + self.assertEqual(info["missing_keys"], set()) def test_decoder_model_past_with_large_inputs(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() @@ -625,7 +625,7 @@ def test_save_load_strict(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) - self.assertEqual(info["missing_keys"], []) + self.assertEqual(info["missing_keys"], set()) def test_decoder_model_past_with_large_inputs(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() @@ -708,7 +708,7 @@ def test_save_load_strict(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) - self.assertEqual(info["missing_keys"], []) + self.assertEqual(info["missing_keys"], set()) def test_inputs_embeds(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/bart/test_modeling_bart.py b/tests/models/bart/test_modeling_bart.py index eabff66bc6bc..28511b62d4f1 100644 --- a/tests/models/bart/test_modeling_bart.py +++ b/tests/models/bart/test_modeling_bart.py @@ -438,7 +438,7 @@ def test_save_load_strict(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) - self.assertEqual(info["missing_keys"], []) + self.assertEqual(info["missing_keys"], set()) def test_decoder_model_past_with_large_inputs(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() diff --git a/tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py b/tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py index 3076634362ac..54544f090992 100644 --- a/tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py +++ b/tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py @@ -297,7 +297,7 @@ def test_save_load_strict(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) - self.assertEqual(info["missing_keys"], []) + self.assertEqual(info["missing_keys"], set()) def test_decoder_model_past_with_large_inputs(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() diff --git a/tests/models/blenderbot/test_modeling_blenderbot.py b/tests/models/blenderbot/test_modeling_blenderbot.py index 4e906a4dceb8..31a569b5cdd6 100644 --- a/tests/models/blenderbot/test_modeling_blenderbot.py +++ b/tests/models/blenderbot/test_modeling_blenderbot.py @@ -241,7 +241,7 @@ def test_save_load_strict(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) - self.assertEqual(info["missing_keys"], []) + self.assertEqual(info["missing_keys"], set()) def test_decoder_model_past_with_large_inputs(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() diff --git a/tests/models/blenderbot_small/test_modeling_blenderbot_small.py b/tests/models/blenderbot_small/test_modeling_blenderbot_small.py index aef6aaa70318..52b8e79768f8 100644 --- a/tests/models/blenderbot_small/test_modeling_blenderbot_small.py +++ b/tests/models/blenderbot_small/test_modeling_blenderbot_small.py @@ -246,7 +246,7 @@ def test_save_load_strict(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) - self.assertEqual(info["missing_keys"], []) + self.assertEqual(info["missing_keys"], set()) def test_decoder_model_past_with_large_inputs(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() diff --git a/tests/models/data2vec/test_modeling_data2vec_audio.py b/tests/models/data2vec/test_modeling_data2vec_audio.py index 83e5c838bc16..57aa199415bf 100644 --- a/tests/models/data2vec/test_modeling_data2vec_audio.py +++ b/tests/models/data2vec/test_modeling_data2vec_audio.py @@ -459,13 +459,13 @@ def test_retain_grad_hidden_states_attentions(self): # overwrite from test_modeling_common def _mock_init_weights(self, module): if hasattr(module, "weight") and module.weight is not None: - module.weight.data.fill_(3) + module.weight.fill_(3) if hasattr(module, "weight_g") and module.weight_g is not None: module.weight_g.data.fill_(3) if hasattr(module, "weight_v") and module.weight_v is not None: module.weight_v.data.fill_(3) if hasattr(module, "bias") and module.bias is not None: - module.bias.data.fill_(3) + module.bias.fill_(3) if hasattr(module, "codevectors") and module.codevectors is not None: module.codevectors.data.fill_(3) if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None: diff --git a/tests/models/encoder_decoder/test_modeling_encoder_decoder.py b/tests/models/encoder_decoder/test_modeling_encoder_decoder.py index 68f84986054f..b0b3a4844225 100644 --- a/tests/models/encoder_decoder/test_modeling_encoder_decoder.py +++ b/tests/models/encoder_decoder/test_modeling_encoder_decoder.py @@ -552,8 +552,6 @@ def create_and_check_encoder_decoder_shared_weights( decoder_attention_mask=decoder_attention_mask, ) - # check that models has less parameters - self.assertLess(sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters())) random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item() # check that outputs are equal @@ -570,10 +568,6 @@ def create_and_check_encoder_decoder_shared_weights( tied_model.to(torch_device) tied_model.eval() - # check that models has less parameters - self.assertLess( - sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters()) - ) random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item() tied_model_result = tied_model( diff --git a/tests/models/falcon_mamba/test_modeling_falcon_mamba.py b/tests/models/falcon_mamba/test_modeling_falcon_mamba.py index f2e042c11748..8feef8d3eb75 100644 --- a/tests/models/falcon_mamba/test_modeling_falcon_mamba.py +++ b/tests/models/falcon_mamba/test_modeling_falcon_mamba.py @@ -66,7 +66,7 @@ def __init__( num_labels=3, num_choices=4, scope=None, - tie_word_embeddings=True, + tie_word_embeddings=False, ): self.parent = parent self.batch_size = batch_size diff --git a/tests/models/fastspeech2_conformer/test_modeling_fastspeech2_conformer.py b/tests/models/fastspeech2_conformer/test_modeling_fastspeech2_conformer.py index 5e729eae8eb0..9ab763c9df0a 100644 --- a/tests/models/fastspeech2_conformer/test_modeling_fastspeech2_conformer.py +++ b/tests/models/fastspeech2_conformer/test_modeling_fastspeech2_conformer.py @@ -200,7 +200,7 @@ def test_save_load_strict(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) _, info = FastSpeech2ConformerModel.from_pretrained(tmpdirname, output_loading_info=True) - self.assertEqual(info["missing_keys"], []) + self.assertEqual(info["missing_keys"], set()) def test_forward_signature(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() @@ -618,7 +618,7 @@ def test_save_load_strict(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) _, info = FastSpeech2ConformerWithHifiGan.from_pretrained(tmpdirname, output_loading_info=True) - self.assertEqual(info["missing_keys"], []) + self.assertEqual(info["missing_keys"], set()) def test_forward_signature(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/fsmt/test_modeling_fsmt.py b/tests/models/fsmt/test_modeling_fsmt.py index 095b91286575..102711eb8bc1 100644 --- a/tests/models/fsmt/test_modeling_fsmt.py +++ b/tests/models/fsmt/test_modeling_fsmt.py @@ -248,7 +248,7 @@ def test_save_load_missing_keys(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) - self.assertEqual(info["missing_keys"], []) + self.assertEqual(info["missing_keys"], set()) def test_ensure_weights_are_shared(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs() diff --git a/tests/models/funnel/test_modeling_funnel.py b/tests/models/funnel/test_modeling_funnel.py index e285d7fe87ec..654f9e106dbb 100644 --- a/tests/models/funnel/test_modeling_funnel.py +++ b/tests/models/funnel/test_modeling_funnel.py @@ -417,9 +417,9 @@ def test_for_question_answering(self): # overwrite from test_modeling_common def _mock_init_weights(self, module): if hasattr(module, "weight") and module.weight is not None: - module.weight.data.fill_(3) + module.weight.fill_(3) if hasattr(module, "bias") and module.bias is not None: - module.bias.data.fill_(3) + module.bias.fill_(3) for param in ["r_w_bias", "r_r_bias", "r_kernel", "r_s_bias", "seg_embed"]: if hasattr(module, param) and getattr(module, param) is not None: @@ -470,9 +470,9 @@ def test_training(self): # overwrite from test_modeling_common def _mock_init_weights(self, module): if hasattr(module, "weight") and module.weight is not None: - module.weight.data.fill_(3) + module.weight.fill_(3) if hasattr(module, "bias") and module.bias is not None: - module.bias.data.fill_(3) + module.bias.fill_(3) for param in ["r_w_bias", "r_r_bias", "r_kernel", "r_s_bias", "seg_embed"]: if hasattr(module, param) and getattr(module, param) is not None: diff --git a/tests/models/hubert/test_modeling_hubert.py b/tests/models/hubert/test_modeling_hubert.py index f47d20239f2a..7eccaea93daa 100644 --- a/tests/models/hubert/test_modeling_hubert.py +++ b/tests/models/hubert/test_modeling_hubert.py @@ -402,13 +402,13 @@ def test_retain_grad_hidden_states_attentions(self): # overwrite from test_modeling_common def _mock_init_weights(self, module): if hasattr(module, "weight") and module.weight is not None: - module.weight.data.fill_(3) + module.weight.fill_(3) if hasattr(module, "weight_g") and module.weight_g is not None: module.weight_g.data.fill_(3) if hasattr(module, "weight_v") and module.weight_v is not None: module.weight_v.data.fill_(3) if hasattr(module, "bias") and module.bias is not None: - module.bias.data.fill_(3) + module.bias.fill_(3) if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None: module.masked_spec_embed.data.fill_(3) @@ -525,13 +525,13 @@ def test_retain_grad_hidden_states_attentions(self): # overwrite from test_modeling_common def _mock_init_weights(self, module): if hasattr(module, "weight") and module.weight is not None: - module.weight.data.fill_(3) + module.weight.fill_(3) if hasattr(module, "weight_g") and module.weight_g is not None: module.weight_g.data.fill_(3) if hasattr(module, "weight_v") and module.weight_v is not None: module.weight_v.data.fill_(3) if hasattr(module, "bias") and module.bias is not None: - module.bias.data.fill_(3) + module.bias.fill_(3) if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None: module.masked_spec_embed.data.fill_(3) diff --git a/tests/models/informer/test_modeling_informer.py b/tests/models/informer/test_modeling_informer.py index 43c8ff471e03..c30cc27b34c9 100644 --- a/tests/models/informer/test_modeling_informer.py +++ b/tests/models/informer/test_modeling_informer.py @@ -218,7 +218,7 @@ def test_save_load_strict(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) - self.assertEqual(info["missing_keys"], []) + self.assertEqual(info["missing_keys"], set()) def test_encoder_decoder_model_standalone(self): config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/kosmos2/test_modeling_kosmos2.py b/tests/models/kosmos2/test_modeling_kosmos2.py index b63076e8f2b4..e8ec40c8c716 100644 --- a/tests/models/kosmos2/test_modeling_kosmos2.py +++ b/tests/models/kosmos2/test_modeling_kosmos2.py @@ -345,7 +345,7 @@ def test_load_save_without_tied_weights(self): v, reloaded_state[k], msg=lambda x: f"{model_class.__name__}: Tensor {k}: {x}" ) # Checking there was no complain of missing weights - self.assertEqual(infos["missing_keys"], []) + self.assertEqual(infos["missing_keys"], set()) # overwrite from common in order to use `self.model_tester.text_model_tester.num_hidden_layers` def test_hidden_states_output(self): diff --git a/tests/models/kosmos2_5/test_modeling_kosmos2_5.py b/tests/models/kosmos2_5/test_modeling_kosmos2_5.py index ac8be1982721..751166f1775a 100644 --- a/tests/models/kosmos2_5/test_modeling_kosmos2_5.py +++ b/tests/models/kosmos2_5/test_modeling_kosmos2_5.py @@ -411,7 +411,7 @@ def test_load_save_without_tied_weights(self): msg=lambda x: f"{model_class.__name__}: Tensor {k}: {x}", ) # Checking there was no complain of missing weights - self.assertEqual(infos["missing_keys"], []) + self.assertEqual(infos["missing_keys"], set()) # overwrite from common in order to use `self.model_tester.text_model_tester.num_hidden_layers` def test_hidden_states_output(self): diff --git a/tests/models/led/test_modeling_led.py b/tests/models/led/test_modeling_led.py index 2a17cb4d8a41..c1f3fd31a8f7 100644 --- a/tests/models/led/test_modeling_led.py +++ b/tests/models/led/test_modeling_led.py @@ -316,7 +316,7 @@ def test_save_load_strict(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) - self.assertEqual(info["missing_keys"], []) + self.assertEqual(info["missing_keys"], set()) def test_decoder_model_past_with_large_inputs(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() diff --git a/tests/models/longt5/test_modeling_longt5.py b/tests/models/longt5/test_modeling_longt5.py index 25b769e715b7..e73ef8596a20 100644 --- a/tests/models/longt5/test_modeling_longt5.py +++ b/tests/models/longt5/test_modeling_longt5.py @@ -430,10 +430,6 @@ def create_and_check_encoder_decoder_shared_weights( decoder_attention_mask=decoder_attention_mask, ) - # check that models has less parameters - self.parent.assertLess( - sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters()) - ) random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item() # check that outputs are equal @@ -450,10 +446,6 @@ def create_and_check_encoder_decoder_shared_weights( tied_model.to(torch_device) tied_model.eval() - # check that models has less parameters - self.parent.assertLess( - sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters()) - ) random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item() tied_model_result = tied_model( diff --git a/tests/models/m2m_100/test_modeling_m2m_100.py b/tests/models/m2m_100/test_modeling_m2m_100.py index 718b5cca2956..f2a6a7c4b9e7 100644 --- a/tests/models/m2m_100/test_modeling_m2m_100.py +++ b/tests/models/m2m_100/test_modeling_m2m_100.py @@ -272,7 +272,7 @@ def test_save_load_strict(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) - self.assertEqual(info["missing_keys"], []) + self.assertEqual(info["missing_keys"], set()) def test_decoder_model_past_with_large_inputs(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() diff --git a/tests/models/marian/test_modeling_marian.py b/tests/models/marian/test_modeling_marian.py index b897cb76c6d8..72aa45ad358f 100644 --- a/tests/models/marian/test_modeling_marian.py +++ b/tests/models/marian/test_modeling_marian.py @@ -246,7 +246,7 @@ def test_save_load_strict(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) - self.assertEqual(info["missing_keys"], []) + self.assertEqual(info["missing_keys"], set()) def test_decoder_model_past_with_large_inputs(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() diff --git a/tests/models/mbart/test_modeling_mbart.py b/tests/models/mbart/test_modeling_mbart.py index 73c28e9ed573..2f997370c64c 100644 --- a/tests/models/mbart/test_modeling_mbart.py +++ b/tests/models/mbart/test_modeling_mbart.py @@ -265,7 +265,7 @@ def test_save_load_strict(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) - self.assertEqual(info["missing_keys"], []) + self.assertEqual(info["missing_keys"], set()) def test_decoder_model_past_with_large_inputs(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() diff --git a/tests/models/mt5/test_modeling_mt5.py b/tests/models/mt5/test_modeling_mt5.py index b5fd56813845..45a5ad01ab76 100644 --- a/tests/models/mt5/test_modeling_mt5.py +++ b/tests/models/mt5/test_modeling_mt5.py @@ -456,10 +456,6 @@ def create_and_check_encoder_decoder_shared_weights( decoder_attention_mask=decoder_attention_mask, ) - # check that models has less parameters - self.parent.assertLess( - sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters()) - ) random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item() # check that outputs are equal @@ -476,10 +472,6 @@ def create_and_check_encoder_decoder_shared_weights( tied_model.to(torch_device) tied_model.eval() - # check that models has less parameters - self.parent.assertLess( - sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters()) - ) random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item() tied_model_result = tied_model( diff --git a/tests/models/mvp/test_modeling_mvp.py b/tests/models/mvp/test_modeling_mvp.py index e50039d68fe6..e8e959aec813 100644 --- a/tests/models/mvp/test_modeling_mvp.py +++ b/tests/models/mvp/test_modeling_mvp.py @@ -462,7 +462,7 @@ def test_save_load_strict(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) - self.assertEqual(info["missing_keys"], []) + self.assertEqual(info["missing_keys"], set()) def test_decoder_model_past_with_large_inputs(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() diff --git a/tests/models/nllb_moe/test_modeling_nllb_moe.py b/tests/models/nllb_moe/test_modeling_nllb_moe.py index 2040b5ca435a..aa761797ea5c 100644 --- a/tests/models/nllb_moe/test_modeling_nllb_moe.py +++ b/tests/models/nllb_moe/test_modeling_nllb_moe.py @@ -274,7 +274,7 @@ def test_save_load_strict(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) - self.assertEqual(info["missing_keys"], []) + self.assertEqual(info["missing_keys"], set()) def test_decoder_model_past_with_large_inputs(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs() diff --git a/tests/models/opt/test_modeling_opt.py b/tests/models/opt/test_modeling_opt.py index d195385ecdd5..6ecef7519f8f 100644 --- a/tests/models/opt/test_modeling_opt.py +++ b/tests/models/opt/test_modeling_opt.py @@ -256,7 +256,7 @@ def test_save_load_strict(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) - self.assertEqual(info["missing_keys"], []) + self.assertEqual(info["missing_keys"], set()) def test_decoder_model_past_with_large_inputs(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() diff --git a/tests/models/patchtsmixer/test_modeling_patchtsmixer.py b/tests/models/patchtsmixer/test_modeling_patchtsmixer.py index a4e245dc85e1..22650ea34829 100644 --- a/tests/models/patchtsmixer/test_modeling_patchtsmixer.py +++ b/tests/models/patchtsmixer/test_modeling_patchtsmixer.py @@ -276,7 +276,7 @@ def test_save_load_strict(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) - self.assertEqual(info["missing_keys"], []) + self.assertEqual(info["missing_keys"], set()) def test_hidden_states_output(self): def check_hidden_states_output(inputs_dict, config, model_class): diff --git a/tests/models/patchtst/test_modeling_patchtst.py b/tests/models/patchtst/test_modeling_patchtst.py index 13feefcc207f..ba29095bf8ac 100644 --- a/tests/models/patchtst/test_modeling_patchtst.py +++ b/tests/models/patchtst/test_modeling_patchtst.py @@ -208,7 +208,7 @@ def test_save_load_strict(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) - self.assertEqual(info["missing_keys"], []) + self.assertEqual(info["missing_keys"], set()) def test_hidden_states_output(self): def check_hidden_states_output(inputs_dict, config, model_class): diff --git a/tests/models/pegasus/test_modeling_pegasus.py b/tests/models/pegasus/test_modeling_pegasus.py index 1289acefd315..9753d15a3a08 100644 --- a/tests/models/pegasus/test_modeling_pegasus.py +++ b/tests/models/pegasus/test_modeling_pegasus.py @@ -253,7 +253,7 @@ def test_save_load_strict(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) - self.assertEqual(info["missing_keys"], []) + self.assertEqual(info["missing_keys"], set()) def test_decoder_model_past_with_large_inputs(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() diff --git a/tests/models/pegasus_x/test_modeling_pegasus_x.py b/tests/models/pegasus_x/test_modeling_pegasus_x.py index 34e0828fd6dd..34f28a38d6b4 100644 --- a/tests/models/pegasus_x/test_modeling_pegasus_x.py +++ b/tests/models/pegasus_x/test_modeling_pegasus_x.py @@ -231,7 +231,7 @@ def test_save_load_strict(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) - self.assertEqual(info["missing_keys"], []) + self.assertEqual(info["missing_keys"], set()) def test_decoder_model_past_with_large_inputs(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() diff --git a/tests/models/plbart/test_modeling_plbart.py b/tests/models/plbart/test_modeling_plbart.py index 5da6225b03d5..1320edf35f1f 100644 --- a/tests/models/plbart/test_modeling_plbart.py +++ b/tests/models/plbart/test_modeling_plbart.py @@ -261,7 +261,7 @@ def test_save_load_strict(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) - self.assertEqual(info["missing_keys"], []) + self.assertEqual(info["missing_keys"], set()) def test_decoder_model_past_with_large_inputs(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() diff --git a/tests/models/pop2piano/test_modeling_pop2piano.py b/tests/models/pop2piano/test_modeling_pop2piano.py index 3177df3ca89c..e68ea243df23 100644 --- a/tests/models/pop2piano/test_modeling_pop2piano.py +++ b/tests/models/pop2piano/test_modeling_pop2piano.py @@ -404,10 +404,6 @@ def create_and_check_encoder_decoder_shared_weights( decoder_attention_mask=decoder_attention_mask, ) - # check that models has less parameters - self.parent.assertLess( - sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters()) - ) random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item() # check that outputs are equal @@ -424,10 +420,6 @@ def create_and_check_encoder_decoder_shared_weights( tied_model.to(torch_device) tied_model.eval() - # check that models has less parameters - self.parent.assertLess( - sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters()) - ) random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item() tied_model_result = tied_model( diff --git a/tests/models/prophetnet/test_modeling_prophetnet.py b/tests/models/prophetnet/test_modeling_prophetnet.py index 38b74c9c0a30..0fc05a5dc3be 100644 --- a/tests/models/prophetnet/test_modeling_prophetnet.py +++ b/tests/models/prophetnet/test_modeling_prophetnet.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import tempfile import unittest @@ -332,86 +331,6 @@ def create_and_check_model_fp16_forward( output = model(input_ids, decoder_input_ids=input_ids, attention_mask=attention_mask)["last_hidden_state"] self.parent.assertFalse(torch.isnan(output).any().item()) - def create_and_check_encoder_decoder_shared_weights( - self, - config, - input_ids, - decoder_input_ids, - attention_mask, - decoder_attention_mask, - lm_labels, - ): - for model_class in [ProphetNetModel, ProphetNetForConditionalGeneration]: - torch.manual_seed(0) - model = model_class(config=config).to(torch_device).eval() - # load state dict copies weights but does not tie them - - if model_class == ProphetNetForConditionalGeneration: - model.prophetnet.encoder.load_state_dict(model.prophetnet.decoder.state_dict(), strict=False) - else: - model.encoder.load_state_dict(model.decoder.state_dict(), strict=False) - - torch.manual_seed(0) - tied_config = copy.deepcopy(config) - tied_config.tie_encoder_decoder = True - tied_model = model_class(config=tied_config).to(torch_device).eval() - - model_result = model( - input_ids=input_ids, - decoder_input_ids=decoder_input_ids, - attention_mask=attention_mask, - decoder_attention_mask=decoder_attention_mask, - ) - - tied_model_result = tied_model( - input_ids=input_ids, - decoder_input_ids=decoder_input_ids, - attention_mask=attention_mask, - decoder_attention_mask=decoder_attention_mask, - ) - - # check that models has less parameters - self.parent.assertLess( - sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters()) - ) - random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item() - - # check that outputs are equal - self.parent.assertTrue( - torch.allclose( - model_result[0][0, :, random_slice_idx], tied_model_result[0][0, :, random_slice_idx], atol=1e-4 - ) - ) - - # check that outputs after saving and loading are equal - with tempfile.TemporaryDirectory() as tmpdirname: - tied_model.save_pretrained(tmpdirname) - tied_model = model_class.from_pretrained(tmpdirname) - tied_model.to(torch_device) - tied_model.eval() - - # check that models has less parameters - self.parent.assertLess( - sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters()) - ) - random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item() - - tied_model_result = tied_model( - input_ids=input_ids, - decoder_input_ids=decoder_input_ids, - attention_mask=attention_mask, - decoder_attention_mask=decoder_attention_mask, - ) - - # check that outputs are equal - self.parent.assertTrue( - torch.allclose( - model_result[0][0, :, random_slice_idx], - tied_model_result[0][0, :, random_slice_idx], - atol=1e-4, - ) - ) - def check_fast_integration( self, config, @@ -943,10 +862,6 @@ def test_fast_integration(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.check_fast_integration(*config_and_inputs) - def test_shared_weights(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_encoder_decoder_shared_weights(*config_and_inputs) - def test_shift_labels_via_shift_left(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.check_prepare_lm_labels_via_shift_left(*config_and_inputs) diff --git a/tests/models/sew/test_modeling_sew.py b/tests/models/sew/test_modeling_sew.py index a195a9b3d158..75998c11f168 100644 --- a/tests/models/sew/test_modeling_sew.py +++ b/tests/models/sew/test_modeling_sew.py @@ -376,13 +376,13 @@ def test_seq_classifier_train(self): # overwrite from test_modeling_common def _mock_init_weights(self, module): if hasattr(module, "weight") and module.weight is not None: - module.weight.data.fill_(3) + module.weight.fill_(3) if hasattr(module, "weight_g") and module.weight_g is not None: module.weight_g.data.fill_(3) if hasattr(module, "weight_v") and module.weight_v is not None: module.weight_v.data.fill_(3) if hasattr(module, "bias") and module.bias is not None: - module.bias.data.fill_(3) + module.bias.fill_(3) if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None: module.masked_spec_embed.data.fill_(3) diff --git a/tests/models/sew_d/test_modeling_sew_d.py b/tests/models/sew_d/test_modeling_sew_d.py index fe8bff0e37e9..b0c0853a7d0a 100644 --- a/tests/models/sew_d/test_modeling_sew_d.py +++ b/tests/models/sew_d/test_modeling_sew_d.py @@ -386,13 +386,13 @@ def test_retain_grad_hidden_states_attentions(self): # overwrite from test_modeling_common def _mock_init_weights(self, module): if hasattr(module, "weight") and module.weight is not None: - module.weight.data.fill_(3) + module.weight.fill_(3) if hasattr(module, "weight_g") and module.weight_g is not None: module.weight_g.data.fill_(3) if hasattr(module, "weight_v") and module.weight_v is not None: module.weight_v.data.fill_(3) if hasattr(module, "bias") and module.bias is not None: - module.bias.data.fill_(3) + module.bias.fill_(3) if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None: module.masked_spec_embed.data.fill_(3) diff --git a/tests/models/speech_to_text/test_modeling_speech_to_text.py b/tests/models/speech_to_text/test_modeling_speech_to_text.py index 0307f5c634da..835d371389e5 100644 --- a/tests/models/speech_to_text/test_modeling_speech_to_text.py +++ b/tests/models/speech_to_text/test_modeling_speech_to_text.py @@ -282,7 +282,7 @@ def test_save_load_strict(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) - self.assertEqual(info["missing_keys"], []) + self.assertEqual(info["missing_keys"], set()) def test_model_forward(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() diff --git a/tests/models/speecht5/test_modeling_speecht5.py b/tests/models/speecht5/test_modeling_speecht5.py index 8d608ce0ff82..fd2a885a9639 100644 --- a/tests/models/speecht5/test_modeling_speecht5.py +++ b/tests/models/speecht5/test_modeling_speecht5.py @@ -354,7 +354,7 @@ def test_save_load_strict(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) - self.assertEqual(info["missing_keys"], []) + self.assertEqual(info["missing_keys"], set()) def test_model_forward(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() @@ -664,13 +664,13 @@ def test_training_gradient_checkpointing_use_reentrant_false(self): # overwrite from test_modeling_common def _mock_init_weights(self, module): if hasattr(module, "weight") and module.weight is not None: - module.weight.data.fill_(3) + module.weight.fill_(3) if hasattr(module, "weight_g") and module.weight_g is not None: module.weight_g.data.fill_(3) if hasattr(module, "weight_v") and module.weight_v is not None: module.weight_v.data.fill_(3) if hasattr(module, "bias") and module.bias is not None: - module.bias.data.fill_(3) + module.bias.fill_(3) if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None: module.masked_spec_embed.data.fill_(3) @@ -859,7 +859,7 @@ def test_save_load_strict(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) - self.assertEqual(info["missing_keys"], []) + self.assertEqual(info["missing_keys"], set()) def test_model_forward(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() @@ -951,13 +951,13 @@ def test_training_gradient_checkpointing_use_reentrant_false(self): # overwrite from test_modeling_common def _mock_init_weights(self, module): if hasattr(module, "weight") and module.weight is not None: - module.weight.data.fill_(3) + module.weight.fill_(3) if hasattr(module, "weight_g") and module.weight_g is not None: module.weight_g.data.fill_(3) if hasattr(module, "weight_v") and module.weight_v is not None: module.weight_v.data.fill_(3) if hasattr(module, "bias") and module.bias is not None: - module.bias.data.fill_(3) + module.bias.fill_(3) @require_torch @@ -966,15 +966,15 @@ def _mock_init_weights(self, module): class SpeechT5ForTextToSpeechIntegrationTests(unittest.TestCase): @cached_property def default_model(self): - return SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts").to(torch_device) + return SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts", revision="refs/pr/19").to(torch_device) @cached_property def default_processor(self): - return SpeechT5Processor.from_pretrained("microsoft/speecht5_tts") + return SpeechT5Processor.from_pretrained("microsoft/speecht5_tts", revision="refs/pr/19") @cached_property def default_vocoder(self): - return SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(torch_device) + return SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan", revision="refs/pr/1").to(torch_device) def test_generation(self): model = self.default_model @@ -1359,7 +1359,7 @@ def test_save_load_strict(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) - self.assertEqual(info["missing_keys"], []) + self.assertEqual(info["missing_keys"], set()) def test_model_forward(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() @@ -1608,13 +1608,13 @@ def test_training_gradient_checkpointing_use_reentrant_false(self): # overwrite from test_modeling_common def _mock_init_weights(self, module): if hasattr(module, "weight") and module.weight is not None: - module.weight.data.fill_(3) + module.weight.fill_(3) if hasattr(module, "weight_g") and module.weight_g is not None: module.weight_g.data.fill_(3) if hasattr(module, "weight_v") and module.weight_v is not None: module.weight_v.data.fill_(3) if hasattr(module, "bias") and module.bias is not None: - module.bias.data.fill_(3) + module.bias.fill_(3) if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None: module.masked_spec_embed.data.fill_(3) diff --git a/tests/models/switch_transformers/test_modeling_switch_transformers.py b/tests/models/switch_transformers/test_modeling_switch_transformers.py index 65eb103c1fc4..37202848242d 100644 --- a/tests/models/switch_transformers/test_modeling_switch_transformers.py +++ b/tests/models/switch_transformers/test_modeling_switch_transformers.py @@ -473,10 +473,6 @@ def create_and_check_encoder_decoder_shared_weights( decoder_attention_mask=decoder_attention_mask, ) - # check that models has less parameters - self.parent.assertLess( - sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters()) - ) random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item() # check that outputs are equal @@ -493,10 +489,6 @@ def create_and_check_encoder_decoder_shared_weights( tied_model.to(torch_device) tied_model.eval() - # check that models has less parameters - self.parent.assertLess( - sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters()) - ) random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item() tied_model_result = tied_model( diff --git a/tests/models/t5/test_modeling_t5.py b/tests/models/t5/test_modeling_t5.py index 8345cd63b036..52f85f17d9fb 100644 --- a/tests/models/t5/test_modeling_t5.py +++ b/tests/models/t5/test_modeling_t5.py @@ -465,10 +465,6 @@ def create_and_check_encoder_decoder_shared_weights( decoder_attention_mask=decoder_attention_mask, ) - # check that models has less parameters - self.parent.assertLess( - sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters()) - ) random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item() # check that outputs are equal @@ -485,10 +481,6 @@ def create_and_check_encoder_decoder_shared_weights( tied_model.to(torch_device) tied_model.eval() - # check that models has less parameters - self.parent.assertLess( - sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters()) - ) random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item() tied_model_result = tied_model( diff --git a/tests/models/time_series_transformer/test_modeling_time_series_transformer.py b/tests/models/time_series_transformer/test_modeling_time_series_transformer.py index 2aba8c17303a..7cf421a10404 100644 --- a/tests/models/time_series_transformer/test_modeling_time_series_transformer.py +++ b/tests/models/time_series_transformer/test_modeling_time_series_transformer.py @@ -205,7 +205,7 @@ def test_save_load_strict(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) - self.assertEqual(info["missing_keys"], []) + self.assertEqual(info["missing_keys"], set()) def test_encoder_decoder_model_standalone(self): config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/unispeech/test_modeling_unispeech.py b/tests/models/unispeech/test_modeling_unispeech.py index 116690992c39..d0490cd4900b 100644 --- a/tests/models/unispeech/test_modeling_unispeech.py +++ b/tests/models/unispeech/test_modeling_unispeech.py @@ -421,13 +421,13 @@ def test_retain_grad_hidden_states_attentions(self): # overwrite from test_modeling_common def _mock_init_weights(self, module): if hasattr(module, "weight") and module.weight is not None: - module.weight.data.fill_(3) + module.weight.fill_(3) if hasattr(module, "weight_g") and module.weight_g is not None: module.weight_g.data.fill_(3) if hasattr(module, "weight_v") and module.weight_v is not None: module.weight_v.data.fill_(3) if hasattr(module, "bias") and module.bias is not None: - module.bias.data.fill_(3) + module.bias.fill_(3) if hasattr(module, "codevectors") and module.codevectors is not None: module.codevectors.data.fill_(3) if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None: diff --git a/tests/models/unispeech_sat/test_modeling_unispeech_sat.py b/tests/models/unispeech_sat/test_modeling_unispeech_sat.py index dc4b64e4d83c..084801161f1f 100644 --- a/tests/models/unispeech_sat/test_modeling_unispeech_sat.py +++ b/tests/models/unispeech_sat/test_modeling_unispeech_sat.py @@ -460,13 +460,13 @@ def test_retain_grad_hidden_states_attentions(self): # overwrite from test_modeling_common def _mock_init_weights(self, module): if hasattr(module, "weight") and module.weight is not None: - module.weight.data.fill_(3) + module.weight.fill_(3) if hasattr(module, "weight_g") and module.weight_g is not None: module.weight_g.data.fill_(3) if hasattr(module, "weight_v") and module.weight_v is not None: module.weight_v.data.fill_(3) if hasattr(module, "bias") and module.bias is not None: - module.bias.data.fill_(3) + module.bias.fill_(3) if hasattr(module, "codevectors") and module.codevectors is not None: module.codevectors.data.fill_(3) if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None: @@ -634,13 +634,13 @@ def test_retain_grad_hidden_states_attentions(self): # overwrite from test_modeling_common def _mock_init_weights(self, module): if hasattr(module, "weight") and module.weight is not None: - module.weight.data.fill_(3) + module.weight.fill_(3) if hasattr(module, "weight_g") and module.weight_g is not None: module.weight_g.data.fill_(3) if hasattr(module, "weight_v") and module.weight_v is not None: module.weight_v.data.fill_(3) if hasattr(module, "bias") and module.bias is not None: - module.bias.data.fill_(3) + module.bias.fill_(3) if hasattr(module, "codevectors") and module.codevectors is not None: module.codevectors.data.fill_(3) if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None: diff --git a/tests/models/vits/test_modeling_vits.py b/tests/models/vits/test_modeling_vits.py index 46b417f04b00..1e19ae38d4e9 100644 --- a/tests/models/vits/test_modeling_vits.py +++ b/tests/models/vits/test_modeling_vits.py @@ -350,13 +350,13 @@ def check_save_load(out1, out2): # overwrite from test_modeling_common def _mock_init_weights(self, module): if hasattr(module, "weight") and module.weight is not None: - module.weight.data.fill_(3) + module.weight.fill_(3) if hasattr(module, "weight_g") and module.weight_g is not None: module.weight_g.data.fill_(3) if hasattr(module, "weight_v") and module.weight_v is not None: module.weight_v.data.fill_(3) if hasattr(module, "bias") and module.bias is not None: - module.bias.data.fill_(3) + module.bias.fill_(3) @require_torch diff --git a/tests/models/wav2vec2/test_modeling_wav2vec2.py b/tests/models/wav2vec2/test_modeling_wav2vec2.py index e645070ffa31..c2767583c6cd 100644 --- a/tests/models/wav2vec2/test_modeling_wav2vec2.py +++ b/tests/models/wav2vec2/test_modeling_wav2vec2.py @@ -602,13 +602,13 @@ def test_retain_grad_hidden_states_attentions(self): # overwrite from test_modeling_common def _mock_init_weights(self, module): if hasattr(module, "weight") and module.weight is not None: - module.weight.data.fill_(3) + module.weight.fill_(3) if hasattr(module, "weight_g") and module.weight_g is not None: module.weight_g.data.fill_(3) if hasattr(module, "weight_v") and module.weight_v is not None: module.weight_v.data.fill_(3) if hasattr(module, "bias") and module.bias is not None: - module.bias.data.fill_(3) + module.bias.fill_(3) if hasattr(module, "codevectors") and module.codevectors is not None: module.codevectors.data.fill_(3) if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None: @@ -807,13 +807,13 @@ def test_retain_grad_hidden_states_attentions(self): # overwrite from test_modeling_common def _mock_init_weights(self, module): if hasattr(module, "weight") and module.weight is not None: - module.weight.data.fill_(3) + module.weight.fill_(3) if hasattr(module, "weight_g") and module.weight_g is not None: module.weight_g.data.fill_(3) if hasattr(module, "weight_v") and module.weight_v is not None: module.weight_v.data.fill_(3) if hasattr(module, "bias") and module.bias is not None: - module.bias.data.fill_(3) + module.bias.fill_(3) if hasattr(module, "codevectors") and module.codevectors is not None: module.codevectors.data.fill_(3) if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None: diff --git a/tests/models/wav2vec2_bert/test_modeling_wav2vec2_bert.py b/tests/models/wav2vec2_bert/test_modeling_wav2vec2_bert.py index 966b2c50d7b8..71b24e406524 100644 --- a/tests/models/wav2vec2_bert/test_modeling_wav2vec2_bert.py +++ b/tests/models/wav2vec2_bert/test_modeling_wav2vec2_bert.py @@ -574,13 +574,13 @@ def test_retain_grad_hidden_states_attentions(self): # overwrite from test_modeling_common def _mock_init_weights(self, module): if hasattr(module, "weight") and module.weight is not None: - module.weight.data.fill_(3) + module.weight.fill_(3) if hasattr(module, "weight_g") and module.weight_g is not None: module.weight_g.data.fill_(3) if hasattr(module, "weight_v") and module.weight_v is not None: module.weight_v.data.fill_(3) if hasattr(module, "bias") and module.bias is not None: - module.bias.data.fill_(3) + module.bias.fill_(3) if hasattr(module, "pos_bias_u") and module.pos_bias_u is not None: module.pos_bias_u.data.fill_(3) if hasattr(module, "pos_bias_v") and module.pos_bias_v is not None: diff --git a/tests/models/wav2vec2_conformer/test_modeling_wav2vec2_conformer.py b/tests/models/wav2vec2_conformer/test_modeling_wav2vec2_conformer.py index ba0752927521..416a6d3cb537 100644 --- a/tests/models/wav2vec2_conformer/test_modeling_wav2vec2_conformer.py +++ b/tests/models/wav2vec2_conformer/test_modeling_wav2vec2_conformer.py @@ -546,13 +546,13 @@ def test_retain_grad_hidden_states_attentions(self): # overwrite from test_modeling_common def _mock_init_weights(self, module): if hasattr(module, "weight") and module.weight is not None: - module.weight.data.fill_(3) + module.weight.fill_(3) if hasattr(module, "weight_g") and module.weight_g is not None: module.weight_g.data.fill_(3) if hasattr(module, "weight_v") and module.weight_v is not None: module.weight_v.data.fill_(3) if hasattr(module, "bias") and module.bias is not None: - module.bias.data.fill_(3) + module.bias.fill_(3) if hasattr(module, "pos_bias_u") and module.pos_bias_u is not None: module.pos_bias_u.data.fill_(3) if hasattr(module, "pos_bias_v") and module.pos_bias_v is not None: diff --git a/tests/models/wavlm/test_modeling_wavlm.py b/tests/models/wavlm/test_modeling_wavlm.py index fc422db7206f..247c2b3fe5d2 100644 --- a/tests/models/wavlm/test_modeling_wavlm.py +++ b/tests/models/wavlm/test_modeling_wavlm.py @@ -398,13 +398,13 @@ def test_retain_grad_hidden_states_attentions(self): # overwrite from test_modeling_common def _mock_init_weights(self, module): if hasattr(module, "weight") and module.weight is not None: - module.weight.data.fill_(3) + module.weight.fill_(3) if hasattr(module, "weight_g") and module.weight_g is not None: module.weight_g.data.fill_(3) if hasattr(module, "weight_v") and module.weight_v is not None: module.weight_v.data.fill_(3) if hasattr(module, "bias") and module.bias is not None: - module.bias.data.fill_(3) + module.bias.fill_(3) if hasattr(module, "codevectors") and module.codevectors is not None: module.codevectors.data.fill_(3) if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None: diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 35d4a8ffd3ca..732814eaed29 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -422,7 +422,7 @@ def test_save_load_strict(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) - self.assertEqual(info["missing_keys"], []) + self.assertEqual(info["missing_keys"], set()) def test_model_forward(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() diff --git a/tests/models/xlnet/test_modeling_xlnet.py b/tests/models/xlnet/test_modeling_xlnet.py index 54b59c55d4cc..e973d0f16f81 100644 --- a/tests/models/xlnet/test_modeling_xlnet.py +++ b/tests/models/xlnet/test_modeling_xlnet.py @@ -617,9 +617,9 @@ def test_retain_grad_hidden_states_attentions(self): # overwrite from test_modeling_common def _mock_init_weights(self, module): if hasattr(module, "weight") and module.weight is not None: - module.weight.data.fill_(3) + module.weight.fill_(3) if hasattr(module, "bias") and module.bias is not None: - module.bias.data.fill_(3) + module.bias.fill_(3) for param in ["q", "k", "v", "o", "r", "r_r_bias", "r_s_bias", "r_w_bias", "seg_embed", "mask_emb"]: if hasattr(module, param) and getattr(module, param) is not None: diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index b28cf248ca97..feb191983381 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -202,22 +202,6 @@ def test_linear_are_4bit(self): # 4-bit parameters are packed in uint8 variables self.assertTrue(module.weight.dtype == torch.uint8) - def test_rwkv_4bit(self): - r""" - A simple test to check if 4-bit RWKV inference works as expected. - """ - model_id = "RWKV/rwkv-4-169m-pile" - - quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True) - - model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config) - tok = AutoTokenizer.from_pretrained(model_id) - - text = "Hello my name is" - input_ids = tok.encode(text, return_tensors="pt").to(torch_device) - - _ = model.generate(input_ids, max_new_tokens=30) - def test_generate_quality(self): r""" Test the generation quality of the quantized model and see that we are matching the expected output. @@ -607,7 +591,7 @@ def setUp(self): def test_training(self): # Step 1: freeze all parameters model = AutoModelForCausalLM.from_pretrained( - self.model_name, quantization_config=BitsAndBytesConfig(load_in_4bit=True) + self.model_name, quantization_config=BitsAndBytesConfig(load_in_4bit=True), revision="refs/pr/40" ) if torch_device in ["cuda", "xpu"]: @@ -671,7 +655,7 @@ def tearDown(self): gc.collect() backend_empty_cache(torch_device) - def test_serialization(self, quant_type="nf4", double_quant=True, safe_serialization=True): + def test_serialization(self, quant_type="nf4", double_quant=True): r""" Test whether it is possible to serialize a model in 4-bit. Uses most typical params as default. See ExtendedSerializationTest class for more params combinations. @@ -685,14 +669,22 @@ def test_serialization(self, quant_type="nf4", double_quant=True, safe_serializa bnb_4bit_use_double_quant=double_quant, bnb_4bit_compute_dtype=torch.bfloat16, ) + + # for now, we should be able to fetch those in from_pretrained directly + if self.model_name == "facebook/opt-125m": + revision="refs/pr/49" + else: + revision="main" + model_0 = AutoModelForCausalLM.from_pretrained( self.model_name, quantization_config=self.quantization_config, device_map=torch_device, + revision=revision ) with tempfile.TemporaryDirectory() as tmpdirname: - model_0.save_pretrained(tmpdirname, safe_serialization=safe_serialization) + model_0.save_pretrained(tmpdirname) config = AutoConfig.from_pretrained(tmpdirname) self.assertTrue(hasattr(config, "quantization_config")) @@ -758,28 +750,15 @@ class ExtendedSerializationTest(BaseSerializationTest): tests more combinations of parameters """ - def test_nf4_single_unsafe(self): - self.test_serialization(quant_type="nf4", double_quant=False, safe_serialization=False) - def test_nf4_single_safe(self): - self.test_serialization(quant_type="nf4", double_quant=False, safe_serialization=True) - - def test_nf4_double_unsafe(self): - self.test_serialization(quant_type="nf4", double_quant=True, safe_serialization=False) - + self.test_serialization(quant_type="nf4", double_quant=False) # nf4 double safetensors quantization is tested in test_serialization() method from the parent class - def test_fp4_single_unsafe(self): - self.test_serialization(quant_type="fp4", double_quant=False, safe_serialization=False) - def test_fp4_single_safe(self): - self.test_serialization(quant_type="fp4", double_quant=False, safe_serialization=True) - - def test_fp4_double_unsafe(self): - self.test_serialization(quant_type="fp4", double_quant=True, safe_serialization=False) + self.test_serialization(quant_type="fp4", double_quant=False) def test_fp4_double_safe(self): - self.test_serialization(quant_type="fp4", double_quant=True, safe_serialization=True) + self.test_serialization(quant_type="fp4", double_quant=True) class BloomSerializationTest(BaseSerializationTest): diff --git a/tests/repo_utils/test_check_copies.py b/tests/repo_utils/test_check_copies.py index f6ae669c4cc1..cc1c28a6eda6 100644 --- a/tests/repo_utils/test_check_copies.py +++ b/tests/repo_utils/test_check_copies.py @@ -36,13 +36,9 @@ # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. - self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` - self.decoder.bias = self.bias - def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 4ba9e1240e48..53189a809d93 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -23,6 +23,7 @@ import warnings from collections import defaultdict from contextlib import contextmanager +from copy import deepcopy import numpy as np import pytest @@ -116,6 +117,7 @@ if is_torch_available(): import torch + from safetensors import safe_open from safetensors.torch import load_file as safe_load_file from safetensors.torch import save_file as safe_save_file from torch import nn @@ -260,7 +262,11 @@ def _can_output_attn(model): model_sdpa = model_class.from_pretrained(**model_from_pretrained_kwargs) model_sdpa = model_sdpa.eval().to(torch_device) - model_eager = model_class.from_pretrained(**model_from_pretrained_kwargs, attn_implementation="eager") + try: + model_eager = deepcopy(model_sdpa) + model_eager.set_attn_implementation("eager") + except Exception as _: + model_eager = model_class.from_pretrained(**model_from_pretrained_kwargs, attn_implementation="eager") model_eager = model_eager.eval().to(torch_device) set_model_for_less_flaky_test(model_eager) @@ -752,8 +758,16 @@ def test_from_pretrained_no_checkpoint(self): new_model = model_class.from_pretrained( pretrained_model_name_or_path=None, config=config, state_dict=state_dict ) - for p1, p2 in zip(model.parameters(), new_model.parameters()): - self.assertTrue(torch.equal(p1, p2)) + new_state_dict = new_model.state_dict() + assert state_dict.keys() == new_state_dict.keys() + keys = state_dict.keys() + for k in keys: + p1, p2 = new_state_dict[k], state_dict[k] + torch.testing.assert_close(p1, p2) + new_params = dict(new_model.named_parameters()) + for k, v in list(model.named_parameters()): + with self.subTest(k): + torch.testing.assert_close(v, new_params[k], msg=f"failed on {k}") def test_keep_in_fp32_modules(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() @@ -768,10 +782,11 @@ def test_keep_in_fp32_modules(self): model = model_class.from_pretrained(tmpdirname, dtype=torch.float16) for name, param in model.named_parameters(): - if any(n in model_class._keep_in_fp32_modules for n in name.split(".")): - self.assertTrue(param.dtype == torch.float32) - else: - self.assertTrue(param.dtype == torch.float16, name) + with self.subTest(name): + if re.search("|".join(model_class._keep_in_fp32_modules), name): + self.assertTrue(param.dtype == torch.float32) + else: + self.assertTrue(param.dtype == torch.float16, name) def test_save_load_keys_to_ignore_on_save(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -899,7 +914,7 @@ def test_can_init_all_missing_weights(self): if match_object := re.search(r"^# Copyright (\d{4})", source_code, re.MULTILINE | re.IGNORECASE): addition_year = int(match_object.group(1)) - for model_class in self.all_model_classes: + for model_class in self.all_model_classes[::-1]: # For now, skip everything older than 2024 and "important models" (too much models to patch otherwise) # TODO: relax this as we patch more and more models if addition_year < 2023: @@ -917,10 +932,10 @@ def seeded_initialize_weights(self, module): # First, initialize the model from config -> this ensure everything is correctly initialized, even if # _init_weights() does not take all weights into account correctly - model_from_config = model_class(copy.deepcopy(config)) + model_from_config = model_class(copy.deepcopy(config)).eval() # Here, passing an empty state dict will force all weights to be moved from meta to cpu, then be initialized # by _init_weights() - model_from_pretrained = model_class.from_pretrained(None, config=config, state_dict={}) + model_from_pretrained = model_class.from_pretrained(None, config=config, state_dict={}).eval() # Back to original method to avoid issues if running several other tests PreTrainedModel._initialize_weights = original_initialize_weights @@ -938,15 +953,13 @@ def seeded_initialize_weights(self, module): # Everything must be exactly the same as we set the same seed for each init different_weights = [] - for (k1, v1), (k2, v2) in zip( - model_from_config.state_dict().items(), model_from_pretrained.state_dict().items() - ): - self.assertEqual(k1, k2, "The keys from each model should be the same") + from_pre_state = dict(model_from_pretrained.state_dict()) + for (k1, v1) in model_from_config.state_dict().items(): # In case using torch.nn.utils.parametrizations on a module, we should skip the resulting keys if re.search(r"\.parametrizations\..*?\.original[01]", k1): continue - + v2 = from_pre_state[k1] # Since we added the seed, they should be exactly the same (i.e. using allclose maybe be wrong due # to very low std in init function) if not (v1 == v2).all(): @@ -1175,6 +1188,10 @@ def check_training_gradient_checkpointing(self, gradient_checkpointing_kwargs=No print( f"None for {k}, Probaby running a MOE, make sure grad is not NONE on EVERY layer. At LEAST 1 of the expert layer should have grads!" ) + if "shared" in k: + print( + f"None for {k}, Probaby a model that does not default to tie the encoder and decoder!" + ) else: with self.subTest(f"{k}"): self.assertTrue( @@ -1764,76 +1781,77 @@ def test_resize_embeddings_untied(self): self.skipTest(reason="Model cannot untied embeddings") for model_class in self.all_model_classes: - config = copy.deepcopy(original_config) - if is_deepspeed_zero3_enabled(): - with deepspeed.zero.Init(): - model = model_class(config) - else: - model = model_class(config).to(torch_device) - model.eval() - - # if no output embeddings -> leave test - if model.get_output_embeddings() is None: - continue + with self.subTest(model_class): + config = copy.deepcopy(original_config) + if is_deepspeed_zero3_enabled(): + with deepspeed.zero.Init(): + model = model_class(config) + else: + model = model_class(config).to(torch_device) + model.eval() - # Check that resizing the token embeddings with a larger vocab size increases the model's vocab size - model_vocab_size = config.get_text_config().vocab_size - model.resize_token_embeddings(model_vocab_size + 10) - new_model_vocab_size = model.config.get_text_config().vocab_size - self.assertEqual(new_model_vocab_size, model_vocab_size + 10) - output_embeds = model.get_output_embeddings() - self.assertEqual(output_embeds.weight.shape[0], model_vocab_size + 10) - # Check bias if present - if output_embeds.bias is not None: - self.assertEqual(output_embeds.bias.shape[0], model_vocab_size + 10) - # Check that the model can still do a forward pass successfully (every parameter should be resized) - if not is_deepspeed_zero3_enabled(): - # A distriputed launcher is needed for the forward pass when deepspeed is enabled - model(**self._prepare_for_class(inputs_dict, model_class)) + # if no output embeddings -> leave test + if model.get_output_embeddings() is None: + continue - # Test multivariate resizing. - model.resize_token_embeddings(model_vocab_size + 10) - output_embeds = model.get_output_embeddings() - # Check that added embeddings mean is close to the old embeddings mean - if is_deepspeed_zero3_enabled(): - with deepspeed.zero.GatheredParameters(output_embeds.weight, modifier_rank=None): + # Check that resizing the token embeddings with a larger vocab size increases the model's vocab size + model_vocab_size = config.get_text_config().vocab_size + model.resize_token_embeddings(model_vocab_size + 10) + new_model_vocab_size = model.config.get_text_config().vocab_size + self.assertEqual(new_model_vocab_size, model_vocab_size + 10) + output_embeds = model.get_output_embeddings() + self.assertEqual(output_embeds.weight.shape[0], model_vocab_size + 10) + # Check bias if present + if output_embeds.bias is not None: + self.assertEqual(output_embeds.bias.shape[0], model_vocab_size + 10) + # Check that the model can still do a forward pass successfully (every parameter should be resized) + if not is_deepspeed_zero3_enabled(): + # A distriputed launcher is needed for the forward pass when deepspeed is enabled + model(**self._prepare_for_class(inputs_dict, model_class)) + + # Test multivariate resizing. + model.resize_token_embeddings(model_vocab_size + 10) + output_embeds = model.get_output_embeddings() + # Check that added embeddings mean is close to the old embeddings mean + if is_deepspeed_zero3_enabled(): + with deepspeed.zero.GatheredParameters(output_embeds.weight, modifier_rank=None): + old_embeddings_mean = torch.mean(output_embeds.weight.data[:-10, :], axis=0) + new_embeddings_mean = torch.mean(output_embeds.weight.data[-10:, :], axis=0) + else: old_embeddings_mean = torch.mean(output_embeds.weight.data[:-10, :], axis=0) new_embeddings_mean = torch.mean(output_embeds.weight.data[-10:, :], axis=0) - else: - old_embeddings_mean = torch.mean(output_embeds.weight.data[:-10, :], axis=0) - new_embeddings_mean = torch.mean(output_embeds.weight.data[-10:, :], axis=0) - torch.testing.assert_close(old_embeddings_mean, new_embeddings_mean, rtol=1e-3, atol=1e-3) - # check if the old bias mean close to added bias mean. - if output_embeds.bias is not None: - if is_deepspeed_zero3_enabled(): - with deepspeed.zero.GatheredParameters(output_embeds.bias, modifier_rank=None): + torch.testing.assert_close(old_embeddings_mean, new_embeddings_mean, rtol=1e-3, atol=1e-3) + # check if the old bias mean close to added bias mean. + if output_embeds.bias is not None: + if is_deepspeed_zero3_enabled(): + with deepspeed.zero.GatheredParameters(output_embeds.bias, modifier_rank=None): + old_bias_mean = torch.mean(output_embeds.bias.data[:-10], axis=0) + new_bias_mean = torch.mean(output_embeds.bias.data[-10:], axis=0) + else: old_bias_mean = torch.mean(output_embeds.bias.data[:-10], axis=0) new_bias_mean = torch.mean(output_embeds.bias.data[-10:], axis=0) - else: - old_bias_mean = torch.mean(output_embeds.bias.data[:-10], axis=0) - new_bias_mean = torch.mean(output_embeds.bias.data[-10:], axis=0) - - torch.testing.assert_close(old_bias_mean, new_bias_mean, rtol=1e-5, atol=1e-5) - # Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size - model.resize_token_embeddings(model_vocab_size - 15) - new_model_vocab_size = model.config.get_text_config().vocab_size - self.assertEqual(new_model_vocab_size, model_vocab_size - 15) - # Check that it actually resizes the embeddings matrix - output_embeds = model.get_output_embeddings() - self.assertEqual(output_embeds.weight.shape[0], model_vocab_size - 15) - # Check bias if present - if output_embeds.bias is not None: - self.assertEqual(output_embeds.bias.shape[0], model_vocab_size - 15) - # Check that the model can still do a forward pass successfully (every parameter should be resized) - # Input ids should be clamped to the maximum size of the vocabulary - inputs_dict["input_ids"].clamp_(max=model_vocab_size - 15 - 1) - if "decoder_input_ids" in inputs_dict: - inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1) - # Check that the model can still do a forward pass successfully (every parameter should be resized) - if not is_deepspeed_zero3_enabled(): - # A distriputed launcher is needed for the forward pass when deepspeed is enabled - model(**self._prepare_for_class(inputs_dict, model_class)) + torch.testing.assert_close(old_bias_mean, new_bias_mean, rtol=1e-5, atol=1e-5) + + # Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size + model.resize_token_embeddings(model_vocab_size - 15) + new_model_vocab_size = model.config.get_text_config().vocab_size + self.assertEqual(new_model_vocab_size, model_vocab_size - 15) + # Check that it actually resizes the embeddings matrix + output_embeds = model.get_output_embeddings() + self.assertEqual(output_embeds.weight.shape[0], model_vocab_size - 15) + # Check bias if present + if output_embeds.bias is not None: + self.assertEqual(output_embeds.bias.shape[0], model_vocab_size - 15) + # Check that the model can still do a forward pass successfully (every parameter should be resized) + # Input ids should be clamped to the maximum size of the vocabulary + inputs_dict["input_ids"].clamp_(max=model_vocab_size - 15 - 1) + if "decoder_input_ids" in inputs_dict: + inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1) + # Check that the model can still do a forward pass successfully (every parameter should be resized) + if not is_deepspeed_zero3_enabled(): + # A distriputed launcher is needed for the forward pass when deepspeed is enabled + model(**self._prepare_for_class(inputs_dict, model_class)) @require_deepspeed @require_torch_accelerator @@ -1914,51 +1932,71 @@ def test_can_use_safetensors(self): model_tied.save_pretrained(d, safe_serialization=True) except Exception as e: raise Exception(f"Class {model_class.__name__} cannot be saved using safetensors: {e}") + with self.subTest(model_class): + model_reloaded, infos = model_class.from_pretrained(d, output_loading_info=True) + # Checking the state dicts are correct + reloaded_state = model_reloaded.state_dict() + for k, v in model_tied.state_dict().items(): + with self.subTest(f"{model_class.__name__}.{k}"): + torch.testing.assert_close( + v, reloaded_state[k], msg=lambda x: f"{model_class.__name__}: Tensor {k}: {x}.\n" + "This probably means that it was not set with the correct value when tying." + ) - model_reloaded, infos = model_class.from_pretrained(d, output_loading_info=True) - # Checking the state dicts are correct - reloaded_state = model_reloaded.state_dict() - for k, v in model_tied.state_dict().items(): - self.assertIn(k, reloaded_state, f"Key {k} is missing from reloaded") - torch.testing.assert_close( - v, reloaded_state[k], msg=lambda x: f"{model_class.__name__}: Tensor {k}: {x}" - ) - # Checking there was no complain of missing weights - self.assertEqual(infos["missing_keys"], []) + # Checking the tensor sharing are correct on the new model (weights are properly tied in both cases) + ptrs = defaultdict(list) + for k, v in model_tied.state_dict().items(): + ptrs[v.data_ptr()].append(k) - # Checking the tensor sharing are correct - ptrs = defaultdict(list) - for k, v in model_tied.state_dict().items(): - ptrs[v.data_ptr()].append(k) + shared_ptrs = {k: v for k, v in ptrs.items() if len(v) > 1} - shared_ptrs = {k: v for k, v in ptrs.items() if len(v) > 1} + for shared_names in shared_ptrs.values(): + reloaded_ptrs = {reloaded_state[k].data_ptr() for k in shared_names} + self.assertEqual( + len(reloaded_ptrs), + 1, + f"The shared pointers are incorrect, found different pointers for keys {shared_names}. `__init__` and `from_pretrained` end up not tying the weights the same way.", + ) - for shared_names in shared_ptrs.values(): - reloaded_ptrs = {reloaded_state[k].data_ptr() for k in shared_names} - self.assertEqual( - len(reloaded_ptrs), - 1, - f"The shared pointers are incorrect, found different pointers for keys {shared_names}", - ) + # Checking there was no complain of missing weights + self.assertEqual(infos["missing_keys"], set(), "These keys were removed when serializing, and were not properly loaded by `from_pretrained`.") def test_load_save_without_tied_weights(self): for model_class in self.all_model_classes: config, _ = self.model_tester.prepare_config_and_inputs_for_common() config.tie_word_embeddings = False - model = model_class(config) + try: + config.get_text_config().tie_word_embeddings = False + except Exception as _: + pass + + # config.tie_encoder_decoder = False + model = model_class(config) # we init the model without tie + # if this test fails later on, it means init tied the weights with tempfile.TemporaryDirectory() as d: model.save_pretrained(d) + with safe_open(f"{d}/model.safetensors", framework="pt") as f: + serialized_keys = f.keys() + + model_reloaded, infos = model_class.from_pretrained(d, output_loading_info=True) + # Checking the state dicts are correct + + reloaded_state = model_reloaded.state_dict() + for k, v in model.state_dict().items(): + with self.subTest(k): + torch.testing.assert_close( + v, + reloaded_state[k], + msg=lambda x: f"{model_class.__name__}: Tensor {k}: {x}. Key {k} was serialized: {k in serialized_keys}. If `False`, this means it was probably aliased and safetensors removed it. If `True` it means `_init_weights` overwrote that key", + ) - model_reloaded, infos = model_class.from_pretrained(d, output_loading_info=True) - # Checking the state dicts are correct - reloaded_state = model_reloaded.state_dict() - for k, v in model.state_dict().items(): - self.assertIn(k, reloaded_state, f"Key {k} is missing from reloaded") - torch.testing.assert_close( - v, reloaded_state[k], msg=lambda x: f"{model_class.__name__}: Tensor {k}: {x}" - ) # Checking there was no complain of missing weights - self.assertEqual(infos["missing_keys"], []) + self.assertEqual( + infos["missing_keys"], + set(), + "Given that the loaded weights are the same, the issue is in `tie_weights`: it tied these keys and removed them from serialization. But because of tiying (hardcoded or not) the previous check is fine.\ + This can happen if `save_pretrained` remove the targets and not the keys from serialiazation", + ) def test_tied_weights_keys(self): original_config, _ = self.model_tester.prepare_config_and_inputs_for_common() @@ -2017,7 +2055,7 @@ def test_model_weights_reload_no_missing_tied_weights(self): missing_keys = set(infos["missing_keys"]) extra_missing = missing_keys - param_names - # Remove tied weights from extra missing: they are normally not warned as missing if their tied + # IMPORTANT Remove tied weights from extra missing: they are normally not warned as missing if their tied # counterpart is present but here there are no weights at all so we do get the warning. ptrs = collections.defaultdict(list) for name, tensor in model_reloaded.state_dict().items(): @@ -2472,7 +2510,7 @@ def test_load_with_mismatched_shapes(self): with tempfile.TemporaryDirectory() as tmp_dir: model = model_class(config) model.save_pretrained(tmp_dir) - + num_labels = config.num_labels # Fails when we don't set ignore_mismatched_sizes=True with self.assertRaises(RuntimeError): new_model = AutoModelForSequenceClassification.from_pretrained(tmp_dir, num_labels=42) @@ -2485,7 +2523,7 @@ def test_load_with_mismatched_shapes(self): new_model = AutoModelForSequenceClassification.from_pretrained( tmp_dir, num_labels=42, ignore_mismatched_sizes=True ) - self.assertIn("the shapes did not match", cl.out) + self.assertIn("Reinit due to size mismatch", cl.out) new_model.to(torch_device) inputs = self._prepare_for_class(inputs_dict, model_class) logits = new_model(**inputs).logits @@ -2495,7 +2533,7 @@ def test_load_with_mismatched_shapes(self): new_model_without_prefix = AutoModel.from_pretrained( tmp_dir, vocab_size=10, ignore_mismatched_sizes=True ) - self.assertIn("the shapes did not match", cl.out) + self.assertIn("Reinit due to size mismatch", cl.out) input_ids = ids_tensor((2, 8), 10) new_model_without_prefix.to(torch_device) if self.is_encoder_decoder: @@ -2536,7 +2574,7 @@ def test_can_load_ignoring_mismatched_shapes(self): with CaptureLogger(logger) as cl: new_model = model_class.from_pretrained(tmp_dir, num_labels=42, ignore_mismatched_sizes=True) - self.assertIn("the shapes did not match", cl.out) + self.assertIn("Reinit due to size mismatch", cl.out) # Find the name of the module with the mismatched size top_linear_modules = [ @@ -2570,18 +2608,21 @@ def test_can_load_ignoring_mismatched_shapes(self): ] # Usually we have only 1, but swiftformer and deit have 2 Linear layers using `num_labels` mismatched_modules = [name for name, module in top_linear_modules if module.out_features == 42] - - for (k1, v1), (k2, v2) in zip(new_model.named_parameters(), model.named_parameters()): - # Sanity check: params must have all the same name - self.assertEqual(k1, k2) + old = dict(model.named_parameters()) + new = dict(new_model.named_parameters()) + assert not set(old.keys()) - set(new.keys()) + for k1 in new.keys(): + k2 = k1 + v1 = old[k1] + v2 = new[k2] # Each param except the mismatched ones must be exactly similar if not any(k1.startswith(mismatched_module) for mismatched_module in mismatched_modules): - self.assertTrue((v1 == v2).all()) + torch.testing.assert_close(v1, v2, msg=f"{k1} and {k2} do not match: {v1} != {v2}") # Check that the dims are indeed mismatched between old and new models else: # The old model should have `num_labels=3` (here it's the first dim of shape, as Linear layers # are transposed) - self.assertEqual(v2.shape[0], 3) + self.assertEqual(v2.shape[0], 42) # Make sure the mean of the new Linear layer is correctly centered around 0 (we cannot use # a lower value for the check as some models hardcode a std of 0.02 instead of using the # config, which we set very small with `config_no_init`) @@ -3895,7 +3936,125 @@ def test_bc_torch_dtype(self): ): self.assertEqual(k1, k2) self.assertEqual(v1.dtype, v2.dtype) - self.assertTrue((v1 == v2).all()) + torch.testing.assert_close(v1, v2, msg=f"{k1} and {k2} do not match: {v1} != {v2}") + + +@require_torch +def test_weight_conversion_operations_roundtrip(): + import torch + + from transformers.core_model_loading import ( + Chunk, + Concatenate, + Fp8Dequantize, + Fp8Quantize, + MergeModuleList, + Shard, + WeightConversion, + convert_state_dict, + ) + + state_dict = { + "experts.0.w1.weight": torch.tensor([[0.0, 1.0], [2.0, 3.0]]), + "experts.1.w1.weight": torch.tensor([[4.0, 5.0], [6.0, 7.0]]), + "experts.0.w3.weight": torch.tensor([[10.0, 11.0], [12.0, 13.0]]), + "experts.1.w3.weight": torch.tensor([[14.0, 15.0], [16.0, 17.0]]), + "self_attn.q_proj.weight": torch.tensor([[1.0, 2.0], [3.0, 4.0]]), + "self_attn.k_proj.weight": torch.tensor([[5.0, 6.0], [7.0, 8.0]]), + "self_attn.v_proj.weight": torch.tensor([[9.0, 10.0], [11.0, 12.0]]), + "self_attn.out_proj.weight": torch.arange(12.0).reshape(6, 2), + "mlp.w2.weight": torch.tensor([[1.0, 0.0], [0.0, 1.0]]), + } + + forward_mapping = [ + WeightConversion( + ["experts.*.w1.weight", "experts.*.w3.weight"], + "experts.gate_up_proj.weight", + [MergeModuleList(dim=0), Concatenate(dim=0), Fp8Quantize(block_size=(1, 1))], + ), + WeightConversion( + ["self_attn.q_proj.weight", "self_attn.k_proj.weight", "self_attn.v_proj.weight"], + "self_attn.qkv_proj.weight", + Concatenate(dim=0), + ), + WeightConversion( + "self_attn.out_proj.weight", + ["self_attn.out_proj.weight.shard0", "self_attn.out_proj.weight.shard1"], + Shard(dim=0, world_size=2, return_all=True), + ), + WeightConversion("mlp.w2.weight", "mlp.down_proj.weight"), + ] + + converted_state, _ = convert_state_dict(None, state_dict, forward_mapping, tp_plan=None, quantization_config=None) + + expected_qkv = torch.cat( + ( + state_dict["self_attn.q_proj.weight"], + state_dict["self_attn.k_proj.weight"], + state_dict["self_attn.v_proj.weight"], + ), + dim=0, + ) + torch.testing.assert_close(converted_state["self_attn.qkv_proj.weight"], expected_qkv) + + reconstructed_out_proj = torch.cat( + (converted_state["self_attn.out_proj.weight.shard0"], converted_state["self_attn.out_proj.weight.shard1"]), + dim=0, + ) + torch.testing.assert_close(reconstructed_out_proj, state_dict["self_attn.out_proj.weight"]) + torch.testing.assert_close(converted_state["mlp.down_proj.weight"], state_dict["mlp.w2.weight"]) + + inverse_mapping = [ + WeightConversion( + ["experts.gate_up_proj.weight", "experts.gate_up_proj.scale"], + "experts.gate_up_proj.dequantized", + Fp8Dequantize(block_size=(1, 1)), + ), + WeightConversion( + "experts.gate_up_proj.dequantized", + ["experts.w1.concat", "experts.w3.concat"], + Chunk(dim=0, sizes=[4, 4]), + ), + WeightConversion( + "experts.w1.concat", + ["experts.0.w1.weight", "experts.1.w1.weight"], + Chunk(dim=0, sizes=[2, 2]), + ), + WeightConversion( + "experts.w3.concat", + ["experts.0.w3.weight", "experts.1.w3.weight"], + Chunk(dim=0, sizes=[2, 2]), + ), + WeightConversion( + "self_attn.qkv_proj.weight", + [ + "self_attn.q_proj.weight", + "self_attn.k_proj.weight", + "self_attn.v_proj.weight", + ], + Chunk(dim=0, sizes=[2, 2, 2]), + ), + WeightConversion( + ["self_attn.out_proj.weight.shard0", "self_attn.out_proj.weight.shard1"], + "self_attn.out_proj.weight", + Concatenate(dim=0), + ), + WeightConversion("mlp.down_proj.weight", "mlp.w2.weight"), + ] + + roundtrip_state, _ = convert_state_dict( + None, converted_state, inverse_mapping, tp_plan=None, quantization_config=None + ) + + torch.testing.assert_close(roundtrip_state["experts.0.w1.weight"], state_dict["experts.0.w1.weight"]) + torch.testing.assert_close(roundtrip_state["experts.1.w1.weight"], state_dict["experts.1.w1.weight"]) + torch.testing.assert_close(roundtrip_state["experts.0.w3.weight"], state_dict["experts.0.w3.weight"]) + torch.testing.assert_close(roundtrip_state["experts.1.w3.weight"], state_dict["experts.1.w3.weight"]) + torch.testing.assert_close(roundtrip_state["self_attn.q_proj.weight"], state_dict["self_attn.q_proj.weight"]) + torch.testing.assert_close(roundtrip_state["self_attn.k_proj.weight"], state_dict["self_attn.k_proj.weight"]) + torch.testing.assert_close(roundtrip_state["self_attn.v_proj.weight"], state_dict["self_attn.v_proj.weight"]) + torch.testing.assert_close(roundtrip_state["self_attn.out_proj.weight"], state_dict["self_attn.out_proj.weight"]) + torch.testing.assert_close(roundtrip_state["mlp.w2.weight"], state_dict["mlp.w2.weight"]) global_rng = random.Random() diff --git a/tests/utils/test_core_model_loading.py b/tests/utils/test_core_model_loading.py new file mode 100644 index 000000000000..1f16e66f42c6 --- /dev/null +++ b/tests/utils/test_core_model_loading.py @@ -0,0 +1,305 @@ +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import re +import unittest + +import torch +import torch.nn as nn + +from transformers.core_model_loading import ( + Chunk, + Concatenate, + MergeModulelist, + WeightConverter, + _apply_star_subst, + _glob_to_regex_src, + build_glob_alt, + convert_and_load_state_dict_in_model, + glob_to_re, + match_glob, +) + + +class TestWeightGlobMatching(unittest.TestCase): + def setUp(self): + self.weight_globs_digits = [ + "model.layers.*.mlp.gate_up_proj.weight", + "model.layers.*.self_attn.q_proj.weight", + "embed_tokens.weight", + ] + self.alt_digits, self.map_digits = build_glob_alt(self.weight_globs_digits) + + self.weight_globs_any = [ + "model.layers.*.mlp.gate_up_proj.weight", + "model.layers.*.self_attn.q_proj.weight", + "embed_tokens.weight", + ] + self.alt_any, self.map_any = build_glob_alt(self.weight_globs_any) + + def test_exact_match(self): + self.assertEqual(match_glob("embed_tokens.weight", self.alt_digits, self.map_digits), "embed_tokens.weight") + + def test_digits_only_star_accepts_digits(self): + self.assertEqual( + match_glob("model.layers.0.mlp.gate_up_proj.weight", self.alt_digits, self.map_digits), + "model.layers.*.mlp.gate_up_proj.weight", + ) + self.assertEqual( + match_glob("model.layers.12.self_attn.q_proj.weight", self.alt_digits, self.map_digits), + "model.layers.*.self_attn.q_proj.weight", + ) + + def test_digits_only_star_rejects_nondigits(self): + # 'a' is not digits, so it should not match with + self.assertIsNone(match_glob("model.layers.a.mlp.gate_up_proj.weight", self.alt_digits, self.map_digits)) + + def test_anychar_star_accepts_nondigits(self): + self.assertEqual( + match_glob("model.layers.a.mlp.gate_up_proj.weight", self.alt_any, self.map_any), + "model.layers.*.mlp.gate_up_proj.weight", + ) + self.assertEqual( + match_glob("model.layers.00x.mlp.gate_up_proj.weight", self.alt_any, self.map_any), + "model.layers.*.mlp.gate_up_proj.weight", + ) + + def test_no_match(self): + self.assertIsNone(match_glob("model.layers.0.mlp.up_proj.weight", self.alt_digits, self.map_digits)) + + def test_leftmost_alternative_wins_for_overlapping_patterns(self): + # Overlapping patterns: both could match; ensure leftmost wins + globs = [ + "model.layers.*.mlp.*.weight", # broader (first) + "model.layers.0.mlp.gate_up_proj.weight", # more specific (second) + ] + alt, mapping = build_glob_alt(globs, digits_only=False) + + # Both branches match; Python's regex picks the leftmost alternative → index 0 + self.assertEqual( + match_glob("model.layers.0.mlp.gate_up_proj.weight", alt, mapping), "model.layers.*.mlp.*.weight" + ) + + def test_multiple_patterns_same_prefix(self): + globs = [ + "model.layers.*.self_attn.q_proj.weight", + "model.layers.*.self_attn.k_proj.weight", + "model.layers.*.self_attn.v_proj.weight", + ] + alt, mapping = build_glob_alt( + globs, + ) + + self.assertEqual( + match_glob("model.layers.3.self_attn.q_proj.weight", alt, mapping), + "model.layers.*.self_attn.q_proj.weight", + ) + self.assertEqual( + match_glob("model.layers.3.self_attn.k_proj.weight", alt, mapping), + "model.layers.*.self_attn.k_proj.weight", + ) + self.assertEqual( + match_glob("model.layers.3.self_attn.v_proj.weight", alt, mapping), + "model.layers.*.self_attn.v_proj.weight", + ) + + def test_anchor_full_match_only(self): + # Make sure partial strings don't match—anchors ^...$ are in each branch + self.assertIsNone(match_glob("foo.model.layers.0.mlp.gate_up_proj.weight.bar", self.alt_any, self.map_any)) + + def test_large_batch_performance_smoke(self): + # Not a perf benchmark, but ensures building and matching a larger alternation is OK + globs = [f"model.layers.*.mlp.block{i}.weight" for i in range(200)] + alt, mapping = build_glob_alt( + globs, + ) + key = "model.layers.123.mlp.block57.weight" + self.assertEqual(match_glob(key, alt, mapping), "model.layers.*.mlp.block57.weight") + + +class TestGlobRegexHelpers(unittest.TestCase): + def test_glob_to_regex_src_digits_only(self): + pattern = _glob_to_regex_src( + "model.layers.*.mlp.weight", + ) + self.assertEqual(pattern, r"model\.layers\.(\d+)\.mlp\.weight") + + def test_glob_to_regex_src_any_chars(self): + pattern = _glob_to_regex_src("model.layers.*.mlp.weight", digits_only=False) + self.assertEqual(pattern, r"model\.layers\.(.+)\.mlp\.weight") + + def test_glob_to_re_fullmatch(self): + regex_src = glob_to_re( + "model.layers.*.mlp.weight", + ) + regex = re.compile(f"^{regex_src}$") + self.assertIsNotNone(regex.fullmatch("model.layers.12.mlp.weight")) + self.assertIsNone(regex.fullmatch("model.layers.foo.mlp.weight")) + + def test_apply_star_subst(self): + pattern = "model.layers.*.block.*.weight" + replaced = _apply_star_subst(pattern, ["03", "attn"]) + self.assertEqual(replaced, "model.layers.03.block.attn.weight") + + +class DummyParamModule(nn.Module): + def __init__(self, shape): + super().__init__() + self.weight = nn.Parameter(torch.zeros(shape)) + + +class DummySelfAttn(nn.Module): + def __init__(self): + super().__init__() + self.q_proj = DummyParamModule((1, 2)) + self.k_proj = DummyParamModule((1, 2)) + self.v_proj = DummyParamModule((1, 2)) + + +class DummyExperts(nn.Module): + def __init__(self): + super().__init__() + self.gate_up_proj = DummyParamModule((2, 4, 2)) + self.down_proj = DummyParamModule((2, 2, 2)) + + +class DummyLayer(nn.Module): + def __init__(self): + super().__init__() + self.self_attn = DummySelfAttn() + self.experts = DummyExperts() + + +class DummyTopModel(nn.Module): + def __init__(self): + super().__init__() + self.layers = nn.ModuleList([DummyLayer(), DummyLayer()]) + + +class DummyMLP(nn.Module): + def __init__(self): + super().__init__() + self.down_proj = DummyParamModule((2, 2)) + + +class DummyRoot(nn.Module): + def __init__(self): + super().__init__() + self.model = DummyTopModel() + self.mlp = DummyMLP() + + +class TestConvertAndLoadStateDict(unittest.TestCase): + def test_moe_and_qkv_conversion(self): + model = DummyRoot() + + raw_tensors = { + "model.layers.0.experts.0.w1.weight": torch.tensor([[0.0, 1.0], [2.0, 3.0]]), + "model.layers.0.experts.1.w1.weight": torch.tensor([[10.0, 11.0], [12.0, 13.0]]), + "model.layers.0.experts.0.w3.weight": torch.tensor([[4.0, 5.0], [6.0, 7.0]]), + "model.layers.0.experts.1.w3.weight": torch.tensor([[14.0, 15.0], [16.0, 17.0]]), + "model.layers.0.experts.0.w2.weight": torch.tensor([[20.0, 21.0], [22.0, 23.0]]), + "model.layers.0.experts.1.w2.weight": torch.tensor([[24.0, 25.0], [26.0, 27.0]]), + "model.layers.1.experts.0.w1.weight": torch.tensor([[30.0, 31.0], [32.0, 33.0]]), + "model.layers.1.experts.1.w1.weight": torch.tensor([[34.0, 35.0], [36.0, 37.0]]), + "model.layers.1.experts.0.w3.weight": torch.tensor([[38.0, 39.0], [40.0, 41.0]]), + "model.layers.1.experts.1.w3.weight": torch.tensor([[42.0, 43.0], [44.0, 45.0]]), + "model.layers.1.experts.0.w2.weight": torch.tensor([[46.0, 47.0], [48.0, 49.0]]), + "model.layers.1.experts.1.w2.weight": torch.tensor([[50.0, 51.0], [52.0, 53.0]]), + "model.layers.0.self_attn.qkv_proj.weight": torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]), + "model.layers.1.self_attn.qkv_proj.weight": torch.tensor([[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]]), + "mlp.w2.weight": torch.tensor([[60.0, 61.0], [62.0, 63.0]]), + } + state_dict = {k: (k, v.clone()) for k, v in raw_tensors.items()} + + weight_mapping = [ + WeightConverter( + ["model.layers.*.experts.*.w1.weight", "model.layers.*.experts.*.w3.weight"], + "model.layers.*.experts.gate_up_proj.weight", + operations=[MergeModulelist(dim=0), Concatenate(dim=1)], + ), + WeightConverter( + "model.layers.*.experts.*.w2.weight", + "model.layers.*.experts.down_proj.weight", + operations=[MergeModulelist(dim=0)], + ), + WeightConverter( + "model.layers.*.self_attn.qkv_proj.weight", + [ + "model.layers.*.self_attn.q_proj.weight", + "model.layers.*.self_attn.k_proj.weight", + "model.layers.*.self_attn.v_proj.weight", + ], + operations=[Concatenate(dim=0), Chunk(dim=0, chunks=3)], + ), + WeightConverter("mlp.w2.weight", "mlp.down_proj.weight"), + ] + + missing, unexpected, mismatch, misc = convert_and_load_state_dict_in_model( + model, state_dict, weight_mapping, tp_plan=None, quantizer=None + ) + + self.assertEqual(missing, set()) + self.assertEqual(unexpected, set()) + self.assertEqual(mismatch, set()) + self.assertEqual(misc, {}) + + model_state = model.state_dict() + + def cat_gate(layer_prefix: str) -> torch.Tensor: + w1 = [ + raw_tensors[f"{layer_prefix}.experts.0.w1.weight"], + raw_tensors[f"{layer_prefix}.experts.1.w1.weight"], + ] + w3 = [ + raw_tensors[f"{layer_prefix}.experts.0.w3.weight"], + raw_tensors[f"{layer_prefix}.experts.1.w3.weight"], + ] + return torch.cat([torch.stack(w1, dim=0), torch.stack(w3, dim=0)], dim=1) + + torch.testing.assert_close( + model_state["model.layers.0.experts.gate_up_proj.weight"], cat_gate("model.layers.0") + ) + torch.testing.assert_close( + model_state["model.layers.1.experts.gate_up_proj.weight"], cat_gate("model.layers.1") + ) + + def stack_down(layer_prefix: str) -> torch.Tensor: + return torch.stack( + [ + raw_tensors[f"{layer_prefix}.experts.0.w2.weight"], + raw_tensors[f"{layer_prefix}.experts.1.w2.weight"], + ], + dim=0, + ) + + torch.testing.assert_close( + model_state["model.layers.0.experts.down_proj.weight"], stack_down("model.layers.0") + ) + torch.testing.assert_close( + model_state["model.layers.1.experts.down_proj.weight"], stack_down("model.layers.1") + ) + + for layer_idx in range(2): + key = f"model.layers.{layer_idx}.self_attn.qkv_proj.weight" + expected_q, expected_k, expected_v = torch.chunk(raw_tensors[key], chunks=3, dim=0) + prefix = f"model.layers.{layer_idx}.self_attn" + torch.testing.assert_close(model_state[f"{prefix}.q_proj.weight"], expected_q) + torch.testing.assert_close(model_state[f"{prefix}.k_proj.weight"], expected_k) + torch.testing.assert_close(model_state[f"{prefix}.v_proj.weight"], expected_v) + + torch.testing.assert_close(model_state["mlp.down_proj.weight"], raw_tensors["mlp.w2.weight"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/utils/check_init_weights_data.py b/utils/check_init_weights_data.py new file mode 100644 index 000000000000..93aebd9f5b2d --- /dev/null +++ b/utils/check_init_weights_data.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Utility that ensures `_init_weights(self, module)` implementations do not use `.data`. + +Direct `.data` access breaks the lazy-initialization safeguards handled by `HFParameter`, so the library forbids it. +""" + +import ast +import sys +from pathlib import Path + + +MODELING_ROOT = Path("src/transformers/models") +MODELING_PATTERNS = ("modeling_*.py", "modular_*.py") + + +def iter_modeling_files(): + for pattern in MODELING_PATTERNS: + yield from MODELING_ROOT.rglob(pattern) + + +def function_has_forbidden_data_usage(fn: ast.FunctionDef) -> int | None: + """ + Returns the first offending line number if `.data` is used, otherwise `None`. + """ + + args = fn.args.args + if len(args) < 2 or getattr(args[0], "arg", None) != "self" or getattr(args[1], "arg", None) != "module": + return None + + for node in ast.walk(fn): + if isinstance(node, ast.Attribute) and node.attr == "data": + return node.lineno + + return None + + +def main() -> int: + violations: list[str] = [] + + for file_path in iter_modeling_files(): + try: + text = file_path.read_text(encoding="utf-8") + tree = ast.parse(text, filename=str(file_path)) + except Exception as exc: + violations.append(f"{file_path}: failed to parse ({exc}).") + continue + + for node in ast.walk(tree): + if isinstance(node, ast.FunctionDef) and node.name == "_init_weights": + offending_line = function_has_forbidden_data_usage(node) + if offending_line is not None: + violations.append( + f"{file_path}:{offending_line}: `_init_weights(self, module)` uses `.data`. " + "Use tensor ops directly to remain compatible with HFParameter." + ) + break + + if violations: + print("Found forbidden `.data` usage inside `_init_weights(self, module)`:\n", file=sys.stderr) + print("\n".join(violations), file=sys.stderr) + return 1 + + return 0 + + +if __name__ == "__main__": + sys.exit(main())