Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] MMSeparateDistributedDataParallel skip init_weights #1042

Closed
2 tasks done
makecent opened this issue Apr 3, 2023 · 5 comments · Fixed by #1045
Closed
2 tasks done

[Bug] MMSeparateDistributedDataParallel skip init_weights #1042

makecent opened this issue Apr 3, 2023 · 5 comments · Fixed by #1045
Labels
bug Something isn't working

Comments

@makecent
Copy link

makecent commented Apr 3, 2023

Prerequisite

Reproduces the problem - code sample

I encountered this problem when running my own custom project based on mmengine, which is too complicated to present.

Additional information

The MMSeparateDistributedDataParallel is a model wrapper of model wrapper, whose module may contain the MMDistributedDataParallel.

sub_module = MMDistributedDataParallel(
module=sub_module.to(device),
broadcast_buffers=broadcast_buffers,
find_unused_parameters=find_unused_parameters,
**kwargs)
module._modules[name] = sub_module

When initializing the model_wrapper weights, the runner initialize the weight of the runner.model.module:

def _init_model_weights(self) -> None:
"""Initialize the model weights if the model has
:meth:`init_weights`"""
model = self.model.module if is_model_wrapper(
self.model) else self.model
if hasattr(model, 'init_weights'):
model.init_weights()
# sync params and buffers
for name, params in model.state_dict().items():
broadcast(params)

The above codes will cause the init_weights functions in model not work because the children module of the runner.model.module in this case is of type MMDistributedDataParallel which does NOT have init_weights function.

@makecent makecent added the bug Something isn't working label Apr 3, 2023
@zhouzaida
Copy link
Member

Hi, the module is defined in

Therefore, if you implement the init_weights method in the module, the runner can call it as expected.

@makecent
Copy link
Author

makecent commented Apr 3, 2023

@zhouzaida But the sub_modules in self.module are modified to be model_wrapper MMDistributedDataParallel:

for name, sub_module in module._modules.items():
# module without parameters.
if next(sub_module.parameters(), None) is None:
sub_module = sub_module.to(device)
elif all(not p.requires_grad for p in sub_module.parameters()):
sub_module = sub_module.to(device)
else:
sub_module = MMDistributedDataParallel(
module=sub_module.to(device),
broadcast_buffers=broadcast_buffers,
find_unused_parameters=find_unused_parameters,
**kwargs)
module._modules[name] = sub_module

Therefore, the MMSeparateDistributedDataParallel.module does have init_weights function but its sub-modules ( MMDistributedDataParallel) do not have init_weights function, which causes the init_weights functions of the module wrapped in the MMDistributedDataParallel not accessible during the initialzation.

@zhouzaida
Copy link
Member

zhouzaida commented Apr 3, 2023

Oh, the init_weights method is defined in the sub-module, so there might be an issue with that.

@HAOCHENYE
Copy link
Collaborator

@makecent Hi, I create a PR, will this solve the problem?

@makecent
Copy link
Author

makecent commented Apr 4, 2023

@HAOCHENYE LGTM. The initialization works as expected with the PR.

@makecent makecent closed this as completed Apr 4, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants