From 4ab360a042560588a2fd1309faf134bf6e409706 Mon Sep 17 00:00:00 2001 From: Moshe Island Date: Sun, 18 Feb 2024 17:26:08 +0200 Subject: [PATCH] MOE: Fix capacity when using TP for non-MoE When non-expert layers use TP and experts do not use TP, we drop duplicate tokens sent to experts. Dropping duplicate tokens is done by slicing the tokens tensor sent to experts where each expert handles only 1/TP of the tokens. However, for that, we need to make sure that the capacity is divisible by TP. Signed-off-by: Moshe Island --- deepspeed/moe/sharded_moe.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/deepspeed/moe/sharded_moe.py b/deepspeed/moe/sharded_moe.py index b9613eff1ede..745ca853aadb 100644 --- a/deepspeed/moe/sharded_moe.py +++ b/deepspeed/moe/sharded_moe.py @@ -214,6 +214,11 @@ def top1gating(logits: Tensor, if not drop_tokens: new_capacity = torch.max(exp_counts).to(logits.device) dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.get_world_group()) + if groups._get_expert_model_parallel_world_size() == 1: + # If the non-expert is tensor-parallel, we need to pad the capacity to 'tp'. + # This is since we are going to activate drop_tokens() to drop duplicate tokens. + tp = 1 if groups.mpu is None else groups.mpu.get_tensor_model_parallel_world_size() + new_capacity = torch.ceil(new_capacity / tp).mul(tp).to(new_capacity.dtype) capacity = new_capacity # Compute l_aux @@ -326,6 +331,11 @@ def top2gating(logits: Tensor, # Do not drop tokens - set capacity according to current expert assignments new_capacity = torch.max(exp_counts) dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.get_world_group()) + if groups._get_expert_model_parallel_world_size() == 1: + # If the non-expert is tensor-parallel, we need to pad the capacity to 'tp'. + # This is since we are going to activate drop_tokens() to drop duplicate tokens. + tp = 1 if groups.mpu is None else groups.mpu.get_tensor_model_parallel_world_size() + new_capacity = torch.ceil(new_capacity / tp).mul(tp).to(new_capacity.dtype) capacity = new_capacity # Store the capacity location for each token