Skip to content

Commit

Permalink
[refactor] moe: simplify logic removing top expert (#125)
Browse files Browse the repository at this point in the history
  • Loading branch information
msbaines committed Oct 6, 2020
1 parent 662667d commit 6e7ad79
Showing 1 changed file with 1 addition and 3 deletions.
4 changes: 1 addition & 3 deletions fairscale/nn/moe/top2gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor:
def top2gating(logits: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]:
"""Implements Top2Gating on logits."""
gates = F.softmax(logits, dim=2)
min_logit = torch.finfo(logits.dtype).min # type: ignore

# gates has shape of GSE
num_tokens = gates.shape[1]
Expand All @@ -46,8 +45,7 @@ def top2gating(logits: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]:
# https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
# Replace top-expert with min value
mins = torch.full_like(logits, min_logit)
logits_except1 = torch.where(mask1.bool(), mins, logits_w_noise)
logits_except1 = logits_w_noise.masked_fill(mask1.bool(), float("-inf"))
indices2_gs = torch.argmax(logits_except1, dim=2)
mask2 = F.one_hot(indices2_gs, num_classes=num_experts)

Expand Down

0 comments on commit 6e7ad79

Please sign in to comment.