Skip to content
247 changes: 147 additions & 100 deletions src/transformers/modeling_utils.py

Large diffs are not rendered by default.

6 changes: 0 additions & 6 deletions src/transformers/models/esm/modeling_esm.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,8 +732,6 @@ def __init__(self, config):
self.esm = EsmModel(config, add_pooling_layer=False)
self.lm_head = EsmLMHead(config)

self.init_weights()

self.post_init()

def get_output_embeddings(self):
Expand Down Expand Up @@ -828,8 +826,6 @@ def __init__(self, config):
self.esm = EsmModel(config, add_pooling_layer=False)
self.classifier = EsmClassificationHead(config)

self.init_weights()

self.post_init()

@can_return_tuple
Expand Down Expand Up @@ -903,8 +899,6 @@ def __init__(self, config):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)

self.init_weights()

self.post_init()

@can_return_tuple
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/hubert/modeling_hubert.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,7 +993,7 @@ def __init__(self, config, target_lang: Optional[str] = None):
# Initialize weights and apply final processing
self.post_init()

def tie_weights(self, missing_keys=None):
def tie_weights(self, **kwargs):
"""
This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
passing `target_lang=...` to `from_pretrained(...)`.
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/idefics/modeling_idefics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1111,7 +1111,7 @@ def __init__(self, config, vision_model=None):
# Initialize weights and apply final processing
self.post_init()

def tie_weights(self, missing_keys=None):
def tie_weights(self, **kwargs):
"""
Overwrite `transformers.modeling_utils.PreTrainedModel.tie_weights` to handle the case of
IdeficsDecoupledLinear and IdeficsDecoupledEmbedding.
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/sew/modeling_sew.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,7 +857,7 @@ def __init__(self, config, target_lang: Optional[str] = None):
# Initialize weights and apply final processing
self.post_init()

def tie_weights(self, missing_keys=None):
def tie_weights(self, **kwargs):
"""
This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
passing `target_lang=...` to `from_pretrained(...)`.
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/sew_d/modeling_sew_d.py
Original file line number Diff line number Diff line change
Expand Up @@ -1400,7 +1400,7 @@ def __init__(self, config, target_lang: Optional[str] = None):
# Initialize weights and apply final processing
self.post_init()

def tie_weights(self, missing_keys=None):
def tie_weights(self, **kwargs):
"""
This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
passing `target_lang=...` to `from_pretrained(...)`.
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/unispeech/modeling_unispeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -1210,7 +1210,7 @@ def __init__(self, config, target_lang: Optional[str] = None):
# Initialize weights and apply final processing
self.post_init()

def tie_weights(self, missing_keys=None):
def tie_weights(self, **kwargs):
"""
This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
passing `target_lang=...` to `from_pretrained(...)`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1206,7 +1206,7 @@ def __init__(self, config, target_lang: Optional[str] = None):
# Initialize weights and apply final processing
self.post_init()

def tie_weights(self, missing_keys=None):
def tie_weights(self, **kwargs):
"""
This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
passing `target_lang=...` to `from_pretrained(...)`.
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/wav2vec2/modeling_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1684,7 +1684,7 @@ def __init__(self, config, target_lang: Optional[str] = None):
# Initialize weights and apply final processing
self.post_init()

def tie_weights(self, missing_keys=None):
def tie_weights(self, **kwargs):
"""
This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
passing `target_lang=...` to `from_pretrained(...)`.
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/wavlm/modeling_wavlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1135,7 +1135,7 @@ def __init__(self, config, target_lang: Optional[str] = None):
# Initialize weights and apply final processing
self.post_init()

def tie_weights(self, missing_keys=None):
def tie_weights(self, **kwargs):
"""
This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
passing `target_lang=...` to `from_pretrained(...)`.
Expand Down
7 changes: 7 additions & 0 deletions tests/models/openai/test_modeling_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,13 @@ def test_model_from_pretrained(self):
model = OpenAIGPTModel.from_pretrained(model_name)
self.assertIsNotNone(model)

@unittest.skip("Tied weights mapping is reversed, so this is supposed to error out")
def test_correct_missing_keys(self):
# openai defines `_tied_weights_keys = {"transformer.tokens_embed.weight": "lm_head.weight"}` instead
# of the usual `_tied_weights_keys = {"lm_head.weight": "transformer.tokens_embed.weight"}`, so removing
# the head parameters actually removes the source weight, so this test is supposed to fail
pass


@require_torch
class OPENAIGPTModelLanguageGenerationTest(unittest.TestCase):
Expand Down
4 changes: 2 additions & 2 deletions tests/utils/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def __init__(self, config):
def forward(self, x):
return self.linear_2(self.linear(x))

def tie_weights(self, missing_keys=None):
def tie_weights(self, missing_keys=None, **kwargs):
self.linear_2.weight = self.linear.weight
if missing_keys is not None:
missing_keys.discard("linear_2.weight")
Expand Down Expand Up @@ -254,7 +254,7 @@ def __init__(self, config):
def forward(self, x):
return self.decoder(self.base(x))

def tie_weights(self, missing_keys=None):
def tie_weights(self, missing_keys=None, **kwargs):
self.decoder.weight = self.base.linear.weight
if missing_keys is not None:
missing_keys.discard("decoder.weight")
Expand Down