-
Notifications
You must be signed in to change notification settings - Fork 225
[integrating megablocks with open_lm] Question about megablocks + FSDP #57
Description
Hello! I'm trying to integrate megablocks with our open source LLM training library (open_lm), which uses native torch FSDP.
For some reason I am consistently seeing worse performance than my dense baselines at the same compute budget. I'm a bit stumped as to why, and I was wondering if you could provide any pointers on things to watch out for wrt integrations.
To integrate your library:
-
I replace our feedforward layer with a Megablocks MoE: https://github.com/mlfoundations/open_lm/pull/115/files#diff-ae20e8018bed746c1e2ec41d171f99755ef714e018909a94b5097687ea80c3a4R230-R239
-
I add a megablock_moe at every other layer (like in Switch/GShard): https://github.com/mlfoundations/open_lm/pull/115/files#diff-ae20e8018bed746c1e2ec41d171f99755ef714e018909a94b5097687ea80c3a4R319-R322
-
I add the load balancing loss: https://github.com/mlfoundations/open_lm/pull/115/files#diff-52f1d1b2a425f93f57435d70bdac50ded09d6495ee2b417f999869c18cbbcbd9R198-R203
Am I missing anything else?
One hypothesis I have is that something is going wrong when I use Megablocks with FSDP.
Here is our FSDP wrapper:
https://github.com/mlfoundations/open_lm/blob/main/open_lm/main.py#L454-L462
Is there anything I need to change in the FSDP arguments to make sure FSDP doesn't interfere with all2alls? Currently it wraps the Transformer Block module, of which the MoE is a part of.
Thanks for your help!