Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MOE gate fixes and enhancements #5156

Merged
merged 13 commits into from Mar 7, 2024
7 changes: 5 additions & 2 deletions deepspeed/moe/layer.py
Expand Up @@ -32,6 +32,7 @@ class MoE(nn.Module):
use_rts (bool, optional): default=True, whether to use Random Token Selection.
use_tutel (bool, optional): default=False, whether to use Tutel optimizations (if installed).
enable_expert_tensor_parallelism (bool, optional): default=False, whether to use tensor parallelism for experts
top2_2nd_expert_sampling (bool, optional): default=True, whether to perform sampling for 2nd expert
"""

def __init__(self,
Expand All @@ -48,7 +49,8 @@ def __init__(self,
drop_tokens: bool = True,
use_rts: bool = True,
use_tutel: bool = False,
enable_expert_tensor_parallelism: bool = False) -> None:
enable_expert_tensor_parallelism: bool = False,
top2_2nd_expert_sampling: bool = True) -> None:

super(MoE, self).__init__()

Expand All @@ -69,7 +71,8 @@ def __init__(self,

experts = Experts(expert, self.num_local_experts, self.expert_group_name)
self.deepspeed_moe = MOELayer(TopKGate(hidden_size, num_experts, k, capacity_factor, eval_capacity_factor,
min_capacity, noisy_gate_policy, drop_tokens, use_rts),
min_capacity, noisy_gate_policy, drop_tokens, use_rts,
top2_2nd_expert_sampling),
experts,
self.expert_group_name,
self.ep_size,
Expand Down
63 changes: 42 additions & 21 deletions deepspeed/moe/sharded_moe.py
Expand Up @@ -214,6 +214,11 @@ def top1gating(logits: Tensor,
if not drop_tokens:
new_capacity = torch.max(exp_counts).to(logits.device)
dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.get_world_group())
if groups._get_expert_model_parallel_world_size() == 1:
# If the non-expert is tensor-parallel, we need to pad the capacity to 'tp'.
# This is since we are going to activate drop_tokens() to drop duplicate tokens.
tp = 1 if groups.mpu is None else groups.mpu.get_tensor_model_parallel_world_size()
new_capacity = torch.ceil(new_capacity / tp).mul(tp).to(new_capacity.dtype)
capacity = new_capacity

# Compute l_aux
Expand Down Expand Up @@ -279,23 +284,27 @@ def top1gating(logits: Tensor,
return l_aux, combine_weights, dispatch_mask, exp_counts


def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
def top2gating(logits: Tensor,
capacity_factor: float,
min_capacity: int,
drop_tokens: bool = True,
top2_2nd_expert_sampling: bool = True) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Implements Top2Gating on logits."""
# everything is in fp32 in this function
gates = F.softmax(logits, dim=1)

capacity = _capacity(gates, torch.tensor(capacity_factor * 2), torch.tensor(min_capacity))

# Create a mask for 1st's expert per token
indices1_s = torch.argmax(gates, dim=1)
num_experts = int(gates.shape[1])
mask1 = F.one_hot(indices1_s, num_classes=num_experts)

# Create a mask for 2nd's expert per token using Gumbel-max trick
# https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
if top2_2nd_expert_sampling:
# Create a mask for 2nd's expert per token using Gumbel-max trick
# https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
logits += gumbel_rsample(logits.shape, device=logits.device)

# Replace top-expert with min value
logits_except1 = logits_w_noise.masked_fill(mask1.bool(), float("-inf"))
logits_except1 = logits.masked_fill(mask1.bool(), float("-inf"))
indices2_s = torch.argmax(logits_except1, dim=1)
mask2 = F.one_hot(indices2_s, num_classes=num_experts)

Expand All @@ -305,17 +314,29 @@ def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tup
# Update 2nd's location by accounting for locations of 1st
locations2 += torch.sum(mask1, dim=0, keepdim=True)

# gating decisions
exp_counts = torch.sum(mask1, dim=0).detach().to('cpu')

# Compute l_aux
me = torch.mean(gates, dim=0)
ce = torch.mean(mask1.float(), dim=0)
l_aux = torch.mean(me * ce) * num_experts * num_experts
l_aux = torch.mean(me * ce) * num_experts
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ykim362 - FYI. I think the current value of l_aux in top-2 was giving us good convergence. I am not sure if we change it, we will need to run training experiments to verify any regression in loss.

@mosheisland - have you trained the models using top-2 and you see that l_aux with your change gives better convergence/loss?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not run full training before and after my changes and compare.
I added this change to better aligns to original paper.
Since you saw better results with the current formulation, lets keep it as-is.
I will upload a new commit that reverts this change.
BTW, in the case of using same num_experts across all MOE layers, this can be modified by changing the ----moe-loss-coeff.


# gating decisions
exp_counts = torch.sum(mask1 + mask2, dim=0)

# Remove locations outside capacity from mask
mask1 *= torch.lt(locations1, capacity)
mask2 *= torch.lt(locations2, capacity)
if drop_tokens:
# Calculate configured capacity and remove locations outside capacity from mask
capacity = _capacity(gates, torch.tensor(capacity_factor * 2), torch.tensor(min_capacity))
mask1 *= torch.lt(locations1, capacity)
mask2 *= torch.lt(locations2, capacity)
else:
# Do not drop tokens - set capacity according to current expert assignments
new_capacity = torch.max(exp_counts)
dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.get_world_group())
if groups._get_expert_model_parallel_world_size() == 1:
awan-10 marked this conversation as resolved.
Show resolved Hide resolved
# If the non-expert is tensor-parallel, we need to pad the capacity to 'tp'.
# This is since we are going to activate drop_tokens() to drop duplicate tokens.
tp = 1 if groups.mpu is None else groups.mpu.get_tensor_model_parallel_world_size()
new_capacity = torch.ceil(new_capacity / tp).mul(tp).to(new_capacity.dtype)
capacity = new_capacity

# Store the capacity location for each token
locations1_s = torch.sum(locations1 * mask1, dim=1)
Expand All @@ -342,7 +363,7 @@ def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tup
combine_weights = combine1_sec + combine2_sec
dispatch_mask = combine_weights.bool()

return l_aux, combine_weights, dispatch_mask, exp_counts
return l_aux, combine_weights, dispatch_mask, exp_counts.detach().to('cpu')


class TopKGate(Module):
Expand Down Expand Up @@ -372,13 +393,14 @@ def __init__(self,
min_capacity: int = 8,
noisy_gate_policy: Optional[str] = None,
drop_tokens: bool = True,
use_rts: bool = True) -> None:
use_rts: bool = True,
top2_2nd_expert_sampling: bool = True) -> None:
super().__init__()

# Only top-1 and top-2 are supported at the moment.
if k != 1 and k != 2:
raise ValueError('Only top-1 and top-2 gatings are supported.')
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False).float()
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ykim362 - can you please review this part? I remember we forced .float() here because we wanted the gate weight to always be fp32 even if everything else was fp16.

self.k = k
self.capacity_factor = capacity_factor
self.eval_capacity_factor = eval_capacity_factor
Expand All @@ -389,6 +411,7 @@ def __init__(self,
self.gate_time = 0.0
self.drop_tokens = drop_tokens
self.use_rts = use_rts
self.top2_2nd_expert_sampling = top2_2nd_expert_sampling

def forward(self,
input: torch.Tensor,
Expand All @@ -398,13 +421,11 @@ def forward(self,
if self.wall_clock_breakdown:
self.timers(TOPK_GATE_TIMER).start()

if self.wg.weight.dtype != torch.float32:
self.wg = self.wg.float()
input_fp32 = input.float()
# input jittering
if self.noisy_gate_policy == 'Jitter' and self.training:
input_fp32 = multiplicative_jitter(input_fp32, device=input.device)
logits = self.wg(input_fp32)
logits = torch.nn.functional.linear(input_fp32, weight=self.wg.weight.float(), bias=None)

if self.k == 1:
gate_output = top1gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor,
Expand All @@ -413,7 +434,7 @@ def forward(self,

else:
gate_output = top2gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor,
self.min_capacity)
self.min_capacity, self.drop_tokens, self.top2_2nd_expert_sampling)

if self.wall_clock_breakdown:
self.timers(TOPK_GATE_TIMER).stop()
Expand Down