-
Notifications
You must be signed in to change notification settings - Fork 351
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to use l2l.algorithms.MAML correctly with nn.DistributedDataParallel? #170
Comments
Hello @AyanamiReiFan, and thanks for the kind words. Parallelizing MAML with opt = optim.Adam(model.parameters())
opt = Distributed(model.parameters(), opt, sync=1)
# Training code
opt.step() If you want to parallelize the model over GPUs, I would use learner = maml.clone()
learner = torch.nn.DataParallel(learner, device_ids=[0, 1])
# Training code Let me know if you ever find a solution to using |
Thanks very much! |
I'm training a 1-Way 5-Shot Segmentation model on MAML, so the batch size on training can only be 5. So I think it will not speed up a lot by parallelizing the adapt and evaluate action in each iteration in
This is why I tried to parallelize the MAML module and want to let the MAML to calculate different batch on many gpus, but it seems that my effort to use Thanks very much! |
Hello @AyanamiReiFan, Maybe this helps. |
Thanks very much! @janbolle |
@janbolle That's an exciting work!
Thank you! |
|
@janbolle Thanks for the reply!
:D |
Closing since dormant. Feel free to reopen. |
I have a large batch that cannot fit into one 2080Ti GPU (11G). I have tried:
But all memory still goes to one GPU. Is there an easy way to get around this? Thanks. |
@zhaozj89 This worked for me: learner = model.clone()
learner.module = torch.nn.DataParallel(learner.module, device_ids=[0, 1]) |
see this thread: #197 it seems you can use a lighting wrapper to parallelize MAML. Not tried it myself yet...but I assume it works. Seems DDP is tricky to work for technical reasons I don't understant. |
Hi, here's my implementation of ParallellMAML using Learn2Learn's LightningMAML + PyTorch Lightning DDP: https://gist.github.com/SungFeng-Huang/dec22eef5650f5a74d24a732ffd0080f |
This work is awesome!
Using
nn.DistributedDataParallel
in the following way will raise Error when executelearner = maml.clone()
How to use it correctly? Should I use nn.DistributedDataParallel on MyModel and then use MAML?
Thanks!
model = MyModel()
maml = l2l.algorithms.MAML(model, lr=0.5)
model = nn.DistributedDataPrallel(model, device_ids=[rank])
...
learner = maml.clone()
The text was updated successfully, but these errors were encountered: