Skip to content

Commit

Permalink
Running microGPT example ! Needs some proper testing
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux committed Jan 17, 2022
1 parent 04bb6c1 commit 8d5654d
Show file tree
Hide file tree
Showing 7 changed files with 168 additions and 4 deletions.
7 changes: 5 additions & 2 deletions examples/microGPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,12 @@ def __init__(
},
},
"feedforward_config": {
"name": "FusedMLP", # Use MLP if Triton is not available
"name": "MixtureOfExperts", # Use MLP if Triton is not available
"dropout": self.hparams.mlp_pdrop,
"activation": "gelu",
"hidden_layer_multiplier": self.hparams.hidden_layer_multiplier,
"number_of_experts": 4,
"gate_config": "top_2",
},
}
]
Expand Down Expand Up @@ -273,7 +275,7 @@ def top_k_logits(logits, k):
# Adjust batch depending on the available memory on your machine.
# You can also use reversible layers to save memory
REF_BATCH = 512
BATCH = 256
BATCH = 64

WORKERS = 4
EPOCHS = 1
Expand Down Expand Up @@ -304,6 +306,7 @@ def top_k_logits(logits, k):
attention="scaled_dot_product",
warmup_tokens=REF_BATCH * WARMUP,
final_tokens=EPOCHS * len(train_dataset) * BLOCK,
hidden_layer_multiplier=1,
)
print(model)

Expand Down
3 changes: 3 additions & 0 deletions requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,6 @@ git+git://github.com/rwightman/pytorch-image-models@v0.4.5#egg=timm

# Dependency for factory
hydra-core >= 1.1

# Dependency for Mixture of Experts
fairscale >= 0.4.5
2 changes: 2 additions & 0 deletions tests/test_block_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def test_xformer_encoder_block(
"dropout": DROPOUT,
"activation": activation,
"hidden_layer_multiplier": 4,
"number_of_experts": 4,
"gate_config": "top_2",
}

position_encoding_config = {
Expand Down
2 changes: 2 additions & 0 deletions tests/test_feedforward.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def test_feedforward(
"dropout": DROPOUT,
"activation": activation,
"hidden_layer_multiplier": 4,
"number_of_experts": 4, # MoE
"gate_config": "top_2", # MoE
}

# dummy, just check construction and dimensions in the FW pass
Expand Down
2 changes: 2 additions & 0 deletions tests/test_model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
"activation": "relu",
"hidden_layer_multiplier": 4,
"dim_model": EMB,
"number_of_experts": 4,
"gate_config": "top_2",
},
}

Expand Down
2 changes: 0 additions & 2 deletions xformers/components/feedforward/fused_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ class FusedMlpConfig(FeedforwardConfig):
class FusedMLP(Feedforward):
"""
A MLP using fused linear layers.
.. warning: This is not currently competitive with PyTorch in terms of training speed
"""

def __init__(
Expand Down
154 changes: 154 additions & 0 deletions xformers/components/feedforward/mixture_of_experts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.


import logging
import tempfile
from dataclasses import dataclass
from enum import Enum
from typing import Any, Callable, Optional

import torch

from xformers.components import Activation
from xformers.components.feedforward import (
Feedforward,
FeedforwardConfig,
register_feedforward,
)

_is_fairscale_available = True

try:
import torch.distributed as dist
from fairscale.nn import MOELayer, Top2Gate

from xformers.components.feedforward import MLP

except ImportError:
logging.warning(
"Either FairScale or torch distributed is not available, MixtureOfExperts will not be exposed."
" Please install them if you would like to use MoE"
)
_is_fairscale_available = False


if _is_fairscale_available:

# Credits: initially implemented in FairScale for sanity checking
class RoundRobinGate(torch.nn.Module):
def __init__(self, model_dim, num_experts):
super().__init__()
self.model_dim = model_dim
self.num_experts = num_experts

def forward(self, input):
s = input.shape[0]
assert s % self.num_experts == 0, f"{s} % {self.num_experts} != 0"
capacity = 2 * s // self.num_experts
output = torch.zeros(
s, self.num_experts, capacity, dtype=input.dtype, device=input.device
)
for i in range(s):
output[i, i % self.num_experts, i // self.num_experts] = 1.0
return 0.0, output, output.bool()

class GateConfig(str, Enum):
RoundRobin = "round_robin"
Top2 = "top_2"
# Other gating techniques could be exposed here

@dataclass
class MoEConfig(FeedforwardConfig):
number_of_experts: int
gate_config: GateConfig
number_of_local_experts: Optional[int] = None
expert_constructor: Optional[Callable[[], torch.nn.Module]] = None
hidden_layer_multiplier: Optional[int] = None
group: Optional[Any] = None

@register_feedforward("MixtureOfExperts", MoEConfig)
class MixtureOfExperts(Feedforward):
"""
A MLP variant which uses the "Mixture of Experts" paradigm, as described in Gshard_.
xFormers uses the FairScale_ implementation under the hood.
.. warning: Please note that most of the benefits of MoE are present in a distributed training environmentt
.. _Gshard: https://arxiv.org/pdf/2006.16668.pdf
.. _FairScale: https://github.com/facebookresearch/fairscale/
"""

def __init__(
self,
dim_model: int,
dropout: float,
activation: Activation,
number_of_experts: int,
gate_config: GateConfig,
number_of_local_experts: Optional[int] = None,
expert_constructor: Optional[Callable[[], torch.nn.Module]] = None,
hidden_layer_multiplier: Optional[int] = None,
group: Optional[Any] = None,
*_,
**__,
):
super().__init__()

# Handle a possibly uninitialized process group
if group is None and not dist.is_initialized():
init_url = "file://" + tempfile.mkstemp()[1]
backend = (
dist.Backend.NCCL
if torch.cuda.is_available()
else dist.Backend.GLOO
)
dist.init_process_group(
backend=backend,
rank=0,
world_size=1,
init_method=init_url,
)

if number_of_local_experts is not None:
assert number_of_experts > number_of_local_experts
else:
if dist.get_world_size() == 1:
number_of_local_experts = number_of_experts
else:
number_of_local_experts = 1

# Programatically handle the gating technique
gate_constructor = {
GateConfig.RoundRobin: RoundRobinGate,
GateConfig.Top2: Top2Gate,
}[gate_config]

self.gate = gate_constructor(dim_model, number_of_experts)

# Programatically handle the experts
if expert_constructor is None:

multiplier = (
hidden_layer_multiplier
if hidden_layer_multiplier is not None
else 4
)

def expert_constructor() -> torch.nn.Module:
return MLP(dim_model, dropout, activation, multiplier)

assert expert_constructor is not None

local_experts = torch.nn.ModuleList(
[expert_constructor() for _ in range(number_of_local_experts)]
)

self.moe = MOELayer(gate=self.gate, experts=local_experts, group=group)

self.requires_cuda = True

def forward(self, inputs: torch.Tensor) -> torch.Tensor:
return self.moe(inputs)

0 comments on commit 8d5654d

Please sign in to comment.