Skip to content

Commit

Permalink
MOE: Fix capacity when using TP for non-MoE
Browse files Browse the repository at this point in the history
When non-expert layers use TP and experts do not use TP, we drop duplicate
tokens sent to experts.
Dropping duplicate tokens is done by slicing the tokens tensor sent to experts
where each expert handles only 1/TP of the tokens.
However, for that, we need to make sure that the capacity is divisible by TP.

Signed-off-by: Moshe Island <misland@habana.ai>
  • Loading branch information
misland-habana authored and mosheisland committed Feb 21, 2024
1 parent 88eca85 commit 4ab360a
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions deepspeed/moe/sharded_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,11 @@ def top1gating(logits: Tensor,
if not drop_tokens:
new_capacity = torch.max(exp_counts).to(logits.device)
dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.get_world_group())
if groups._get_expert_model_parallel_world_size() == 1:
# If the non-expert is tensor-parallel, we need to pad the capacity to 'tp'.
# This is since we are going to activate drop_tokens() to drop duplicate tokens.
tp = 1 if groups.mpu is None else groups.mpu.get_tensor_model_parallel_world_size()
new_capacity = torch.ceil(new_capacity / tp).mul(tp).to(new_capacity.dtype)
capacity = new_capacity

# Compute l_aux
Expand Down Expand Up @@ -326,6 +331,11 @@ def top2gating(logits: Tensor,
# Do not drop tokens - set capacity according to current expert assignments
new_capacity = torch.max(exp_counts)
dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.get_world_group())
if groups._get_expert_model_parallel_world_size() == 1:
# If the non-expert is tensor-parallel, we need to pad the capacity to 'tp'.
# This is since we are going to activate drop_tokens() to drop duplicate tokens.
tp = 1 if groups.mpu is None else groups.mpu.get_tensor_model_parallel_world_size()
new_capacity = torch.ceil(new_capacity / tp).mul(tp).to(new_capacity.dtype)
capacity = new_capacity

# Store the capacity location for each token
Expand Down

0 comments on commit 4ab360a

Please sign in to comment.