-
Notifications
You must be signed in to change notification settings - Fork 551
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feat] Mixture of Experts #181
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
# | ||
# This source code is licensed under the BSD license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
|
||
import logging | ||
from dataclasses import dataclass | ||
from enum import Enum | ||
from typing import Any, Callable, Optional, Union | ||
|
||
import torch | ||
|
||
from xformers.components import Activation | ||
from xformers.components.feedforward import ( | ||
Feedforward, | ||
FeedforwardConfig, | ||
register_feedforward, | ||
) | ||
|
||
_is_fairscale_available = True | ||
|
||
try: | ||
import torch.distributed as dist | ||
from fairscale.nn import MOELayer, Top2Gate | ||
|
||
from xformers.components.feedforward import MLP | ||
|
||
except ImportError: | ||
logging.warning( | ||
"Either FairScale or torch distributed is not available, MixtureOfExperts will not be exposed." | ||
" Please install them if you would like to use MoE" | ||
) | ||
_is_fairscale_available = False | ||
|
||
|
||
if _is_fairscale_available: | ||
|
||
# Credits: initially implemented in FairScale for sanity checking | ||
class RoundRobinGate(torch.nn.Module): | ||
def __init__(self, model_dim, num_experts): | ||
super().__init__() | ||
self.model_dim = model_dim | ||
self.num_experts = num_experts | ||
|
||
def forward(self, input): | ||
s = input.shape[0] | ||
assert s % self.num_experts == 0, f"{s} % {self.num_experts} != 0" | ||
capacity = 2 * s // self.num_experts | ||
output = torch.zeros( | ||
s, self.num_experts, capacity, dtype=input.dtype, device=input.device | ||
) | ||
for i in range(s): | ||
output[i, i % self.num_experts, i // self.num_experts] = 1.0 | ||
return 0.0, output, output.bool() | ||
|
||
class GateConfig(str, Enum): | ||
RoundRobin = "round_robin" | ||
Top2 = "top_2" | ||
# Other gating techniques could be exposed here | ||
|
||
@dataclass | ||
class MoEConfig(FeedforwardConfig): | ||
number_of_experts: int | ||
gate: GateConfig | ||
number_of_local_experts: Optional[int] = None | ||
expert_constructor: Optional[Any] = None | ||
hidden_layer_multiplier: Optional[int] = None | ||
group: Optional[Any] = None | ||
|
||
@register_feedforward("MixtureOfExperts", MoEConfig) | ||
class MixtureOfExperts(Feedforward): | ||
""" | ||
A MLP variant which uses the "Mixture of Experts" paradigm, as described in Gshard_. | ||
xFormers uses the FairScale_ implementation under the hood. | ||
|
||
.. warning: Please note that most of the benefits of MoE are present in a distributed training environmentt | ||
|
||
.. _Gshard: https://arxiv.org/pdf/2006.16668.pdf | ||
.. _FairScale: https://github.com/facebookresearch/fairscale/ | ||
""" | ||
|
||
def __init__( | ||
self, | ||
dim_model: int, | ||
dropout: float, | ||
activation: Activation, | ||
number_of_experts: int, | ||
gate: Union[GateConfig, torch.nn.Module], | ||
number_of_local_experts: Optional[int] = None, | ||
expert_constructor: Optional[Callable[[], torch.nn.Module]] = None, | ||
hidden_layer_multiplier: Optional[int] = None, | ||
group: Optional[Any] = None, | ||
*_, | ||
**__, | ||
): | ||
super().__init__() | ||
|
||
# Handle a possibly uninitialized process group | ||
assert ( | ||
dist.is_initialized() | ||
), "Mixture of Experts require torch distributed to be initialized" | ||
|
||
if number_of_local_experts is not None: | ||
assert number_of_experts >= number_of_local_experts | ||
else: | ||
if dist.get_world_size() == 1: | ||
logging.warning("Local experts no specified but world size of 1") | ||
logging.warning("Assuming that all experts are local") | ||
number_of_local_experts = number_of_experts | ||
else: | ||
number_of_local_experts = 1 | ||
|
||
# Programatically handle the gating technique | ||
if not isinstance(gate, torch.nn.Module): | ||
gate_constructor = { | ||
GateConfig.RoundRobin: RoundRobinGate, | ||
GateConfig.Top2: Top2Gate, | ||
}[gate] | ||
|
||
self.gate = gate_constructor(dim_model, number_of_experts) | ||
else: | ||
self.gate = gate | ||
|
||
# Programatically handle the experts | ||
if expert_constructor is None: | ||
|
||
multiplier = ( | ||
hidden_layer_multiplier | ||
if hidden_layer_multiplier is not None | ||
else 4 | ||
) | ||
|
||
def expert_constructor() -> torch.nn.Module: | ||
return MLP(dim_model, dropout, activation, multiplier) | ||
|
||
assert expert_constructor is not None | ||
|
||
local_experts = torch.nn.ModuleList( | ||
[expert_constructor() for _ in range(number_of_local_experts)] | ||
) | ||
|
||
self.moe = MOELayer(gate=self.gate, experts=local_experts, group=group) | ||
|
||
self.requires_cuda = True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm missing context here, is this used somewhere? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's an "old" flag, makes it easier for CI to test something or not test it depending on the HW needs without maintaining an escape list in different places (I think that it came from the Triton parts). We can change that, it was "a" way to solve this |
||
|
||
def forward(self, inputs: torch.Tensor) -> torch.Tensor: | ||
# FairScale MoE assumes that the dimensions are [S, B, E] | ||
# xFormers assumes [B, S, E] | ||
return self.moe(inputs.movedim(0, 1)).movedim(0, 1) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
# | ||
# This source code is licensed under the BSD license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
|
||
import tempfile | ||
|
||
import torch | ||
|
||
|
||
def init_torch_distributed_local(): | ||
if torch.distributed.is_initialized(): | ||
return | ||
|
||
init_url = "file://" + tempfile.mkstemp()[1] | ||
backend = ( | ||
torch.distributed.Backend.NCCL | ||
if torch.cuda.is_available() | ||
else torch.distributed.Backend.GLOO | ||
) | ||
torch.distributed.init_process_group( | ||
backend=backend, | ||
rank=0, | ||
world_size=1, | ||
init_method=init_url, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not true anymore :D