Skip to content

Commit

Permalink
allow for different thresholds between second and third expert
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 21, 2023
1 parent f9b8ce3 commit 22dfd4d
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 15 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ out, balance_loss, router_z_loss = moe_block(inputs) # (4, 1024, 512), (1,), (1,
- [x] consult some MoE experts in the open source community; question why hierarchical MoE is needed, in light of results from soft-MoE
- [x] offer top-n gating generalization, as it seems top3 (with smaller threshold) can work even better
- [x] figure out if there was an error in <a href="https://github.com/lucidrains/mixture-of-experts/blob/master/mixture_of_experts/mixture_of_experts.py#L210">a previous transcription</a> - no there was not an error
- [x] allow for different thresholds for second vs third routed expert

- [ ] allow for different thresholds for second vs third routed expert
- [ ] improvise a `Top2GatingWithCoordinateDescent` for `MoE` without `importance`
- [ ] take care of scatter gather, and once done, port over to <a href="https://github.com/lucidrains/soft-moe-pytorch">soft moe</a>

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'st-moe-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.20',
version = '0.0.21',
license='MIT',
description = 'ST - Mixture of Experts - Pytorch',
author = 'Phil Wang',
Expand Down
39 changes: 26 additions & 13 deletions st_moe_pytorch/st_moe_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from functools import partial
from collections import namedtuple
from typing import Optional
from typing import Optional, Tuple, Union

import torch
from torch.nn import Module, ModuleList
Expand Down Expand Up @@ -34,8 +34,8 @@ def default(val, default):

return default() if callable(default) else default

def cast_tuple(el):
return el if isinstance(el, tuple) else (el,)
def cast_tuple(el, len = 1):
return el if isinstance(el, tuple) else ((el,) * len)

def pack_one(t, pattern):
return pack([t], pattern)
Expand Down Expand Up @@ -144,14 +144,16 @@ def forward(self, x):
# gating network

class TopNGating(Module):

@beartype
def __init__(
self,
dim,
num_gates,
eps = 1e-9,
top_n = 2,
threshold_train = 0.2,
threshold_eval = 0.2,
threshold_train: Union[float, Tuple[float, ...]] = 0.2,
threshold_eval: Union[float, Tuple[float, ...]] = 0.2,
capacity_factor_train = 1.25,
capacity_factor_eval = 2.,
straight_through_dispatch_tensor = True
Expand All @@ -163,9 +165,16 @@ def __init__(

assert top_n >= 2, 'must be 2 or more experts'
self.top_n = top_n
top_n_minus_1 = top_n - 1

threshold_train = cast_tuple(threshold_train, top_n_minus_1)
threshold_eval = cast_tuple(threshold_eval, top_n_minus_1)

assert len(threshold_train) == len(threshold_eval) == top_n_minus_1

self.register_buffer('threshold_train', torch.tensor([eps, *threshold_train]))
self.register_buffer('threshold_eval', torch.tensor([eps, *threshold_eval]))

self.threshold_train = threshold_train
self.threshold_eval = threshold_eval
self.capacity_factor_train = capacity_factor_train
self.capacity_factor_eval = capacity_factor_eval

Expand All @@ -182,7 +191,7 @@ def forward(self, x):
k - top-n experts
"""

*_, b, group_size, dim, dtype, num_gates, eps = *x.shape, x.dtype, self.num_gates, self.eps
*_, b, group_size, dim, dtype, top_n, num_gates, eps = *x.shape, x.dtype, self.top_n, self.num_gates, self.eps

# threshold, capacity depending on training or eval

Expand All @@ -205,7 +214,7 @@ def forward(self, x):

# find top N experts per position

gates, gate_indices = raw_gates.topk(k = 2, dim = -1)
gates, gate_indices = raw_gates.topk(k = top_n, dim = -1)

# move the top-n dimension to be first

Expand All @@ -230,9 +239,11 @@ def forward(self, x):

probs = torch.zeros_like(gates).uniform_(0., 1.)

should_route = probs < (gates / max(threshold, eps))
threshold = rearrange(threshold, 'k -> k 1 1')
should_route = probs < (gates / threshold.clamp(min = eps))

# tokens should always be routed to first expert
# threshold for first expert already set to very small number, but just in case

should_route[0, ...] = True

Expand Down Expand Up @@ -317,6 +328,8 @@ def forward(self, x):
# plain mixture of experts

class MoE(Module):

@beartype
def __init__(self,
dim,
num_experts = 16,
Expand Down Expand Up @@ -355,12 +368,12 @@ def __init__(self,
self.loss_coef = loss_coef
self.router_z_loss_coef = router_z_loss_coef

def forward(self, inputs, **kwargs):
dispatch_tensor, combine_tensor, loss, router_z_loss = self.gate(inputs)
def forward(self, x):
dispatch_tensor, combine_tensor, loss, router_z_loss = self.gate(x)

# dispatch

expert_inputs = einsum('b n d, b n e c -> e b c d', inputs, dispatch_tensor)
expert_inputs = einsum('b n d, b n e c -> e b c d', x, dispatch_tensor)

# feed the expert inputs through the experts.

Expand Down

0 comments on commit 22dfd4d

Please sign in to comment.