Skip to content

Commit

Permalink
better naming
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 21, 2023
1 parent 1c4a2aa commit 1ca8170
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'st-moe-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.18',
version = '0.0.19',
license='MIT',
description = 'ST - Mixture of Experts - Pytorch',
author = 'Phil Wang',
Expand Down
10 changes: 5 additions & 5 deletions st_moe_pytorch/st_moe_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def __init__(
threshold_eval = 0.2,
capacity_factor_train = 1.25,
capacity_factor_eval = 2.,
detached_dispatch_tensor = True
straight_through_dispatch_tensor = True
):
super().__init__()
self.eps = eps
Expand All @@ -171,7 +171,7 @@ def __init__(
self.capacity_factor_train = capacity_factor_train
self.capacity_factor_eval = capacity_factor_eval

self.detached_dispatch_tensor = detached_dispatch_tensor
self.straight_through_dispatch_tensor = straight_through_dispatch_tensor
self.register_buffer('zero', torch.zeros((1,)), persistent = False)

def forward(self, x):
Expand Down Expand Up @@ -288,7 +288,7 @@ def forward(self, x):

dispatch_tensor = combine_tensor.bool().type(dtype)

if not self.detached_dispatch_tensor:
if self.straight_through_dispatch_tensor:
dispatch_tensor = dispatch_tensor + combine_tensor - combine_tensor.detach()

# balance losses - (batch, experts)
Expand Down Expand Up @@ -328,7 +328,7 @@ def __init__(self,
loss_coef = 1e-2,
router_z_loss_coef = 1e-3,
experts: Optional[Module] = None,
detached_dispatch_tensor = True
straight_through_dispatch_tensor = True
):
super().__init__()
self.dim = dim
Expand All @@ -345,7 +345,7 @@ def __init__(self,
dim,
top_n = gating_top_n,
num_gates = num_experts,
detached_dispatch_tensor = detached_dispatch_tensor,
straight_through_dispatch_tensor = straight_through_dispatch_tensor,
**gating_kwargs
)

Expand Down

0 comments on commit 1ca8170

Please sign in to comment.