Skip to content

Commit

Permalink
Merge pull request #208 from ghostplant/main
Browse files Browse the repository at this point in the history
support Megablocks-style MoE inference
  • Loading branch information
msftsw committed Aug 5, 2023
2 parents 9016428 + 2af4bf1 commit 71db950
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 18 deletions.
43 changes: 37 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,41 @@ Tutel MoE: An Optimized Mixture-of-Experts Implementation.
- Supported CPU: fp64/fp32


### What's New:

- Tutel v0.3: Add Megablocks solution to improve decoder inference on single-GPU with num_local_expert >= 2:
```py
>> Example (capacity_factor=0 for dropless-MoE):
# Using BatchMatmul:
python3 -m tutel.examples.helloworld --megablocks_size=0 --batch_size=1 --num_tokens=32 --top=1 --eval --num_local_experts=128 --capacity_factor=0
# Using Megablocks with block_size = 1:
python3 -m tutel.examples.helloworld --megablocks_size=1 --batch_size=1 --num_tokens=32 --top=1 --eval --num_local_experts=128 --capacity_factor=0
# Using Megablocks with block_size = 2:
python3 -m tutel.examples.helloworld --megablocks_size=2 --batch_size=1 --num_tokens=32 --top=1 --eval --num_local_experts=128 --capacity_factor=0

>> How to:
self._moe_layer.forward(x, .., megablocks_size=1) # Control the switch of megablocks_size (0 for disabled)
```

- Tutel v0.2: Allow most configurations to be dynamic switchable with free cost:
```py
>> Example:
python3 -m torch.distributed.run --nproc_per_node=8 -m tutel.examples.helloworld_switch --batch_size=16

>> How to:
self._moe_layer.forward(x, .., a2a_ffn_overlap_degree=2) # Control the switch of overlap granularity (1 for no overlapping)
self._moe_layer.forward(x, .., adaptive_r=1) # Control the switch of parallelism (0 for DP, 1 for DP + EP, W / E for MP + EP, else for DP + MP + EP)
self._moe_layer.forward(x, .., capacity_factor=1) # Control the switch of capacity_volume (positive for padding, negative for no-padding, 0 for dropless)
self._moe_layer.forward(x, .., top_k=1) # Control the switch of top_k sparsity
```

- Tutel v0.1: Optimize the Einsum Complexity of Data Dispatch Encoding and Decoding, add 2DH option to deal with All-to-All at scale:
```py
>> Example (suggest enabling 2DH only at scale):
python3 -m torch.distributed.run --nproc_per_node=8 -m tutel.examples.helloworld_switch --batch_size=16 --use_2dh=1
```


How to setup Tutel MoE for Pytorch and [run examples](tutel/examples), or [enable fairseq with MoE](tutel/examples/fairseq_moe):
```
* Recommended Pytorch (minimize version == 1.8.0):
Expand Down Expand Up @@ -48,15 +83,11 @@ How to setup Tutel MoE for Pytorch and [run examples](tutel/examples), or [enabl
$ python3 ./tutel/examples/helloworld.py --batch_size=16
..
* Switch Test using single-node 8 GPUs:
$ python3 -m torch.distributed.launch --nproc_per_node=8 -m tutel.examples.helloworld_switch --batch_size=16
* Run Tutel MoE in Distributed Mode:
(Method A - Torch launcher for `Multi-Node x Multi-GPU`:)
$ ssh <node-ip-0> python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr=<node-ip-0> -m tutel.examples.helloworld --batch_size=16
$ ssh <node-ip-1> python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr=<node-ip-0> -m tutel.examples.helloworld --batch_size=16
$ ssh <node-ip-0> python3 -m torch.distributed.run --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr=<node-ip-0> -m tutel.examples.helloworld --batch_size=16
$ ssh <node-ip-1> python3 -m torch.distributed.run --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr=<node-ip-0> -m tutel.examples.helloworld --batch_size=16
(Method B - Tutel launcher for `Multi-Node x Multi-GPU`, requiring package `openmpi-bin`:)
# << Single Node >>
Expand Down
19 changes: 19 additions & 0 deletions tutel/custom/custom_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <cuda_fp16.h>
#include <cuda.h>
#include <nvrtc.h>
#include <ATen/cuda/CUDAContext.h>
#else
#undef USE_NCCL
#endif
Expand Down Expand Up @@ -777,7 +778,25 @@ extern "C" __global__ void cumsum_fn(int* input0 /* (num_samples, batch_num) */,
return y;
}

torch::Tensor warp_sparse_bmm_infer(const torch::Tensor &x, const torch::Tensor &w, const torch::Tensor &sparse_groups_device, bool w_transpose, int64_t sparse_size) {
auto sparse_groups = sparse_groups_device.cpu().to(torch::kInt32);
auto group_ptr = ((int*)sparse_groups.data_ptr());

auto y = torch::empty({x.size(0), x.size(1), w_transpose ? w.size(1) : w.size(2)}, torch::TensorOptions().dtype(x.dtype()).device(x.device()));

// auto hCublas = at::cuda::getCurrentCUDABlasHandle(); -- Wait Pytorch to add builtin support for cublasSgemmBatched()
for (int i = 0; i < sparse_groups.size(0); ++i) {
int group_size = group_ptr[i];
if (group_size > 0) {
auto y_sub = y.select(0, i).narrow(0, 0, int(group_size * sparse_size));
torch::matmul_out(y_sub, x.select(0, i).narrow(0, 0, int(group_size * sparse_size)), w_transpose ? w.select(0, i).t() : w.select(0, i));
}
}
return y;
}

TORCH_LIBRARY(tutel_ops, m) {
m.def("cumsum", warp_cumsum);
m.def("sparse_bmm_infer", warp_sparse_bmm_infer);
}
#endif
10 changes: 8 additions & 2 deletions tutel/examples/helloworld.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--use_2dh', default=False, action='store_true')
parser.add_argument('--eval', default=False, action='store_true')
parser.add_argument('--capacity_factor', type=float, default=1.0) # 0.0 for dMoE (dropless-MoE), negative for no-padded capacity.
parser.add_argument('--megablocks_size', type=int, default=1)

args = parser.parse_args()

parallel_env = system.init_data_model_parallel(backend='nccl' if args.device == 'cuda' else 'gloo')
Expand Down Expand Up @@ -66,7 +69,7 @@ def __init__(self):
super().__init__()

self._moe_layer = tutel_moe.moe_layer(
gate_type = {'type': 'top', 'k': top_value, 'fp32_gate': args.fp32_gate},
gate_type = {'type': 'top', 'k': top_value, 'fp32_gate': args.fp32_gate, 'capacity_factor': args.capacity_factor},
experts = {'type': 'ffn', 'count_per_node': num_local_experts, 'hidden_size_per_expert': hidden_size, 'activation_fn': lambda x: F.relu(x)},
model_dim = model_dim,
scan_expert_func = lambda name, param: setattr(param, 'skip_allreduce', True),
Expand All @@ -82,7 +85,10 @@ def __init__(self):
dist_print('[Statistics] param count for MoE local_experts = %s, param count for MoE gate = %s.\n' % (local_count, shared_count))

def forward(self, input):
result = self._moe_layer(input)
if args.megablocks_size > 0:
result = self._moe_layer(input, megablocks_size=args.megablocks_size)
else:
result = self._moe_layer(input)
result = F.log_softmax(torch.sum(result, dim=2), dim=1)
return result

Expand Down
13 changes: 13 additions & 0 deletions tutel/experts/ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,19 @@ def forward(self, x, ctx):
batched_fc1_bias = self.batched_fc1_bias.unsqueeze(1)
batched_fc2_bias = self.batched_fc2_bias.unsqueeze(1)

# Implementation of https://arxiv.org/pdf/2211.15841.pdf in Tutel v0.3.x
# which benifits decoder inference on single-GPU if num_local_experts >= 2
if ctx.megablocks_size > 0:
sparse_size = ctx.megablocks_size
sparse_groups = torch.div(ctx.dispatch_count + (sparse_size - 1), sparse_size, rounding_mode='floor')
sparse_groups = torch.minimum(sparse_groups, torch.tensor(x.size(1) // sparse_size, dtype=torch.int32, device=x.device))
y = torch.ops.tutel_ops.sparse_bmm_infer(x, batched_fc1_w, sparse_groups, True, sparse_size)
y = torch.add(y, batched_fc1_bias)
y = self.activation_fn(y)
y = torch.ops.tutel_ops.sparse_bmm_infer(y, batched_fc2_w, sparse_groups, False, sparse_size)
y = torch.add(y, batched_fc2_bias)
return y

if ctx.adaptive_degree == 0:
batched_fc1_w = net.zero_gather(batched_fc1_w, group=ctx.group).view(ctx.num_global_experts, -1, batched_fc1_w.size(2))
batched_fc2_w = net.zero_gather(batched_fc2_w, group=ctx.group).view(ctx.num_global_experts, -1, batched_fc2_w.size(2))
Expand Down
16 changes: 11 additions & 5 deletions tutel/impls/fast_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ def extract_critical(scores, top_k, loss_fn=losses.gshard_loss, capacity_factor=
if normalize_gate:
denom_s = torch.clamp(sum(gates_s), min=torch.finfo(gates_s[0].dtype).eps)
gates_s = [x / denom_s for x in gates_s]
else:
locations2 = locations1
locations2 = locations2[-1] + 1

indices_s = [x.to(torch.int32) for x in indices_s]

Expand All @@ -183,8 +186,8 @@ def extract_critical(scores, top_k, loss_fn=losses.gshard_loss, capacity_factor=
if capacity_factor > 0:
capacity = top_k * int(capacity_factor * samples_per_expert)
else:
capacity = torch.max(torch.cat(locations_s, dim=0))
capacity = int(simple_all_reduce(capacity, group=group, op=torch.distributed.ReduceOp.MAX)) + 1
capacity = locations2.max()
capacity = int(simple_all_reduce(capacity, group=group, op=torch.distributed.ReduceOp.MAX))
if capacity_factor < 0:
capacity = min(capacity, top_k * int(-capacity_factor * samples_per_expert))

Expand All @@ -195,16 +198,19 @@ def extract_critical(scores, top_k, loss_fn=losses.gshard_loss, capacity_factor=
if get_world_rank(group) == 0:
logging.info(f"Capacity = {capacity}, real-time capacity-factor for top-{top_k_original} = {capacity / (top_k * samples_per_expert)}")

return (num_global_experts, indices_s, locations_s, gates_s, capacity), l_loss
return (num_global_experts, indices_s, locations_s, gates_s, capacity, locations2), l_loss

def get_dispatch_count(critial_data):
return critial_data[-1]

def fast_encode(data, critial_data, is_postscore=True):
num_global_experts = critial_data[0]
dispatcher = TutelMoeFastDispatcher(num_global_experts, 0, data.size(-1), data.dtype)
dispatcher.update(*critial_data[1:], is_postscore=is_postscore)
dispatcher.update(*critial_data[1:-1], is_postscore=is_postscore)
return dispatcher.encode(data).view(num_global_experts, -1, data.size(-1))

def fast_decode(data, critial_data, is_postscore=True):
num_global_experts = critial_data[0]
dispatcher = TutelMoeFastDispatcher(num_global_experts, 0, data.size(-1), data.dtype)
dispatcher.update(*critial_data[1:], is_postscore=is_postscore)
dispatcher.update(*critial_data[1:-1], is_postscore=is_postscore)
return dispatcher.decode(data).view(-1, data.size(-1))
21 changes: 16 additions & 5 deletions tutel/impls/moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch.nn.functional as F

from ..impls import communicate as C
from ..impls.fast_dispatch import fast_encode, fast_decode, extract_critical
from ..impls.fast_dispatch import fast_encode, fast_decode, extract_critical, get_dispatch_count
from ..impls.overlap import a2a_ffn_overlap_forward
from . import losses

Expand Down Expand Up @@ -216,7 +216,7 @@ def expert_local(self, x, reserve_shape):
self.protected_shape = y.shape
return y.reshape(y.size(0), y.size(1), -1)

def forward(self, input: Tensor, gate_index=0, capacity_factor=None, top_k=None, a2a_ffn_overlap_degree=None, reserve_dims=1, inequivalent_tokens=False, adaptive_r=None):
def forward(self, input: Tensor, gate_index=0, capacity_factor=None, top_k=None, a2a_ffn_overlap_degree=None, reserve_dims=1, inequivalent_tokens=False, adaptive_r=None, megablocks_size=0):
if self.skip_moe:
result_output = input
result_output.l_aux = None
Expand All @@ -234,6 +234,12 @@ def forward(self, input: Tensor, gate_index=0, capacity_factor=None, top_k=None,
self.a2a_ffn_overlap_degree = a2a_ffn_overlap_degree
a2a_ffn_overlap_degree = self.a2a_ffn_overlap_degree

top_k = top_k or gctx.top_k

if megablocks_size > 0:
if self.num_local_experts <= 1 or torch.is_grad_enabled() or self.world_size > 1:
megablocks_size = 0

def routing():
logits = gctx(x)

Expand All @@ -249,14 +255,17 @@ def routing():
_loss_fn = lambda gates, topk_ids: losses.load_importance_loss(
F.softmax(logits, dim=1), logits_w_noise.gather(index=topk_ids, dim=1),
self.num_global_experts, gctx.gate_noise)

mega_up = max(megablocks_size, 1)

return logits.dtype, extract_critical(scores,
top_k = gctx.top_k if top_k is None else top_k,
top_k = top_k,
loss_fn = _loss_fn,
capacity_factor = gctx.capacity_factor if capacity_factor is None else capacity_factor,
capacity_factor = capacity_factor or gctx.capacity_factor,
batch_prioritized_routing = self.batch_prioritized_routing,
normalize_gate = self.normalize_gate,
group = self.group,
alignment = self.sharded_count * a2a_ffn_overlap_degree,
alignment = (self.sharded_count * a2a_ffn_overlap_degree + mega_up - 1) // mega_up * mega_up,
inequivalent_tokens = inequivalent_tokens,
)

Expand All @@ -267,6 +276,8 @@ def routing():
else:
logits_dtype, (crit, l_aux) = routing()

self.megablocks_size = megablocks_size
self.dispatch_count = get_dispatch_count(crit)
y = fast_encode(x.to(logits_dtype), crit, self.is_postscore).to(x.dtype)

if adaptive_r is not None:
Expand Down

0 comments on commit 71db950

Please sign in to comment.