Skip to content

Commit

Permalink
keep cleaning
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 21, 2023
1 parent 7d75c38 commit 5f6b992
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'st-moe-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.16',
version = '0.0.17',
license='MIT',
description = 'ST - Mixture of Experts - Pytorch',
author = 'Phil Wang',
Expand Down
5 changes: 2 additions & 3 deletions st_moe_pytorch/st_moe_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ def __init__(
dim,
num_gates,
eps = 1e-9,
outer_expert_dims: Tuple[int, ...] = tuple(),
second_threshold_train = 0.2,
second_threshold_eval = 0.2,
capacity_factor_train = 1.25,
Expand All @@ -167,7 +166,7 @@ def __init__(
super().__init__()
self.eps = eps
self.num_gates = num_gates
self.w_gating = nn.Parameter(torch.randn(*outer_expert_dims, dim, num_gates))
self.to_gates = nn.Linear(dim, num_gates, bias = False)

self.second_threshold_train = second_threshold_train
self.second_threshold_eval = second_threshold_eval
Expand Down Expand Up @@ -196,7 +195,7 @@ def forward(self, x):

# gate logits and gates

gate_logits = einsum('... b n d, ... d e -> ... b n e', x, self.w_gating)
gate_logits = self.to_gates(x)
raw_gates = gate_logits.softmax(dim = -1)

# FIND TOP 2 EXPERTS PER POSITON
Expand Down

0 comments on commit 5f6b992

Please sign in to comment.