diff --git a/setup.py b/setup.py index ed8a524..ff1fc81 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'st-moe-pytorch', packages = find_packages(exclude=[]), - version = '0.1.2', + version = '0.1.4', license='MIT', description = 'ST - Mixture of Experts - Pytorch', author = 'Phil Wang', diff --git a/st_moe_pytorch/st_moe_pytorch.py b/st_moe_pytorch/st_moe_pytorch.py index b6c5961..60e5e80 100644 --- a/st_moe_pytorch/st_moe_pytorch.py +++ b/st_moe_pytorch/st_moe_pytorch.py @@ -420,11 +420,13 @@ def forward( gate_logits = self.to_gates(x) + maybe_noised_gate_logits = gate_logits + if noise_gates: - noise = gumbel_noise(gate_logits) - gate_logits = gate_logits + noise * noise_mult + noise = gumbel_noise(maybe_noised_gate_logits) + maybe_noised_gate_logits = maybe_noised_gate_logits + noise * noise_mult - raw_gates = gate_logits.softmax(dim = -1) + raw_gates = maybe_noised_gate_logits.softmax(dim = -1) # find top N experts per position