Skip to content

Commit

Permalink
[feat] moe: initial implementation of MOELayer (#128)
Browse files Browse the repository at this point in the history
Currently only implemented for a single process and expert.
  • Loading branch information
msbaines committed Oct 8, 2020
1 parent 82dbd5d commit 22ff665
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 1 deletion.
2 changes: 1 addition & 1 deletion fairscale/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

from .moe import Top2Gate
from .moe import MOELayer, Top2Gate
from .pipe import Pipe

__all__ = ["Pipe", "Top2Gate"]
1 change: 1 addition & 0 deletions fairscale/nn/moe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

from .moelayer import MOELayer
from .top2gate import Top2Gate
64 changes: 64 additions & 0 deletions fairscale/nn/moe/moelayer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

from typing import TYPE_CHECKING, Any

import torch
from torch import Tensor
from torch.nn import Module

if TYPE_CHECKING:
Base = Module[Tensor]
else:
Base = Module

# einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity
# See https://arxiv.org/pdf/2006.16668.pdf for details.


class MOELayer(Base):
"""MOELayer module which implements MixtureOfExperts as described in Gshard_.
::
gate = Top2Gate(model_dim, num_experts)
moe = MOELayer(gate, expert)
l_aux, combine_weights, dispatch_mask = moe(input)
.. Gshard_: https://arxiv.org/pdf/2006.16668.pdf
Args:
gate (torch.nn.Module):
gate network
expert (torch.nn.Module):
expert network
"""

def __init__(self, gate: Module, expert: Module) -> None:
super().__init__()
self.gate = gate
self.expert = expert

def all_to_all_dispatch(self, dispatch_mask: Tensor, input: Tensor) -> Tensor:
dispatched_input = torch.einsum("gsec,gsm->egcm", dispatch_mask.float(), input)
# TODO(msb) all-to-all
dispatched_input = torch.squeeze(dispatched_input, 0) # drop E dimension
return dispatched_input

def all_to_all_combine(self, combine_weights: Tensor, input: Tensor) -> Tensor:
# TODO(msb) all-to-all
expert_output = torch.unsqueeze(input, 1) # add E dimension
output = torch.einsum("gsec,gecm->gsm", combine_weights, expert_output)
return output

def forward(self, *input: Any, **kwargs: Any) -> Tensor:
# Implement Algorithm 2 from GShard paper.
shape = input[0].shape
# Reshape into S tokens per group.
reshaped_input = input[0].reshape(shape[0], -1, shape[3])
self.l_aux, combine_weights, dispatching_mask = self.gate(reshaped_input)
dispatched_input = self.all_to_all_dispatch(dispatching_mask, reshaped_input)
expert_output = self.expert(dispatched_input)
combined_output = self.all_to_all_combine(combine_weights, expert_output)
return combined_output.reshape(shape)
53 changes: 53 additions & 0 deletions tests/nn/moe/test_moelayer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import pytest
import torch

from fairscale.nn import MOELayer, Top2Gate

skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")


def test_create():
model_dim = 8
num_experts = 4
gate = Top2Gate(model_dim, num_experts)
expert = torch.nn.Linear(model_dim, model_dim)
moe = MOELayer(gate, expert)


@skip_if_no_cuda
def test_create_cuda():
model_dim = 8
num_experts = 4
gate = Top2Gate(model_dim, num_experts)
expert = torch.nn.Linear(model_dim, model_dim)
moe = MOELayer(gate, expert).cuda()


def do_test_forward(device):
model_dim = 8
num_experts = 1
input = torch.randn(3, 4, 16, model_dim).to(device)
gate = Top2Gate(model_dim, num_experts)
expert = torch.nn.Linear(model_dim, model_dim, bias=False)
# Use identity matrix
expert.weight = torch.nn.Parameter(torch.eye(model_dim))
moe = MOELayer(gate, expert).to(device)
output = moe(input)
assert moe.l_aux.item() == 1.0
assert output.shape == input.shape
# Re-assembled output should match input due to identity expert.
assert torch.equal(input, output)


def test_forward_cpu():
do_test_forward("cpu")


@skip_if_no_cuda
def test_forward_cuda():
do_test_forward("cuda")

0 comments on commit 22ff665

Please sign in to comment.