Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
MoE aux loss is based on https://arxiv.org/pdf/2006.16668.pdf, Algo 1. For top1 aux loss is implemented as: l_aux = torch.sum(me * ce) * num_experts Where, for top2 aux loss is implemented as: l_aux = torch.sum(me * ce) * num_experts * num_experts Based on Algo 1, no reason to have an extra multiplications by num_experts. Signed-off-by: Moshe Island <misland@habana.ai>
- Loading branch information