Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix multi-GPU runtime error with multi-scale netD #40

Merged
merged 1 commit into from
Nov 22, 2018
Merged

Fix multi-GPU runtime error with multi-scale netD #40

merged 1 commit into from
Nov 22, 2018

Conversation

cuihaoleo
Copy link
Contributor

This PR solves #34 .

About the original issue: When torch.nn.DataParallel replicates a D_NLayersMulti object to multiple devices, although all submodules are replicated, self.model (which is an instance of ListModule) is not replicated because PyTorch doesn't know how to copy it correctly. As a result, all replicated D_NLayersMulti instances have self.model pointing to the same ListModule object whose module attribute points to original D_NLayersMulti object before replication. You can check id(self.model) in D_NLayersMulti.forward to verify it.

I removed the usage of ListModule class, and made D_NLayersMulti.forward directly call submodule.

@junyanz
Copy link
Owner

junyanz commented Nov 22, 2018

Thanks for the fix. Really helpful.

@junyanz junyanz merged commit dc9aa60 into junyanz:master Nov 22, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants