Skip to content

Commit

Permalink
[feat] moe: add all_to_all backward support (#137)
Browse files Browse the repository at this point in the history
  • Loading branch information
msbaines committed Oct 16, 2020
1 parent 1e6c547 commit d99c445
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 13 deletions.
4 changes: 2 additions & 2 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ install_dep_15: &install_dep_15
- run:
name: Install Dependencies
command: |
sudo apt-get install -y mpi-default-dev
sudo apt-get install -y libopenmpi-dev
pip install --progress-bar off torch==1.5.1+cu101 -f https://download.pytorch.org/whl/torch_stable.html
pip install --progress-bar off -r requirements-test.txt
python -c 'import torch; print("Torch version:", torch.__version__)'
Expand All @@ -52,7 +52,7 @@ install_dep_16: &install_dep_16
- run:
name: Install Dependencies
command: |
sudo apt-get install -y mpi-default-dev
sudo apt-get install -y libopenmpi-dev
pip install --progress-bar off torch==1.6.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
pip install --progress-bar off -r requirements-test.txt
python -c 'import torch; print("Torch version:", torch.__version__)'
Expand Down
33 changes: 22 additions & 11 deletions fairscale/nn/moe/moelayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Optional, Tuple

import torch
from torch import Tensor
Expand All @@ -19,6 +19,24 @@
# See https://arxiv.org/pdf/2006.16668.pdf for details.


# Based on https://github.com/pytorch/pytorch/pull/40762
class _AllToAll(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor) -> Tensor: # type: ignore
ctx.group = group
world_size = dist.get_world_size(group)
input = input.contiguous()
output = torch.empty_like(input)
input_chunks = list(input.chunk(world_size))
output_chunks = list(output.chunk(world_size))
dist.all_to_all(output_chunks, input_chunks, group=group)
return output

@staticmethod
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]:
return (None, _AllToAll.apply(ctx.group, *grad_output))


class MOELayer(Base):
"""MOELayer module which implements MixtureOfExperts as described in Gshard_.
::
Expand All @@ -42,21 +60,14 @@ def __init__(self, gate: Module, expert: Module, group: Optional[Any] = None) ->
self.gate = gate
self.expert = expert
self.group = group if group is not None else dist.group.WORLD
self.world_size = dist.get_world_size(self.group)

def all_to_all_dispatch(self, dispatch_mask: Tensor, input: Tensor) -> Tensor:
dispatched_input = torch.einsum("gsec,gsm->egcm", dispatch_mask.float(), input)
dispatched_input = dispatched_input.contiguous()
chunks = list(dispatched_input.chunk(self.world_size))
dist.all_to_all(chunks, chunks, self.group)
return dispatched_input
return _AllToAll.apply(self.group, dispatched_input)

def all_to_all_combine(self, combine_weights: Tensor, input: Tensor) -> Tensor:
expert_output = input.contiguous()
chunks = list(expert_output.chunk(self.world_size))
dist.all_to_all(chunks, chunks, self.group)
output = torch.einsum("gsec,egcm->gsm", combine_weights, expert_output)
return output
expert_output = _AllToAll.apply(self.group, input)
return torch.einsum("gsec,egcm->gsm", combine_weights, expert_output)

def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
assert len(input) == 1, "only single input Tensor supported"
Expand Down
19 changes: 19 additions & 0 deletions tests/nn/moe/test_moelayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,22 @@ def test_forward(device):
assert output.shape == input.shape
# Re-assembled output should match input due to identity expert.
assert torch.allclose(input, output)


@pytest.mark.mpi
@pytest.mark.parametrize("device", ["cpu"])
def test_backward(device):
loss = torch.nn.MSELoss()
model_dim = 8
num_experts = dist.get_world_size(dist.group.WORLD)
input = torch.randn(3, 4, 16, model_dim).to(device)
gate = Top2Gate(model_dim, num_experts)
expert = torch.nn.Linear(model_dim, model_dim, bias=False)
# Use identity matrix
expert.weight = torch.nn.Parameter(torch.eye(model_dim))
moe = MOELayer(gate, expert).to(device)
output = moe(input)
assert output.shape == input.shape
output = loss(output, input)
output.backward()
assert torch.allclose(expert.weight.grad, torch.zeros_like(expert.weight))

0 comments on commit d99c445

Please sign in to comment.