Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MOE gate fixes and enhancements #5156

Merged
merged 13 commits into from Mar 7, 2024
Merged

Conversation

mosheisland
Copy link
Contributor

Fixes the following issues:

  • Fix capacity when using TP for non-MoE by aligning the capacity to TP
  • Fix TopKGate.wg (gate weight) when using ZeRO with fp16 or bf16
  • Fix top2 aux loss to be similar to top1 aux loss

Following are few configurable enhancements:

  • Support top2 with disable token dropping
  • Support disable top2 2nd expert sampling

@tjruwase tjruwase requested review from tohtana and removed request for tohtana February 20, 2024 16:17
@mosheisland mosheisland force-pushed the moe/gate branch 2 times, most recently from 4706fa2 to 8f9d75c Compare February 21, 2024 07:07
Currently, disable token dropping is only integrated in top1 gate logic.
This commit integrates disable token dropping into top2 logic.

Signed-off-by: Moshe Island <misland@habana.ai>
MoE aux loss is based on https://arxiv.org/pdf/2006.16668.pdf, Algo 1.
For top1 aux loss is implemented as:
    l_aux = torch.sum(me * ce) * num_experts
Where, for top2 aux loss is implemented as:
    l_aux = torch.sum(me * ce) * num_experts * num_experts

Based on Algo 1, no reason to have an extra multiplications by num_experts.

Signed-off-by: Moshe Island <misland@habana.ai>
DeepSpeed's MoE top2 gating performs sampling to select 2nd expert.
Support disabling of sampling (i.e. using argmax).
This is configurable while the default is to perform 2nd expert sampling.

Signed-off-by: Moshe Island <misland@habana.ai>
When non-expert layers use TP and experts do not use TP, we drop duplicate
tokens sent to experts.
Dropping duplicate tokens is done by slicing the tokens tensor sent to experts
where each expert handles only 1/TP of the tokens.
However, for that, we need to make sure that the capacity is divisible by TP.

Signed-off-by: Moshe Island <misland@habana.ai>
Currently, during forward, topkgate gate linear layer is converted to fp32.
This is forbidden since the linear layer params are a view into deepspeed's
parameter flat buffer.
To fix it, use torch.nn.functional.linear with gate.weight.float().

Signed-off-by: Moshe Island <misland@habana.ai>
super().__init__()

# Only top-1 and top-2 are supported at the moment.
if k != 1 and k != 2:
raise ValueError('Only top-1 and top-2 gatings are supported.')
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False).float()
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ykim362 - can you please review this part? I remember we forced .float() here because we wanted the gate weight to always be fp32 even if everything else was fp16.

# Compute l_aux
me = torch.mean(gates, dim=0)
ce = torch.mean(mask1.float(), dim=0)
l_aux = torch.mean(me * ce) * num_experts * num_experts
l_aux = torch.mean(me * ce) * num_experts
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ykim362 - FYI. I think the current value of l_aux in top-2 was giving us good convergence. I am not sure if we change it, we will need to run training experiments to verify any regression in loss.

@mosheisland - have you trained the models using top-2 and you see that l_aux with your change gives better convergence/loss?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not run full training before and after my changes and compare.
I added this change to better aligns to original paper.
Since you saw better results with the current formulation, lets keep it as-is.
I will upload a new commit that reverts this change.
BTW, in the case of using same num_experts across all MOE layers, this can be modified by changing the ----moe-loss-coeff.

@awan-10
Copy link
Contributor

awan-10 commented Feb 21, 2024

@mosheisland -- Thank you very much for working on this PR! Overall, it looks good to me but given that we don't have the bandwidth to test the convergence, it will really help us and users if you can do two things:

  1. Protect any changes to l_aux, capacity padding, etc. with user-configurable flags and set their default values to be False.

    • e.g. add a flag like pad_capacity=False as a default and then only do this if pad_capacity=True is passed by the user. This is to avoid any regression as many users are using this in the current form. If we don't know if new code is going to give better results or previous code, at least, we should have the option to have new things disabled by default.
  2. Share any before your changes and after your changes convergence plots for any model that you have trained with new changes. That data will help the users to refer back to this PR in case setting pad_capacity=True or using new l_aux value will change any training/loss-curves for them.

@awan-10 awan-10 self-assigned this Feb 21, 2024
@mosheisland
Copy link
Contributor Author

  1. Protect any changes to l_aux, capacity padding, etc. with user-configurable flags and set their default values to be False.
  • e.g. add a flag like pad_capacity=False as a default and then only do this if pad_capacity=True is passed by the user. This is to avoid any regression as many users are using this in the current form. If we don't know if new code is going to give better results or previous code, at least, we should have the option to have new things disabled by default.

l_aux - I will revert this change.

capacity padding - this bug is reproduced when you use drop-tokens=False and non-expert-tp>1.
The padding itself is with 0-ed rows and has no real effect on training.

  1. Share any before your changes and after your changes convergence plots for any model that you have trained with new changes. That data will help the users to refer back to this PR in case setting pad_capacity=True or using new l_aux value will change any training/loss-curves for them.

l_aux - no need, I will revert
capacity padding - I can't share before the fix since it crashes

@mosheisland
Copy link
Contributor Author

Reverted l_aux commit

@tjruwase tjruwase added this pull request to the merge queue Mar 6, 2024
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks Mar 6, 2024
@tjruwase tjruwase added this pull request to the merge queue Mar 7, 2024
Merged via the queue into microsoft:master with commit 5a2e705 Mar 7, 2024
12 checks passed
ShellyNR pushed a commit to ShellyNR/DeepSpeed that referenced this pull request Mar 11, 2024
Fixes the following issues:
- Fix capacity when using TP for non-MoE by aligning the capacity to TP
- Fix TopKGate.wg (gate weight) when using ZeRO with fp16 or bf16
- Fix top2 aux loss to be similar to top1 aux loss

Following are few configurable enhancements:
- Support top2 with disable token dropping
- Support disable top2 2nd expert sampling

---------

Signed-off-by: Moshe Island <misland@habana.ai>
Co-authored-by: Moshe Island <misland@habana.ai>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
rraminen pushed a commit to ROCm/DeepSpeed that referenced this pull request May 9, 2024
Fixes the following issues:
- Fix capacity when using TP for non-MoE by aligning the capacity to TP
- Fix TopKGate.wg (gate weight) when using ZeRO with fp16 or bf16
- Fix top2 aux loss to be similar to top1 aux loss

Following are few configurable enhancements:
- Support top2 with disable token dropping
- Support disable top2 2nd expert sampling

---------

Signed-off-by: Moshe Island <misland@habana.ai>
Co-authored-by: Moshe Island <misland@habana.ai>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants