New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MOE gate fixes and enhancements #5156
Conversation
4706fa2
to
8f9d75c
Compare
Currently, disable token dropping is only integrated in top1 gate logic. This commit integrates disable token dropping into top2 logic. Signed-off-by: Moshe Island <misland@habana.ai>
MoE aux loss is based on https://arxiv.org/pdf/2006.16668.pdf, Algo 1. For top1 aux loss is implemented as: l_aux = torch.sum(me * ce) * num_experts Where, for top2 aux loss is implemented as: l_aux = torch.sum(me * ce) * num_experts * num_experts Based on Algo 1, no reason to have an extra multiplications by num_experts. Signed-off-by: Moshe Island <misland@habana.ai>
DeepSpeed's MoE top2 gating performs sampling to select 2nd expert. Support disabling of sampling (i.e. using argmax). This is configurable while the default is to perform 2nd expert sampling. Signed-off-by: Moshe Island <misland@habana.ai>
When non-expert layers use TP and experts do not use TP, we drop duplicate tokens sent to experts. Dropping duplicate tokens is done by slicing the tokens tensor sent to experts where each expert handles only 1/TP of the tokens. However, for that, we need to make sure that the capacity is divisible by TP. Signed-off-by: Moshe Island <misland@habana.ai>
Currently, during forward, topkgate gate linear layer is converted to fp32. This is forbidden since the linear layer params are a view into deepspeed's parameter flat buffer. To fix it, use torch.nn.functional.linear with gate.weight.float(). Signed-off-by: Moshe Island <misland@habana.ai>
8f9d75c
to
aab9fc3
Compare
super().__init__() | ||
|
||
# Only top-1 and top-2 are supported at the moment. | ||
if k != 1 and k != 2: | ||
raise ValueError('Only top-1 and top-2 gatings are supported.') | ||
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False).float() | ||
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ykim362 - can you please review this part? I remember we forced .float() here because we wanted the gate weight to always be fp32 even if everything else was fp16.
deepspeed/moe/sharded_moe.py
Outdated
# Compute l_aux | ||
me = torch.mean(gates, dim=0) | ||
ce = torch.mean(mask1.float(), dim=0) | ||
l_aux = torch.mean(me * ce) * num_experts * num_experts | ||
l_aux = torch.mean(me * ce) * num_experts |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ykim362 - FYI. I think the current value of l_aux in top-2 was giving us good convergence. I am not sure if we change it, we will need to run training experiments to verify any regression in loss.
@mosheisland - have you trained the models using top-2 and you see that l_aux with your change gives better convergence/loss?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did not run full training before and after my changes and compare.
I added this change to better aligns to original paper.
Since you saw better results with the current formulation, lets keep it as-is.
I will upload a new commit that reverts this change.
BTW, in the case of using same num_experts across all MOE layers, this can be modified by changing the ----moe-loss-coeff.
@mosheisland -- Thank you very much for working on this PR! Overall, it looks good to me but given that we don't have the bandwidth to test the convergence, it will really help us and users if you can do two things:
|
l_aux - I will revert this change. capacity padding - this bug is reproduced when you use drop-tokens=False and non-expert-tp>1.
l_aux - no need, I will revert |
This reverts commit 692d42d.
Reverted l_aux commit |
Fixes the following issues: - Fix capacity when using TP for non-MoE by aligning the capacity to TP - Fix TopKGate.wg (gate weight) when using ZeRO with fp16 or bf16 - Fix top2 aux loss to be similar to top1 aux loss Following are few configurable enhancements: - Support top2 with disable token dropping - Support disable top2 2nd expert sampling --------- Signed-off-by: Moshe Island <misland@habana.ai> Co-authored-by: Moshe Island <misland@habana.ai> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Fixes the following issues: - Fix capacity when using TP for non-MoE by aligning the capacity to TP - Fix TopKGate.wg (gate weight) when using ZeRO with fp16 or bf16 - Fix top2 aux loss to be similar to top1 aux loss Following are few configurable enhancements: - Support top2 with disable token dropping - Support disable top2 2nd expert sampling --------- Signed-off-by: Moshe Island <misland@habana.ai> Co-authored-by: Moshe Island <misland@habana.ai> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Fixes the following issues:
Following are few configurable enhancements: