You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
importance=gate_scores.sum(1) # Sum over all experts
# Aux loss is mean suqared difference between load and importance
loss= ((load-importance) **2).mean()
where load is of shape [num_experts, dim] and importance is of shape [batch_size, dim]. Testing this SwitchGate class alone by giving an input with batch_size > 1 will raise error like this RuntimeError: The size of tensor a (64) must match the size of tensor b (2) at non-singleton dimension 0
To Reproduce
Simply run a sample with batch_size > 1:
gate=SwitchGate(dim=16, num_experts=3)
x=torch.randn((2, 64, 16)).float()
y, loss=gate(x, use_aux_loss=True)
Upvote & Fund
We're using Polar.sh so you can upvote and help fund this issue.
We receive the funding once the issue is completed & confirmed by you.
Thank you in advance for helping prioritize & fund our backlog.
The text was updated successfully, but these errors were encountered:
Describe the bug
Shape mismatch is found in the computation of auxiliary loss values:
SwitchTransformers/switch_transformers/model.py
Lines 70 to 74 in 36a1ea0
where
load
is of shape[num_experts, dim]
andimportance
is of shape[batch_size, dim]
. Testing thisSwitchGate
class alone by giving an input withbatch_size > 1
will raise error like thisRuntimeError: The size of tensor a (64) must match the size of tensor b (2) at non-singleton dimension 0
To Reproduce
Simply run a sample with
batch_size > 1
:Upvote & Fund
The text was updated successfully, but these errors were encountered: