Skip to content

Commit

Permalink
Fix from_pretrained with default base_model_prefix (#15814)
Browse files Browse the repository at this point in the history
  • Loading branch information
sgugger committed Feb 24, 2022
1 parent 7f921bc commit d1fcc90
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 7 deletions.
12 changes: 8 additions & 4 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1580,8 +1580,12 @@ def _load_state_dict_into_model(
loaded_keys = list(state_dict.keys())
prefix = model.base_model_prefix

has_prefix_module = any(s.startswith(prefix) for s in loaded_keys)
expects_prefix_module = any(s.startswith(prefix) for s in expected_keys)
if len(prefix) > 0:
has_prefix_module = any(s.startswith(prefix) for s in loaded_keys)
expects_prefix_module = any(s.startswith(prefix) for s in expected_keys)
else:
has_prefix_module = False
expects_prefix_module = False

# key re-naming operations are never done on the keys
# that are loaded, but always on the keys of the newly initialized model
Expand Down Expand Up @@ -1669,9 +1673,9 @@ def load(module: nn.Module, prefix=""):
# Make sure we are able to load base models as well as derived models (with heads)
start_prefix = ""
model_to_load = model
if not hasattr(model, cls.base_model_prefix) and has_prefix_module:
if len(cls.base_model_prefix) > 0 and not hasattr(model, cls.base_model_prefix) and has_prefix_module:
start_prefix = cls.base_model_prefix + "."
if hasattr(model, cls.base_model_prefix) and not has_prefix_module:
if len(cls.base_model_prefix) > 0 and hasattr(model, cls.base_model_prefix) and not has_prefix_module:
model_to_load = getattr(model, cls.base_model_prefix)
if any(key in expected_keys_not_prefixed for key in loaded_keys):
raise ValueError(
Expand Down
5 changes: 4 additions & 1 deletion tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2105,7 +2105,10 @@ def test_no_super_init_config_and_model(self):
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)

model = NoSuperInitModel.from_pretrained(tmp_dir)
new_model = NoSuperInitModel.from_pretrained(tmp_dir)

for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))


@require_torch
Expand Down
2 changes: 0 additions & 2 deletions utils/test_module/custom_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

class CustomModel(PreTrainedModel):
config_class = CustomConfig
base_model_prefix = "custom"

def __init__(self, config):
super().__init__(config)
Expand All @@ -22,7 +21,6 @@ def _init_weights(self, module):

class NoSuperInitModel(PreTrainedModel):
config_class = NoSuperInitConfig
base_model_prefix = "custom"

def __init__(self, config):
super().__init__(config)
Expand Down

0 comments on commit d1fcc90

Please sign in to comment.