From 1ca8170a6bc2ae94a8aa32a40aff91895d7f696a Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 21 Aug 2023 11:33:01 -0700 Subject: [PATCH] better naming --- setup.py | 2 +- st_moe_pytorch/st_moe_pytorch.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/setup.py b/setup.py index f64c228..5da55ae 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'st-moe-pytorch', packages = find_packages(exclude=[]), - version = '0.0.18', + version = '0.0.19', license='MIT', description = 'ST - Mixture of Experts - Pytorch', author = 'Phil Wang', diff --git a/st_moe_pytorch/st_moe_pytorch.py b/st_moe_pytorch/st_moe_pytorch.py index 5c01b83..70b3d5d 100644 --- a/st_moe_pytorch/st_moe_pytorch.py +++ b/st_moe_pytorch/st_moe_pytorch.py @@ -156,7 +156,7 @@ def __init__( threshold_eval = 0.2, capacity_factor_train = 1.25, capacity_factor_eval = 2., - detached_dispatch_tensor = True + straight_through_dispatch_tensor = True ): super().__init__() self.eps = eps @@ -171,7 +171,7 @@ def __init__( self.capacity_factor_train = capacity_factor_train self.capacity_factor_eval = capacity_factor_eval - self.detached_dispatch_tensor = detached_dispatch_tensor + self.straight_through_dispatch_tensor = straight_through_dispatch_tensor self.register_buffer('zero', torch.zeros((1,)), persistent = False) def forward(self, x): @@ -288,7 +288,7 @@ def forward(self, x): dispatch_tensor = combine_tensor.bool().type(dtype) - if not self.detached_dispatch_tensor: + if self.straight_through_dispatch_tensor: dispatch_tensor = dispatch_tensor + combine_tensor - combine_tensor.detach() # balance losses - (batch, experts) @@ -328,7 +328,7 @@ def __init__(self, loss_coef = 1e-2, router_z_loss_coef = 1e-3, experts: Optional[Module] = None, - detached_dispatch_tensor = True + straight_through_dispatch_tensor = True ): super().__init__() self.dim = dim @@ -345,7 +345,7 @@ def __init__(self, dim, top_n = gating_top_n, num_gates = num_experts, - detached_dispatch_tensor = detached_dispatch_tensor, + straight_through_dispatch_tensor = straight_through_dispatch_tensor, **gating_kwargs )