Skip to content

Commit

Permalink
MOE: Fix top2 aux loss
Browse files Browse the repository at this point in the history
MoE aux loss is based on https://arxiv.org/pdf/2006.16668.pdf, Algo 1.
For top1 aux loss is implemented as:
    l_aux = torch.sum(me * ce) * num_experts
Where, for top2 aux loss is implemented as:
    l_aux = torch.sum(me * ce) * num_experts * num_experts

Based on Algo 1, no reason to have an extra multiplications by num_experts.

Signed-off-by: Moshe Island <misland@habana.ai>
  • Loading branch information
misland-habana authored and mosheisland committed Feb 21, 2024
1 parent 3e0c35f commit 692d42d
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion deepspeed/moe/sharded_moe.py
Expand Up @@ -309,7 +309,7 @@ def top2gating(logits: Tensor,
# Compute l_aux
me = torch.mean(gates, dim=0)
ce = torch.mean(mask1.float(), dim=0)
l_aux = torch.mean(me * ce) * num_experts * num_experts
l_aux = torch.mean(me * ce) * num_experts

# gating decisions
exp_counts = torch.sum(mask1 + mask2, dim=0)
Expand Down

0 comments on commit 692d42d

Please sign in to comment.