Skip to content

Commit

Permalink
router z loss should be calculated on the unnoised gating logits
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 21, 2023
1 parent d9f5f08 commit 51727d0
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'st-moe-pytorch',
packages = find_packages(exclude=[]),
version = '0.1.2',
version = '0.1.4',
license='MIT',
description = 'ST - Mixture of Experts - Pytorch',
author = 'Phil Wang',
Expand Down
8 changes: 5 additions & 3 deletions st_moe_pytorch/st_moe_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,11 +420,13 @@ def forward(

gate_logits = self.to_gates(x)

maybe_noised_gate_logits = gate_logits

if noise_gates:
noise = gumbel_noise(gate_logits)
gate_logits = gate_logits + noise * noise_mult
noise = gumbel_noise(maybe_noised_gate_logits)
maybe_noised_gate_logits = maybe_noised_gate_logits + noise * noise_mult

raw_gates = gate_logits.softmax(dim = -1)
raw_gates = maybe_noised_gate_logits.softmax(dim = -1)

# find top N experts per position

Expand Down

0 comments on commit 51727d0

Please sign in to comment.