In [2]:
import torch
from einops import einsum 

In [3]:
torch.random.manual_seed(0)

I = 16
O = 8

z = torch.randn(I)
W_glu = torch.randn(I, O)
W_enc = torch.randn(I, O)

# GLU encoders
Vanilla foward pass is as follows:

In [3]:
activ_fn = torch.sigmoid
a = activ_fn(W_glu.T @ z)
out = a * (W_enc.T @ z)

print(out)

tensor([-1.9279,  0.4093,  2.1107, -3.8921, -0.6280,  0.0743,  2.3439, -0.0201])


# Rank-1 MoEs
And can also be written as:  

In [4]:
# take a as our expert coefficients
out_moe = 0
ranks = []
for n in range(O):
    # compute this instead as an MoE 
    E = W_enc @ torch.diag(torch.eye(O)[n])
    ranks += [torch.linalg.matrix_rank(E)]

    out_moe += a[n] * (E.T @ z)
    
print(out_moe)

tensor([-1.9279,  0.4093,  2.1107, -3.8921, -0.6280,  0.0743,  2.3439, -0.0201])


In [5]:
torch.testing.assert_close(out, out_moe)
print('... GLUs and MoEs are equivalent under re-parameterization')

... GLUs and MoEs are equivalent under re-parameterization


In [6]:
print('Rank of each GLU expert:')
for n in range(O):
    print(f'Expert {n}: {ranks[n]}')

Rank of each GLU expert:
Expert 0: 1
Expert 1: 1
Expert 2: 1
Expert 3: 1
Expert 4: 1
Expert 5: 1
Expert 6: 1
Expert 7: 1
