Skip to content

Commit

Permalink
[feat] moe: add all_to_all support (#134)
Browse files Browse the repository at this point in the history
  • Loading branch information
msbaines committed Oct 14, 2020
1 parent 177151e commit 6d802f5
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 32 deletions.
8 changes: 8 additions & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ install_dep_15: &install_dep_15
- run:
name: Install Dependencies
command: |
sudo apt-get install -y mpi-default-dev
pip install --progress-bar off torch==1.5.1+cu101 -f https://download.pytorch.org/whl/torch_stable.html
pip install --progress-bar off -r requirements-test.txt
python -c 'import torch; print("Torch version:", torch.__version__)'
Expand All @@ -51,6 +52,7 @@ install_dep_16: &install_dep_16
- run:
name: Install Dependencies
command: |
sudo apt-get install -y mpi-default-dev
pip install --progress-bar off torch==1.6.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
pip install --progress-bar off -r requirements-test.txt
python -c 'import torch; print("Torch version:", torch.__version__)'
Expand Down Expand Up @@ -84,6 +86,12 @@ run_unittests: &run_unittests
command: |
pytest --junitxml=test-results/junit.xml --verbose
run_mpi_unittests: &run_mpi_unittests
- run:
name: Run MPI Unit Tests
command: |
mpirun -n4 python -m pytest -only-mpi --junitxml=test-results/junit.xml --verbose
run_flake8: &run_flake8
- run:
name: Run Linter (flake8)
Expand Down
27 changes: 18 additions & 9 deletions fairscale/nn/moe/moelayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Optional

import torch
from torch import Tensor
import torch.distributed as dist
from torch.nn import Module

if TYPE_CHECKING:
Expand All @@ -24,7 +25,8 @@ class MOELayer(Base):
gate = Top2Gate(model_dim, num_experts)
moe = MOELayer(gate, expert)
l_aux, combine_weights, dispatch_mask = moe(input)
output = moe(input)
l_aux = moe.l_aux
.. Gshard_: https://arxiv.org/pdf/2006.16668.pdf
Expand All @@ -35,24 +37,31 @@ class MOELayer(Base):
expert network
"""

def __init__(self, gate: Module, expert: Module) -> None:
def __init__(self, gate: Module, expert: Module, group: Optional[Any] = None) -> None:
super().__init__()
self.gate = gate
self.expert = expert
self.group = group if group is not None else dist.group.WORLD
self.world_size = dist.get_world_size(self.group)

def all_to_all_dispatch(self, dispatch_mask: Tensor, input: Tensor) -> Tensor:
dispatched_input = torch.einsum("gsec,gsm->egcm", dispatch_mask.float(), input)
# TODO(msb) all-to-all
dispatched_input = torch.squeeze(dispatched_input, 0) # drop E dimension
dispatched_input = dispatched_input.contiguous()
chunks = list(dispatched_input.chunk(self.world_size))
dist.all_to_all(chunks, chunks, self.group)
return dispatched_input

def all_to_all_combine(self, combine_weights: Tensor, input: Tensor) -> Tensor:
# TODO(msb) all-to-all
expert_output = torch.unsqueeze(input, 1) # add E dimension
output = torch.einsum("gsec,gecm->gsm", combine_weights, expert_output)
expert_output = input.contiguous()
chunks = list(expert_output.chunk(self.world_size))
dist.all_to_all(chunks, chunks, self.group)
output = torch.einsum("gsec,egcm->gsm", combine_weights, expert_output)
return output

def forward(self, *input: Any, **kwargs: Any) -> Tensor:
def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
assert len(input) == 1, "only single input Tensor supported"
assert len(input[0].shape) == 4, "input Tensor must have dimensions: (g)roup, (s)equence, (t)oken, (m)odel"

# Implement Algorithm 2 from GShard paper.
shape = input[0].shape
# Reshape into S tokens per group.
Expand Down
2 changes: 2 additions & 0 deletions requirements-test.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
black == 19.10b0
flake8 == 3.7.9
isort == 4.3.21
mpi4py == 3.0.3
mypy == 0.770
pytest == 5.4.1
pytest-cov == 2.10.0
pytest-mpi == 0.4
torchtext == 0.6.0
torch >= 1.5.1
torchvision >= 0.6.0
Expand Down
7 changes: 6 additions & 1 deletion stubs/torch/distributed/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ import datetime

from . import rpc as rpc

class Backend: ...
class Backend:
GLOO: str
MPI: str
NCCL: str

class ProcessGroup:
def size(self) -> int: ...
Expand All @@ -29,8 +32,10 @@ def broadcast(tensor: Tensor, src: Any, group: Any, async_op: Any = False): ...

def is_initialized() -> bool: ...

def init_process_group(backend: Union[str, Backend], timeout: datetime.timedelta = datetime.timedelta(0, 1800), rank: Optional[int] = None, world_size: Optional[int] = None): ...
def new_group(ranks: List[int], timeout: datetime.timedelta = datetime.timedelta(0, 1800), backend: Union[None, str, Backend] = None): ...

def all_to_all(output: List[Tensor], intput: List[Tensor], group:Optional[ProcessGroup] = None, async_op: bool = False): ...
def all_reduce(tensor: Tensor, op: ReduceOp = ReduceOp.SUM, group:Optional[ProcessGroup] = None, async_op: bool = False): ...
def all_gather(tensor_list: List[Tensor], tensor: Tensor, group:Optional[ProcessGroup] = None, async_op: bool = False): ...

Expand Down
53 changes: 31 additions & 22 deletions tests/nn/moe/test_moelayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,51 +3,60 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import os

import pytest
import torch
import torch.distributed as dist

from fairscale.nn import MOELayer, Top2Gate

skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")

BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO # type: ignore

def test_create():
model_dim = 8
num_experts = 4
gate = Top2Gate(model_dim, num_experts)
expert = torch.nn.Linear(model_dim, model_dim)
moe = MOELayer(gate, expert)
if torch.cuda.is_available():
devices = ["cpu", "cuda"]
else:
devices = ["cpu"]

os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29501"
if "OMPI_COMM_WORLD_SIZE" in os.environ:
dist.init_process_group(backend=dist.Backend.MPI)


def setup_module(module):
if "OMPI_COMM_WORLD_SIZE" not in os.environ:
dist.init_process_group(backend=BACKEND, rank=0, world_size=1)


@skip_if_no_cuda
def test_create_cuda():
def teardown_module(module):
if "OMPI_COMM_WORLD_SIZE" not in os.environ:
torch.distributed.destroy_process_group()


@pytest.mark.parametrize("device", devices)
def test_create(device):
model_dim = 8
num_experts = 4
gate = Top2Gate(model_dim, num_experts)
expert = torch.nn.Linear(model_dim, model_dim)
moe = MOELayer(gate, expert).cuda()
moe = MOELayer(gate, expert).to(device)


def do_test_forward(device):
@pytest.mark.mpi
@pytest.mark.parametrize("device", ["cpu"])
def test_forward(device):
model_dim = 8
num_experts = 1
num_experts = dist.get_world_size(dist.group.WORLD)
input = torch.randn(3, 4, 16, model_dim).to(device)
gate = Top2Gate(model_dim, num_experts)
expert = torch.nn.Linear(model_dim, model_dim, bias=False)
# Use identity matrix
expert.weight = torch.nn.Parameter(torch.eye(model_dim))
moe = MOELayer(gate, expert).to(device)
output = moe(input)
assert moe.l_aux.item() == 1.0
assert output.shape == input.shape
# Re-assembled output should match input due to identity expert.
assert torch.equal(input, output)


def test_forward_cpu():
do_test_forward("cpu")


@skip_if_no_cuda
def test_forward_cuda():
do_test_forward("cuda")
assert torch.allclose(input, output)
4 changes: 4 additions & 0 deletions tests/optim/test_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ def setup_module(module):
dist.init_process_group(backend=BACKEND, rank=0, world_size=1)


def teardown_module(module):
torch.distributed.destroy_process_group()


def dist_init(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29501"
Expand Down

0 comments on commit 6d802f5

Please sign in to comment.