Skip to content

Commit

Permalink
[feat] moe: initial implementation of Top2Gating (#118)
Browse files Browse the repository at this point in the history
  • Loading branch information
msbaines committed Oct 2, 2020
1 parent 2eee136 commit 7815f6f
Show file tree
Hide file tree
Showing 6 changed files with 177 additions and 1 deletion.
8 changes: 7 additions & 1 deletion fairscale/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

from .moe import Top2Gate
from .pipe import Pipe

__all__ = ["Pipe"]
__all__ = ["Pipe", "Top2Gate"]
6 changes: 6 additions & 0 deletions fairscale/nn/moe/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

from .top2gate import Top2Gate
111 changes: 111 additions & 0 deletions fairscale/nn/moe/top2gate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

# Implementation of Top2Gating described in https://arxiv.org/pdf/2006.16668.pdf
# Code is inspired by Top2GatingOnLogits from lingvo:
# https://github.com/tensorflow/lingvo/blob/21b8106c5f1d30a196c98eedc441d4fd70833b11/lingvo/core/moe_layers.py#L477

from typing import Tuple

import torch
from torch import Tensor
import torch.nn.functional as F

gumbel = torch.distributions.gumbel.Gumbel(0, 1) # type: ignore


def top2gating(logits: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]:
"""Implements Top2Gating on logits."""
gates = F.softmax(logits, dim=2)
min_logit = torch.finfo(logits.dtype).min # type: ignore

# gates has shape of GSE
num_tokens = gates.shape[1]
num_experts = gates.shape[2]
# capacity = 2S/E
capacity = 2 * num_tokens // num_experts
assert num_tokens % num_experts == 0

# Create a mask for 1st's expert per token
indices1_gs = torch.argmax(gates, dim=2)
mask1 = F.one_hot(indices1_gs, num_classes=num_experts)

# Create a mask for 2nd's expert per token using Gumbel-max trick
# https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
logits_w_noise = logits + gumbel.rsample(logits.shape)
# Replace top-expert with min value
mins = torch.full_like(logits, min_logit)
logits_except1 = torch.where(mask1.bool(), mins, logits_w_noise)
indices2_gs = torch.argmax(logits_except1, dim=2)
mask2 = F.one_hot(indices2_gs, num_classes=num_experts)

# Compute locations in capacity buffer
locations1 = torch.cumsum(mask1, dim=1) - 1
locations2 = torch.cumsum(mask2, dim=1) - 1
# Update 2nd's location by accounting for locations of 1st
locations2 += torch.sum(mask1, dim=1, keepdim=True)

# Compute l_aux
me = torch.mean(gates, dim=1)
ce = torch.mean(mask1.float(), dim=1)
l_aux = torch.mean(me * ce)

# Remove locations outside capacity from mask
mask1 *= torch.lt(locations1, capacity)
mask2 *= torch.lt(locations2, capacity)

# Store the capacity location for each token
locations1_gs = torch.einsum("gse,gse->gs", locations1, mask1)
locations2_gs = torch.einsum("gse,gse->gs", locations2, mask2)

# Normalize gate probabilities
mask1_float = mask1.float()
mask2_float = mask2.float()
gates1_gs = torch.einsum("gse,gse->gs", gates, mask1_float)
gates2_gs = torch.einsum("gse,gse->gs", gates, mask2_float)
denom_gs = gates1_gs + gates2_gs
# Avoid divide-by-zero
denom_gs = torch.where(denom_gs > 0, denom_gs, torch.ones_like(denom_gs))
gates1_gs /= denom_gs
gates2_gs /= denom_gs

# Calculate combine_weights and dispatch_mask
gates1 = torch.einsum("gs,gse->gse", gates1_gs, mask1_float)
gates2 = torch.einsum("gs,gse->gse", gates2_gs, mask2_float)
locations1_gsc = F.one_hot(locations1_gs, num_classes=capacity)
locations2_gsc = F.one_hot(locations2_gs, num_classes=capacity)
combine1_gsec = torch.einsum("gse,gsc->gsec", gates1, locations1_gsc)
combine2_gsec = torch.einsum("gse,gsc->gsec", gates2, locations2_gsc)
combine_weights = combine1_gsec + combine2_gsec
dispatch_mask = combine_weights.bool()

return l_aux, combine_weights, dispatch_mask


class Top2Gate(torch.nn.Module):
"""Gate module which implements Top2Gating as described in Gshard_.
::
gate = Top2Gate(model_dim, num_experts)
l_aux, combine_weights, dispatch_mask = gate(input)
.. Gshard_: https://arxiv.org/pdf/2006.16668.pdf
Args:
model_dim (int):
size of model embedding dimension
num_experts (ints):
number of experts in model
"""

wg: torch.nn.Linear

def __init__(self, model_dim: int, num_experts: int,) -> None:
super().__init__()
self.wg = torch.nn.Linear(num_experts, model_dim, bias=False)

def forward(self, input: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]: # type: ignore
logits = torch.einsum("gsm,me -> gse", input, self.wg.weight)
return top2gating(logits)
2 changes: 2 additions & 0 deletions stubs/torch/functional.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ from typing import Tuple, List, Union

def split(tensor: Tensor, split_size_or_sections: Union[int, List[int]], dim: int=0) -> Tuple[Tensor,...]: ...

def einsum(equation: str, *operands: Tensor): ...

Empty file added tests/nn/moe/__init__.py
Empty file.
51 changes: 51 additions & 0 deletions tests/nn/moe/test_top2gating.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import pytest
import torch

from fairscale.nn import Top2Gate
from fairscale.nn.moe.top2gate import top2gating


def test_create():
gate = Top2Gate(4, 8)


def test_forward():
torch.manual_seed(3)
input = torch.randn(3, 12, 4)
gate = Top2Gate(4, 6)
capacity = 2 * 12 // 6
l_aux, combine_weights, dispatch_mask = gate(input)
assert pytest.approx(l_aux.item(), 0.0283)
assert combine_weights.shape == (3, 12, 6, 4)
assert dispatch_mask.shape == (3, 12, 6, 4)
assert torch.equal(combine_weights.bool(), dispatch_mask)
assert torch.all(torch.sum(dispatch_mask, axis=(1, 3)) <= capacity)
assert torch.all(combine_weights >= 0.0)
assert torch.all(combine_weights <= 1.0)
weights_sum = torch.sum(combine_weights).item()
assert round(weights_sum) == pytest.approx(weights_sum)
# For this random seed, we get 36 slots filled.
assert weights_sum == pytest.approx(36.0)


# Verify that top gate is allocated capacity as per Algorithm 1 in GShard paper.
def test_top1s():
num_tokens = 8
num_experts = 4
logits = torch.randn(1, num_tokens, num_experts)
l_aux, _, dispatch_mask = top2gating(logits)
top1s = torch.argmax(logits, dim=2)
capacity = 2 * num_tokens // num_experts
ce = [0] * num_experts
locations = [0] * num_tokens
for i, s in enumerate(top1s[0]):
e = s.item()
loc = ce[e]
ce[e] = loc + 1
if ce[e] < capacity:
assert dispatch_mask[0][i][e][loc]

0 comments on commit 7815f6f

Please sign in to comment.