New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Lazy init] Fix edge cases #11615
[Lazy init] Fix edge cases #11615
Conversation
retrieved_modules = [] | ||
# retrieve all modules that has at least one missing weight name | ||
for name, module in self.named_modules(): | ||
if remove_prefix: | ||
name = ".".join(name.split(".")[1:]) if name.startswith(self.base_model_prefix) else name | ||
elif add_prefix: | ||
name = ".".join([self.base_model_prefix, name]) | ||
name = ".".join([self.base_model_prefix, name]) if len(name) > 0 else self.base_model_prefix |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
weird edge case in which the module is just an empty string ''
which will then lead to bert.
instead of the correct bert
key
@@ -1347,13 +1350,17 @@ def load(module: nn.Module, prefix=""): | |||
def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=False): | |||
module_keys = set([".".join(key.split(".")[:-1]) for key in names]) | |||
|
|||
# torch.nn.ParameterList is a special case where two parameter keywords | |||
# are appended to the module name, *e.g.* bert.special_embeddings.0 | |||
module_keys = module_keys.union(set([".".join(key.split(".")[:-2]) for key in names if key[-1].isdigit()])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.nn.ParameterList
is special in a sense that it adds to word with '.'
to a module so that we need to remove the last to "."
's
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
This PR fixes the flaky circle ci regarding tests such as
tests/test_modeling_xlnet.py::XLNetModelTest::test_save_load_fast_init_to_base
.Luckily, the test caught those edge cases ;-)