Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Mixture of Experts #181

Merged
merged 3 commits into from
Jan 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed
- bugfix Favor, single feature map [#183]

### Added
- Mixture of Experts [#181]

## [0.0.8] - 2022-01-07
### Fixed
- Much faster fused dropout [#164]
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ Patrick et al., 2021](https://arxiv.org/abs/2106.05392)*

- [MLP](xformers/components/feedforward/mlp.py)
- [Fused](xformers/components/feedforward/fused_mlp.py)
- [Mixture of Experts](xformers/components/feedforward/mixture_of_experts.py)

</p></details>

Expand Down Expand Up @@ -211,3 +212,4 @@ The following repositories are used in xFormers, either in close to original for
* [LucidRain Reformer](https://github.com/lucidrains/reformer-pytorch)
* [RevTorch](https://github.com/RobinBruegger/RevTorch)
* [Nystromformer](https://github.com/mlpen/Nystromformer)
* [FairScale](https://github.com/facebookresearch/fairscale/)
3 changes: 3 additions & 0 deletions requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,6 @@ git+git://github.com/rwightman/pytorch-image-models@v0.4.5#egg=timm

# Dependency for factory
hydra-core >= 1.1

# Dependency for Mixture of Experts
fairscale >= 0.4.5
13 changes: 12 additions & 1 deletion tests/test_block_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
xFormerEncoderBlock,
xFormerEncoderConfig,
)
from xformers.helpers.test_utils import init_torch_distributed_local

BATCH = 20
BATCH = 4
SEQ = 128
MODEL = 96
DROPOUT = 0.5
Expand Down Expand Up @@ -79,8 +80,13 @@ def test_xformer_encoder_block(
"dropout": DROPOUT,
"activation": activation,
"hidden_layer_multiplier": 4,
"number_of_experts": 4,
"gate": "top_2",
}

if feedforward_name == "MixtureOfExperts":
init_torch_distributed_local()

position_encoding_config = {
"name": "sine",
"dim_model": MODEL,
Expand Down Expand Up @@ -169,8 +175,13 @@ def test_xformer_decoder_block(
"dropout": DROPOUT,
"activation": activation,
"hidden_layer_multiplier": 4,
"number_of_experts": 4,
"gate": "top_2",
}

if feedforward_name == "MixtureOfExperts":
init_torch_distributed_local()

position_encoding_config = {
"name": "sine",
"dim_model": MODEL,
Expand Down
43 changes: 42 additions & 1 deletion tests/test_feedforward.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@

from xformers.components import Activation
from xformers.components.feedforward import FEEDFORWARD_REGISTRY, build_feedforward
from xformers.components.feedforward.mixture_of_experts import GateConfig
from xformers.helpers.test_utils import init_torch_distributed_local

BATCH = 20
BATCH = 4
SEQ = 512
EMBD = 16
LATENT = 128
Expand All @@ -34,8 +36,13 @@ def test_feedforward(
"dropout": DROPOUT,
"activation": activation,
"hidden_layer_multiplier": 4,
"number_of_experts": 4, # MoE
"gate": "top_2", # MoE
}

if feedforward_name == "MixtureOfExperts":
init_torch_distributed_local()

# dummy, just check construction and dimensions in the FW pass
ffw = build_feedforward(test_config)

Expand All @@ -47,3 +54,37 @@ def test_feedforward(
ffw = ffw.to(device)

_ = ffw(inputs)


def get_expert():
return torch.nn.Linear(LATENT, LATENT, bias=False)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="This test requires CUDA")
@pytest.mark.parametrize("gate", [g.value for g in GateConfig])
@pytest.mark.parametrize("number_of_local_experts", [None, 4])
@pytest.mark.parametrize("expert_constructor", [None, get_expert])
def test_moe(gate, number_of_local_experts, expert_constructor):
test_config = {
"name": "MixtureOfExperts",
"dim_model": LATENT,
"dropout": DROPOUT,
"activation": Activation.ReLU,
"hidden_layer_multiplier": 4,
"number_of_experts": 4,
"number_of_local_experts": number_of_local_experts,
"gate": gate,
"expert_constructor": expert_constructor,
}

init_torch_distributed_local()

# dummy, just check construction and dimensions in the FW pass
ffw = build_feedforward(test_config)

inputs = torch.rand(BATCH, SEQ, LATENT, device=torch.device("cuda"))
ffw = ffw.to(torch.device("cuda"))

outputs = ffw(inputs)
loss = torch.sum(outputs)
loss.backward()
2 changes: 2 additions & 0 deletions tests/test_model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
"activation": "relu",
"hidden_layer_multiplier": 4,
"dim_model": EMB,
"number_of_experts": 4,
"gate_config": "top_2",
},
}

Expand Down
2 changes: 0 additions & 2 deletions xformers/components/feedforward/fused_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ class FusedMlpConfig(FeedforwardConfig):
class FusedMLP(Feedforward):
"""
A MLP using fused linear layers.

.. warning: This is not currently competitive with PyTorch in terms of training speed
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not true anymore :D

"""

def __init__(
Expand Down
150 changes: 150 additions & 0 deletions xformers/components/feedforward/mixture_of_experts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.


import logging
from dataclasses import dataclass
from enum import Enum
from typing import Any, Callable, Optional, Union

import torch

from xformers.components import Activation
from xformers.components.feedforward import (
Feedforward,
FeedforwardConfig,
register_feedforward,
)

_is_fairscale_available = True

try:
import torch.distributed as dist
from fairscale.nn import MOELayer, Top2Gate

from xformers.components.feedforward import MLP

except ImportError:
logging.warning(
"Either FairScale or torch distributed is not available, MixtureOfExperts will not be exposed."
" Please install them if you would like to use MoE"
)
_is_fairscale_available = False


if _is_fairscale_available:

# Credits: initially implemented in FairScale for sanity checking
class RoundRobinGate(torch.nn.Module):
def __init__(self, model_dim, num_experts):
super().__init__()
self.model_dim = model_dim
self.num_experts = num_experts

def forward(self, input):
s = input.shape[0]
assert s % self.num_experts == 0, f"{s} % {self.num_experts} != 0"
capacity = 2 * s // self.num_experts
output = torch.zeros(
s, self.num_experts, capacity, dtype=input.dtype, device=input.device
)
for i in range(s):
output[i, i % self.num_experts, i // self.num_experts] = 1.0
return 0.0, output, output.bool()

class GateConfig(str, Enum):
RoundRobin = "round_robin"
Top2 = "top_2"
# Other gating techniques could be exposed here

@dataclass
class MoEConfig(FeedforwardConfig):
number_of_experts: int
gate: GateConfig
number_of_local_experts: Optional[int] = None
expert_constructor: Optional[Any] = None
hidden_layer_multiplier: Optional[int] = None
group: Optional[Any] = None

@register_feedforward("MixtureOfExperts", MoEConfig)
class MixtureOfExperts(Feedforward):
"""
A MLP variant which uses the "Mixture of Experts" paradigm, as described in Gshard_.
xFormers uses the FairScale_ implementation under the hood.

.. warning: Please note that most of the benefits of MoE are present in a distributed training environmentt

.. _Gshard: https://arxiv.org/pdf/2006.16668.pdf
.. _FairScale: https://github.com/facebookresearch/fairscale/
"""

def __init__(
self,
dim_model: int,
dropout: float,
activation: Activation,
number_of_experts: int,
gate: Union[GateConfig, torch.nn.Module],
number_of_local_experts: Optional[int] = None,
expert_constructor: Optional[Callable[[], torch.nn.Module]] = None,
hidden_layer_multiplier: Optional[int] = None,
group: Optional[Any] = None,
*_,
**__,
):
super().__init__()

# Handle a possibly uninitialized process group
assert (
dist.is_initialized()
), "Mixture of Experts require torch distributed to be initialized"

if number_of_local_experts is not None:
assert number_of_experts >= number_of_local_experts
else:
if dist.get_world_size() == 1:
logging.warning("Local experts no specified but world size of 1")
logging.warning("Assuming that all experts are local")
number_of_local_experts = number_of_experts
else:
number_of_local_experts = 1

# Programatically handle the gating technique
if not isinstance(gate, torch.nn.Module):
gate_constructor = {
GateConfig.RoundRobin: RoundRobinGate,
GateConfig.Top2: Top2Gate,
}[gate]

self.gate = gate_constructor(dim_model, number_of_experts)
else:
self.gate = gate

# Programatically handle the experts
if expert_constructor is None:

multiplier = (
hidden_layer_multiplier
if hidden_layer_multiplier is not None
else 4
)

def expert_constructor() -> torch.nn.Module:
return MLP(dim_model, dropout, activation, multiplier)

assert expert_constructor is not None

local_experts = torch.nn.ModuleList(
[expert_constructor() for _ in range(number_of_local_experts)]
)

self.moe = MOELayer(gate=self.gate, experts=local_experts, group=group)

self.requires_cuda = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm missing context here, is this used somewhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's an "old" flag, makes it easier for CI to test something or not test it depending on the HW needs without maintaining an escape list in different places (I think that it came from the Triton parts). We can change that, it was "a" way to solve this


def forward(self, inputs: torch.Tensor) -> torch.Tensor:
# FairScale MoE assumes that the dimensions are [S, B, E]
# xFormers assumes [B, S, E]
return self.moe(inputs.movedim(0, 1)).movedim(0, 1)
27 changes: 27 additions & 0 deletions xformers/helpers/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.


import tempfile

import torch


def init_torch_distributed_local():
if torch.distributed.is_initialized():
return

init_url = "file://" + tempfile.mkstemp()[1]
backend = (
torch.distributed.Backend.NCCL
if torch.cuda.is_available()
else torch.distributed.Backend.GLOO
)
torch.distributed.init_process_group(
backend=backend,
rank=0,
world_size=1,
init_method=init_url,
)