# Mixture of Experts (MoE)

In [1]:
import torch
from torch import nn
from einops import rearrange, reduce

x = torch.rand(2,3)
print(x)

tensor([[0.9532, 0.2905, 0.9662],
        [0.5892, 0.4584, 0.6422]])


In [2]:
class MLP(nn.Module):
    def __init__(
        self,
        input_size,
        layer_size = 64,
        heads = 1,
        dropout = 0.5,
        bias = False
    ):
        super(MLP, self).__init__()
        self.linear = nn.Sequential(
            nn.Linear(
                input_size,
                layer_size,
                bias = bias
            ),
            nn.Dropout(dropout)
        )
        self.GELU = torch.nn.GELU()

    def forward(self, x):
        z = self.linear(x)
        z = nn.functional.normalize(z, dim=-1)
        z = self.GELU(z)
        return z

In [3]:
class SparseGate(nn.Module):
    def __init__(
        self,
        input_size,
        expert_count,
        hidden_layers=64,
        k = 2,
        dropout = 0.5
    ):
        super(SparseGate, self).__init__()
        self.input = nn.Sequential(
            nn.Linear(
                input_size[-1],
                hidden_layers,
                bias = False
            ),
            nn.Dropout(dropout)
        )
        self.linear = nn.Sequential(
            nn.Linear(
                hidden_layers,
                expert_count,
                bias = False
            ),
            nn.Dropout(dropout)
        )
        self.output = nn.Sequential(
            nn.Linear(
                input_size[0],
                1,
                bias = False
            ),
            nn.Dropout(dropout)
        )
        self.softmax = torch.nn.Softmax(dim = -1)
        self.k = k
        
    def forward(self, x):
        x = self.input(x)
        x = self.linear(x)
        x = self.output(x.T)
        x = rearrange(x, 'y x -> (y x)')
        topk, indices = torch.topk(x, self.k, dim = -1)
        topk = self.softmax(topk)
        z = torch.zeros_like(x)
        z[indices] = topk
        return z

In [17]:
class MoE(nn.Module):
    def __init__(
        self,
        input_size,
        expert_count,
        hidden_layers=64,
        k = 2,
        dropout = 0.5
    ):
        super(MoE, self).__init__()
        self.gate = SparseGate(
            input_size,
            expert_count,
            k = k,
            hidden_layers = hidden_layers,
            dropout = dropout
        )
        self.experts = [MLP(input_size[-1]) for _ in range(expert_count)]
        
    def forward(self, x):
        gate = self.gate(x)
        y = torch.tensor([])
        for i, g in enumerate(gate):
            if g > 0:
                e = self.experts[i](x) * g
                e = rearrange(e, '(k m) d -> k m d', k=1)
                y = torch.cat([y, e])
        y = reduce(y, 'k m d -> m d', 'sum')
        return y

In [22]:
mlp = MLP(
    x.size(-1)
)

z = mlp(x)
print(z.shape)

e_count = 10 #Amount of experts

moe = MoE(
    z.size(),
    e_count,
    hidden_layers=64,
    k = 2,
    dropout = 0.5
)

print(moe(z))

torch.Size([2, 64])
tensor([[ 4.1872e-02, -6.0294e-02,  5.3453e-02,  0.0000e+00, -2.3659e-02,
         -1.4250e-02, -1.2553e-01, -2.7605e-02,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  5.2603e-03,  0.0000e+00, -2.0094e-02,  8.0195e-02,
         -2.2227e-02, -9.2839e-03, -1.1017e-02,  0.0000e+00, -4.8096e-02,
          8.0005e-02,  1.1165e-01,  0.0000e+00,  4.2281e-03,  2.3398e-02,
          8.7653e-02,  0.0000e+00,  9.8433e-02, -2.5075e-02,  8.8758e-02,
         -6.6435e-02, -5.9241e-02,  4.6554e-02,  0.0000e+00,  3.5318e-03,
         -1.6312e-02,  0.0000e+00,  8.5300e-02,  2.5357e-02, -7.9469e-03,
         -1.2364e-02, -4.5439e-02,  0.0000e+00, -2.4371e-02, -2.3888e-02,
         -3.6232e-03,  1.3789e-01,  1.5770e-02, -3.2400e-02,  0.0000e+00,
         -3.5098e-02,  7.7504e-03,  9.6464e-03,  2.6843e-01,  1.0529e-02,
          1.5689e-02,  0.0000e+00,  2.8352e-02,  0.0000e+00, -9.8512e-03,
         -8.7423e-03,  4.0917e-02, -6.2688e-02,  0.0000e+00],
        [-1.3886e-02,  0.0000e