New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MOE gate fixes and enhancements #5156
Merged
Merged
Changes from 5 commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
3e0c35f
MOE: Support top2 with disable token dropping
misland-habana 692d42d
MOE: Fix top2 aux loss
misland-habana 953e698
MOE: Support disable top2 2nd expert sampling
misland-habana 64150db
MOE: Fix capacity when using TP for non-MoE
misland-habana aab9fc3
MOE: Fix gate conversion to fp32
misland-habana 1657955
Revert "MOE: Fix top2 aux loss"
misland-habana 05b2262
Merge branch 'master' into moe/gate
loadams 6c04451
Merge branch 'master' into moe/gate
mosheisland c8e05ea
Merge branch 'master' into moe/gate
mosheisland 7eb3633
Merge branch 'master' into moe/gate
mosheisland b76a541
Merge branch 'master' into moe/gate
mosheisland 3d227ef
Merge branch 'master' into moe/gate
mosheisland 30118cf
Merge branch 'master' into moe/gate
mosheisland File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -214,6 +214,11 @@ def top1gating(logits: Tensor, | |
if not drop_tokens: | ||
new_capacity = torch.max(exp_counts).to(logits.device) | ||
dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.get_world_group()) | ||
if groups._get_expert_model_parallel_world_size() == 1: | ||
# If the non-expert is tensor-parallel, we need to pad the capacity to 'tp'. | ||
# This is since we are going to activate drop_tokens() to drop duplicate tokens. | ||
tp = 1 if groups.mpu is None else groups.mpu.get_tensor_model_parallel_world_size() | ||
new_capacity = torch.ceil(new_capacity / tp).mul(tp).to(new_capacity.dtype) | ||
capacity = new_capacity | ||
|
||
# Compute l_aux | ||
|
@@ -279,23 +284,27 @@ def top1gating(logits: Tensor, | |
return l_aux, combine_weights, dispatch_mask, exp_counts | ||
|
||
|
||
def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tuple[Tensor, Tensor, Tensor, Tensor]: | ||
def top2gating(logits: Tensor, | ||
capacity_factor: float, | ||
min_capacity: int, | ||
drop_tokens: bool = True, | ||
top2_2nd_expert_sampling: bool = True) -> Tuple[Tensor, Tensor, Tensor, Tensor]: | ||
"""Implements Top2Gating on logits.""" | ||
# everything is in fp32 in this function | ||
gates = F.softmax(logits, dim=1) | ||
|
||
capacity = _capacity(gates, torch.tensor(capacity_factor * 2), torch.tensor(min_capacity)) | ||
|
||
# Create a mask for 1st's expert per token | ||
indices1_s = torch.argmax(gates, dim=1) | ||
num_experts = int(gates.shape[1]) | ||
mask1 = F.one_hot(indices1_s, num_classes=num_experts) | ||
|
||
# Create a mask for 2nd's expert per token using Gumbel-max trick | ||
# https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/ | ||
logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device) | ||
if top2_2nd_expert_sampling: | ||
# Create a mask for 2nd's expert per token using Gumbel-max trick | ||
# https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/ | ||
logits += gumbel_rsample(logits.shape, device=logits.device) | ||
|
||
# Replace top-expert with min value | ||
logits_except1 = logits_w_noise.masked_fill(mask1.bool(), float("-inf")) | ||
logits_except1 = logits.masked_fill(mask1.bool(), float("-inf")) | ||
indices2_s = torch.argmax(logits_except1, dim=1) | ||
mask2 = F.one_hot(indices2_s, num_classes=num_experts) | ||
|
||
|
@@ -305,17 +314,29 @@ def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tup | |
# Update 2nd's location by accounting for locations of 1st | ||
locations2 += torch.sum(mask1, dim=0, keepdim=True) | ||
|
||
# gating decisions | ||
exp_counts = torch.sum(mask1, dim=0).detach().to('cpu') | ||
|
||
# Compute l_aux | ||
me = torch.mean(gates, dim=0) | ||
ce = torch.mean(mask1.float(), dim=0) | ||
l_aux = torch.mean(me * ce) * num_experts * num_experts | ||
l_aux = torch.mean(me * ce) * num_experts | ||
|
||
# gating decisions | ||
exp_counts = torch.sum(mask1 + mask2, dim=0) | ||
|
||
# Remove locations outside capacity from mask | ||
mask1 *= torch.lt(locations1, capacity) | ||
mask2 *= torch.lt(locations2, capacity) | ||
if drop_tokens: | ||
# Calculate configured capacity and remove locations outside capacity from mask | ||
capacity = _capacity(gates, torch.tensor(capacity_factor * 2), torch.tensor(min_capacity)) | ||
mask1 *= torch.lt(locations1, capacity) | ||
mask2 *= torch.lt(locations2, capacity) | ||
else: | ||
# Do not drop tokens - set capacity according to current expert assignments | ||
new_capacity = torch.max(exp_counts) | ||
dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.get_world_group()) | ||
if groups._get_expert_model_parallel_world_size() == 1: | ||
awan-10 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# If the non-expert is tensor-parallel, we need to pad the capacity to 'tp'. | ||
# This is since we are going to activate drop_tokens() to drop duplicate tokens. | ||
tp = 1 if groups.mpu is None else groups.mpu.get_tensor_model_parallel_world_size() | ||
new_capacity = torch.ceil(new_capacity / tp).mul(tp).to(new_capacity.dtype) | ||
capacity = new_capacity | ||
|
||
# Store the capacity location for each token | ||
locations1_s = torch.sum(locations1 * mask1, dim=1) | ||
|
@@ -342,7 +363,7 @@ def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tup | |
combine_weights = combine1_sec + combine2_sec | ||
dispatch_mask = combine_weights.bool() | ||
|
||
return l_aux, combine_weights, dispatch_mask, exp_counts | ||
return l_aux, combine_weights, dispatch_mask, exp_counts.detach().to('cpu') | ||
|
||
|
||
class TopKGate(Module): | ||
|
@@ -372,13 +393,14 @@ def __init__(self, | |
min_capacity: int = 8, | ||
noisy_gate_policy: Optional[str] = None, | ||
drop_tokens: bool = True, | ||
use_rts: bool = True) -> None: | ||
use_rts: bool = True, | ||
top2_2nd_expert_sampling: bool = True) -> None: | ||
super().__init__() | ||
|
||
# Only top-1 and top-2 are supported at the moment. | ||
if k != 1 and k != 2: | ||
raise ValueError('Only top-1 and top-2 gatings are supported.') | ||
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False).float() | ||
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ykim362 - can you please review this part? I remember we forced .float() here because we wanted the gate weight to always be fp32 even if everything else was fp16. |
||
self.k = k | ||
self.capacity_factor = capacity_factor | ||
self.eval_capacity_factor = eval_capacity_factor | ||
|
@@ -389,6 +411,7 @@ def __init__(self, | |
self.gate_time = 0.0 | ||
self.drop_tokens = drop_tokens | ||
self.use_rts = use_rts | ||
self.top2_2nd_expert_sampling = top2_2nd_expert_sampling | ||
|
||
def forward(self, | ||
input: torch.Tensor, | ||
|
@@ -398,13 +421,11 @@ def forward(self, | |
if self.wall_clock_breakdown: | ||
self.timers(TOPK_GATE_TIMER).start() | ||
|
||
if self.wg.weight.dtype != torch.float32: | ||
self.wg = self.wg.float() | ||
input_fp32 = input.float() | ||
# input jittering | ||
if self.noisy_gate_policy == 'Jitter' and self.training: | ||
input_fp32 = multiplicative_jitter(input_fp32, device=input.device) | ||
logits = self.wg(input_fp32) | ||
logits = torch.nn.functional.linear(input_fp32, weight=self.wg.weight.float(), bias=None) | ||
|
||
if self.k == 1: | ||
gate_output = top1gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor, | ||
|
@@ -413,7 +434,7 @@ def forward(self, | |
|
||
else: | ||
gate_output = top2gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor, | ||
self.min_capacity) | ||
self.min_capacity, self.drop_tokens, self.top2_2nd_expert_sampling) | ||
|
||
if self.wall_clock_breakdown: | ||
self.timers(TOPK_GATE_TIMER).stop() | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ykim362 - FYI. I think the current value of l_aux in top-2 was giving us good convergence. I am not sure if we change it, we will need to run training experiments to verify any regression in loss.
@mosheisland - have you trained the models using top-2 and you see that l_aux with your change gives better convergence/loss?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did not run full training before and after my changes and compare.
I added this change to better aligns to original paper.
Since you saw better results with the current formulation, lets keep it as-is.
I will upload a new commit that reverts this change.
BTW, in the case of using same num_experts across all MOE layers, this can be modified by changing the ----moe-loss-coeff.