Skip to content

Commit

Permalink
code review, cleaner
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux committed Jan 26, 2022
1 parent da1e714 commit 56f0ea0
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 31 deletions.
19 changes: 19 additions & 0 deletions experimental/ragged_inference/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.


import tempfile
from typing import Tuple

import numpy as np
Expand Down Expand Up @@ -82,3 +83,21 @@ def assert_eq(actual, expected, msg="", rtol=None, atol=None):

def bf16_cuda():
return dict(device="cuda", dtype=torch.bfloat16)


def init_torch_distributed_local():
if torch.distributed.is_initialized():
return

init_url = "file://" + tempfile.mkstemp()[1]
backend = (
torch.distributed.Backend.NCCL
if torch.cuda.is_available()
else torch.distributed.Backend.GLOO
)
torch.distributed.init_process_group(
backend=backend,
rank=0,
world_size=1,
init_method=init_url,
)
11 changes: 9 additions & 2 deletions tests/test_block_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
xFormerEncoderBlock,
xFormerEncoderConfig,
)
from xformers.helpers.test_utils import init_torch_distributed_local

BATCH = 4
SEQ = 128
Expand Down Expand Up @@ -80,9 +81,12 @@ def test_xformer_encoder_block(
"activation": activation,
"hidden_layer_multiplier": 4,
"number_of_experts": 4,
"gate_config": "top_2",
"gate": "top_2",
}

if feedforward_name == "MixtureOfExperts":
init_torch_distributed_local()

position_encoding_config = {
"name": "sine",
"dim_model": MODEL,
Expand Down Expand Up @@ -172,9 +176,12 @@ def test_xformer_decoder_block(
"activation": activation,
"hidden_layer_multiplier": 4,
"number_of_experts": 4,
"gate_config": "top_2",
"gate": "top_2",
}

if feedforward_name == "MixtureOfExperts":
init_torch_distributed_local()

position_encoding_config = {
"name": "sine",
"dim_model": MODEL,
Expand Down
10 changes: 8 additions & 2 deletions tests/test_feedforward.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from xformers.components import Activation
from xformers.components.feedforward import FEEDFORWARD_REGISTRY, build_feedforward
from xformers.components.feedforward.mixture_of_experts import GateConfig
from xformers.helpers.test_utils import init_torch_distributed_local

BATCH = 4
SEQ = 512
Expand Down Expand Up @@ -36,9 +37,12 @@ def test_feedforward(
"activation": activation,
"hidden_layer_multiplier": 4,
"number_of_experts": 4, # MoE
"gate_config": "top_2", # MoE
"gate": "top_2", # MoE
}

if feedforward_name == "MixtureOfExperts":
init_torch_distributed_local()

# dummy, just check construction and dimensions in the FW pass
ffw = build_feedforward(test_config)

Expand Down Expand Up @@ -69,10 +73,12 @@ def test_moe(gate, number_of_local_experts, expert_constructor):
"hidden_layer_multiplier": 4,
"number_of_experts": 4,
"number_of_local_experts": number_of_local_experts,
"gate_config": gate,
"gate": gate,
"expert_constructor": expert_constructor,
}

init_torch_distributed_local()

# dummy, just check construction and dimensions in the FW pass
ffw = build_feedforward(test_config)

Expand Down
41 changes: 14 additions & 27 deletions xformers/components/feedforward/mixture_of_experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@


import logging
import tempfile
from dataclasses import dataclass
from enum import Enum
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, Union

import torch

Expand Down Expand Up @@ -63,7 +62,7 @@ class GateConfig(str, Enum):
@dataclass
class MoEConfig(FeedforwardConfig):
number_of_experts: int
gate_config: GateConfig
gate: GateConfig
number_of_local_experts: Optional[int] = None
expert_constructor: Optional[Any] = None
hidden_layer_multiplier: Optional[int] = None
Expand All @@ -87,7 +86,7 @@ def __init__(
dropout: float,
activation: Activation,
number_of_experts: int,
gate_config: GateConfig,
gate: Union[GateConfig, torch.nn.Module],
number_of_local_experts: Optional[int] = None,
expert_constructor: Optional[Callable[[], torch.nn.Module]] = None,
hidden_layer_multiplier: Optional[int] = None,
Expand All @@ -98,24 +97,9 @@ def __init__(
super().__init__()

# Handle a possibly uninitialized process group
if group is None and not dist.is_initialized():
logging.warning(
"Torch Distributed is not initialized, please do so before instantiating MoE"
)
logging.warning("Attempting fallback initialization")

init_url = "file://" + tempfile.mkstemp()[1]
backend = (
dist.Backend.NCCL
if torch.cuda.is_available()
else dist.Backend.GLOO
)
dist.init_process_group(
backend=backend,
rank=0,
world_size=1,
init_method=init_url,
)
assert (
dist.is_initialized()
), "Mixture of Experts require torch distributed to be initialized"

if number_of_local_experts is not None:
assert number_of_experts >= number_of_local_experts
Expand All @@ -128,12 +112,15 @@ def __init__(
number_of_local_experts = 1

# Programatically handle the gating technique
gate_constructor = {
GateConfig.RoundRobin: RoundRobinGate,
GateConfig.Top2: Top2Gate,
}[gate_config]
if not isinstance(gate, torch.nn.Module):
gate_constructor = {
GateConfig.RoundRobin: RoundRobinGate,
GateConfig.Top2: Top2Gate,
}[gate]

self.gate = gate_constructor(dim_model, number_of_experts)
self.gate = gate_constructor(dim_model, number_of_experts)
else:
self.gate = gate

# Programatically handle the experts
if expert_constructor is None:
Expand Down

0 comments on commit 56f0ea0

Please sign in to comment.