From 692d42df29a9450ca69e56d1aea82a904d5665de Mon Sep 17 00:00:00 2001 From: Moshe Island Date: Sun, 18 Feb 2024 17:01:24 +0200 Subject: [PATCH] MOE: Fix top2 aux loss MoE aux loss is based on https://arxiv.org/pdf/2006.16668.pdf, Algo 1. For top1 aux loss is implemented as: l_aux = torch.sum(me * ce) * num_experts Where, for top2 aux loss is implemented as: l_aux = torch.sum(me * ce) * num_experts * num_experts Based on Algo 1, no reason to have an extra multiplications by num_experts. Signed-off-by: Moshe Island --- deepspeed/moe/sharded_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/moe/sharded_moe.py b/deepspeed/moe/sharded_moe.py index 23cd9137587e..ca45472efa18 100644 --- a/deepspeed/moe/sharded_moe.py +++ b/deepspeed/moe/sharded_moe.py @@ -309,7 +309,7 @@ def top2gating(logits: Tensor, # Compute l_aux me = torch.mean(gates, dim=0) ce = torch.mean(mask1.float(), dim=0) - l_aux = torch.mean(me * ce) * num_experts * num_experts + l_aux = torch.mean(me * ce) * num_experts # gating decisions exp_counts = torch.sum(mask1 + mask2, dim=0)