Skip to content

Commit

Permalink
rename is_postnorm to is_postscore (#107)
Browse files Browse the repository at this point in the history
  • Loading branch information
ghostplant committed Feb 26, 2022
1 parent 712bf2e commit bddc915
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 28 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,14 @@ Usage of MOELayer:
or a list of dict-type gate descriptions, e.g. [{'type': 'top', 'k', 2}, {'type': 'top', 'k', 2}],
the value of k in top-gating can be also negative, like -2, which indicates one GPU will hold 1/(-k) parameters of an expert
model_dim : the number of channels for MOE's input tensor
experts : a dict-type config for builtin expert network, or a torch.nn.Module-type custom expert network
experts : a dict-type config for builtin expert network
scan_expert_func : allow users to specify a lambda function to iterate each experts param, e.g. `scan_expert_func = lambda name, param: setattr(param, 'expert', True)`
result_func : allow users to specify a lambda function to format the MoE output and aux_loss, e.g. `result_func = lambda output: (output, output.l_aux)`
group : specify the explicit communication group of all_to_all
seeds : a tuple containing a tripple of int to specify manual seed of (shared params, local params, others params after MoE's)
a2a_ffn_overlap_degree : the value to control a2a overlap depth, 1 by default for no overlap, 2 for overlap a2a with half gemm, ..
parallel_type : the parallel method to compute MoE, valid types: 'auto', 'data', 'model'
pad_samples : whether do auto padding on newly-coming input data to maximum data size in history
* Usage of dict-type Experts Config:
Expand Down
4 changes: 4 additions & 0 deletions tutel/impls/communicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,12 @@ def init(group: dist.ProcessGroup, num_split: int, split_dim: int) -> None:


class AllToAll(torch.autograd.Function):
_use_builtins = False

@staticmethod
def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor):
AllToAll._use_builtins = True

ctx.group = group
world_size = get_world_size(group)
if world_size <= 1:
Expand Down
8 changes: 4 additions & 4 deletions tutel/impls/fast_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,13 @@ def __init__(self, num_global_experts, capacity, model_dim, dispatch_dtype):
self.original_dtype = dispatch_dtype
self.aligned_dim = model_dim // (2 if self.dtype == torch.float16 else 1)

def update(self, indices_, locations_, gates_, capacity=None, is_postnorm=True):
def update(self, indices_, locations_, gates_, capacity=None, is_postscore=True):
self.indices_ = [x.to(torch.int32).view(-1) for x in indices_]
self.locations_ = [x.to(torch.int32) for x in locations_]
self.gates_ = [x.to(self.dtype) for x in gates_]
sample_size = int(self.indices_[0].size(0))
capacity = int(capacity) or self.capacity
self.is_postnorm = is_postnorm
self.is_postscore = is_postscore

if sample_size != self.expected_sample_size or capacity != self.capacity:
self.expected_sample_size, self.capacity = sample_size, capacity
Expand All @@ -109,13 +109,13 @@ def update(self, indices_, locations_, gates_, capacity=None, is_postnorm=True):
self.func_fwd, self.func_bwd_data, self.func_bwd_gate, self.ones_helper = self.kernel_pool[tuple((sample_size, capacity))]

def encode(self, data):
if self.is_postnorm:
if self.is_postscore:
return GatingEncoder.apply(self, data.to(self.dtype)).to(self.original_dtype)
else:
return GatingEncoder.apply(self, data.to(self.dtype), *self.gates_).to(self.original_dtype)

def decode(self, data):
if self.is_postnorm:
if self.is_postscore:
return GatingDecoder.apply(self, data.to(self.dtype), *self.gates_).to(self.original_dtype)
else:
return GatingDecoder.apply(self, data.to(self.dtype)).to(self.original_dtype)
Expand Down
44 changes: 27 additions & 17 deletions tutel/impls/moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,20 +48,22 @@ def __init__(
num_global_experts,
a2a_ffn_overlap_degree=1,
capacity_factor=1.0,
top_k=2,
k=2,
batch_prioritized_routing=False,
**kwargs,
fp32_gate=False,
is_postscore=True,
input_dropout_p=0,
):
super().__init__()
top_k = min(top_k, num_global_experts)
self.top_k = top_k
k = min(k, num_global_experts)
self.top_k = k
assert self.top_k > 0, "Top-k value %d is not valid." % self.top_k

self.wg = torch.nn.Linear(model_dim, num_global_experts, bias=False)

self.fp32_gate = kwargs.get('fp32_gate', False)
self.fp32_gate = fp32_gate
if self.fp32_gate:
self.wg = self.wg.float()
self.wg = self.wg.float()

self.capacity_factor = float(os.environ.get('CAP_FACTOR', capacity_factor))
self.is_ones_gate = (int(os.environ.get('ONES_GATE', 0)) == 1)
Expand All @@ -71,8 +73,7 @@ def __init__(
if int(os.environ.get('BATCH_PRIO', 0)) != 0:
self.batch_prioritized_routing = True

self.is_postnorm = kwargs.get('is_postnorm', True)
input_dropout_p = kwargs.get('input_dropout_p', 0)
self.is_postscore = is_postscore
self.input_dropout = torch.nn.Dropout(p=input_dropout_p) if input_dropout_p else None

self.a2a_ffn_overlap_degree = a2a_ffn_overlap_degree
Expand Down Expand Up @@ -134,7 +135,7 @@ def apply_on_expert_fn(self, input, ctx):

if self.is_ones_gate:
gates_s = [torch.ones_like(x) for x in gates_s]
self._fdr.update(indices_s, locations_s, gates_s, capacity=capacity, is_postnorm=self.is_postnorm)
self._fdr.update(indices_s, locations_s, gates_s, capacity=capacity, is_postscore=self.is_postscore)

dispatched_input = self._fdr.encode(input)

Expand Down Expand Up @@ -223,7 +224,19 @@ class MOELayer(torch.nn.Module):
"""Tutel optimized MOELayer
"""

def __init__(self, gate_type, model_dim: int, experts = None, scan_expert_func = None, result_func = None, group: Optional[Any] = None, seeds = None, a2a_ffn_overlap_degree = 1, **kwargs):
def __init__(
self,
gate_type,
model_dim: int,
experts=None,
scan_expert_func=None,
result_func=None,
group=None,
seeds=None,
a2a_ffn_overlap_degree=1,
parallel_type='auto',
pad_samples=False,
):
super().__init__()
assert model_dim % 2 == 0, "Model_dim (%s) must be even value, while this Model_dim mod 2 > 0." % model_dim
group = group or dist.group.WORLD
Expand Down Expand Up @@ -257,7 +270,6 @@ def __init__(self, gate_type, model_dim: int, experts = None, scan_expert_func =
self.num_global_experts = num_devices * self.num_local_experts
sharded_count = 1

parallel_type = kwargs.get('parallel_type', 'auto')
if sharded_count == 1 or not self.is_builtin_experts:
self.auto_parallel, self.use_model_parallel = False, False
elif parallel_type == 'auto':
Expand Down Expand Up @@ -413,11 +425,9 @@ def to(self, *args, **kwargs):
if single_gate_type['type'] == 'top':
if seeds is not None and seeds[0] is not None:
torch.manual_seed(seeds[0] + gi)
if "fp32_gate" in kwargs:
logging.warning(f'`fp32_gate` option in tutel.moe_layer has been deprecated, please move this option to gate_type = {{.., "fp32_gate": {kwargs["fp32_gate"]}}} instead.')
single_gate_type["fp32_gate"] = kwargs["fp32_gate"]

self.gates += [TopKGate(model_dim=model_dim, top_k=single_gate_type['k'], num_global_experts=self.num_global_experts, a2a_ffn_overlap_degree=a2a_ffn_overlap_degree, **single_gate_type)]
single_gate_type.pop('type')
self.gates += [TopKGate(model_dim=model_dim, num_global_experts=self.num_global_experts, a2a_ffn_overlap_degree=a2a_ffn_overlap_degree, **single_gate_type)]
else:
raise Exception("Unrecognized gate_type: %s" % single_gate_type)

Expand All @@ -435,7 +445,7 @@ def expert_fn(dispatched_input):
return expert_output

self.expert_fn = expert_fn
self.expected_sample_size = 0 if kwargs.get('scale_samples', False) else -1
self.expected_sample_size = 0 if pad_samples else -1

def get_parameter_iterator(self, param_type):
if param_type == 'gate':
Expand All @@ -445,7 +455,7 @@ def get_parameter_iterator(self, param_type):
else:
raise Exception("Specified parameter type is not recognized: %s. Valid `param_type` includes: gate, local_experts." % param_type)

def forward(self, input: Tensor, gate_index=0, **kwargs: Any):
def forward(self, input: Tensor, gate_index=0):
if self.skip_moe:
result_output = input
result_output.l_aux = None
Expand Down
19 changes: 13 additions & 6 deletions tutel/system_init.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import os
import os, sys
import re
import logging

Expand All @@ -25,12 +25,19 @@ def init_affinity_at_program_beginning():
logging.warning('Failed to set NUMA status: %s' % ex)

def init_data_model_parallel(group_count=1, backend='nccl'):
from tutel.impls.communicate import create_groups_from_world
result = create_groups_from_world(group_count=group_count, include_init=backend)
from tutel.impls import communicate as C
result = C.create_groups_from_world(group_count=group_count, include_init=backend)
logging.critical(f'Registering device global rank {result.global_rank}: data_rank = {result.data_rank}, model_rank = {result.model_rank}')

def on_quit():
sys.stdout.flush()
sys.stderr.flush()
# Builtin dist.all_to_all_single in torch is unstable in some versions.
# Temp work around: https://github.com/pytorch/pytorch/issues/56390
if C.AllToAll._use_builtins:
os._exit(0)

# Temp work around for: https://github.com/pytorch/pytorch/issues/56390
import atexit
atexit.register(lambda *args: os._exit(0))
atexit.register(lambda *args: on_quit())

logging.critical(f'Registering device global rank {result.global_rank}: data_rank = {result.data_rank}, model_rank = {result.model_rank}')
return result

0 comments on commit bddc915

Please sign in to comment.