# Mixture of Experts (MoE)

In [1]:
import torch
from torch import nn
from einops import rearrange, reduce

x = torch.rand(2,3)
print(x)

tensor([[0.6499, 0.6035, 0.2683],
        [0.0190, 0.7235, 0.1508]])


In [2]:
class MLP(nn.Module):
    def __init__(
        self,
        input_size,
        layer_size = 64,
        heads = 1,
        dropout = 0.5,
        bias = False
    ):
        super(MLP, self).__init__()
        self.linear = nn.Sequential(
            nn.Linear(
                input_size,
                layer_size,
                bias = bias
            ),
            nn.Dropout(dropout)
        )
        self.GELU = torch.nn.GELU()

    def forward(self, x):
        z = self.linear(x)
        z = nn.functional.normalize(z, dim=-1)
        z = self.GELU(z)
        return z

In [19]:
class SparseGate(nn.Module):
    def __init__(
        self,
        input_size,
        expert_count,
        hidden_layers=64,
        k = 2,
        dropout = 0.5
    ):
        super(SparseGate, self).__init__()
        self.input = nn.Sequential(
            nn.Linear(
                input_size[-1],
                hidden_layers,
                bias = False
            ),
            nn.Dropout(dropout)
        )
        self.linear = nn.Sequential(
            nn.Linear(
                hidden_layers,
                expert_count,
                bias = False
            ),
            nn.Dropout(dropout)
        )
        self.output = nn.Sequential(
            nn.Linear(
                input_size[0],
                1,
                bias = False
            ),
            nn.Dropout(dropout)
        )
        self.softmax = torch.nn.Softmax(dim = -1)
        self.k = k
        
    def forward(self, x):
        x = self.input(x)
        x = self.linear(x)
        x = self.output(x.T)
        x = rearrange(x, 'y x -> (y x)')
        topk, indices = torch.topk(x, self.k, dim = -1)
        topk = self.softmax(topk)
        return [(i, t) for i, t in zip(indices, topk)]

In [20]:
class MoE(nn.Module):
    def __init__(
        self,
        input_size,
        expert_count,
        hidden_layers=64,
        k = 2,
        dropout = 0.5
    ):
        super(MoE, self).__init__()
        self.gate = SparseGate(
            input_size,
            expert_count,
            k = k,
            hidden_layers = hidden_layers,
            dropout = dropout
        )
        self.experts = [MLP(input_size[-1]) for _ in range(expert_count)]
        
    def forward(self, x):
        gate = self.gate(x)
        y = torch.tensor([])
        for i, g in gate:
            e = self.experts[i](x) * g
            e = rearrange(e, '(k m) d -> k m d', k=1)
            y = torch.cat([y, e])
        y = reduce(y, 'k m d -> m d', 'sum')
        return y

In [21]:
mlp = MLP(
    x.size(-1)
)

z = mlp(x)
print(z.shape)

e_count = 10 #Amount of experts

moe = MoE(
    z.size(),
    e_count,
    hidden_layers=64,
    k = 2,
    dropout = 0.5
)

print(moe(z))

torch.Size([2, 64])
tensor([[ 0.0000,  0.0255,  0.0319,  0.0056,  0.1313,  0.0533, -0.0370,  0.0368,
          0.0385,  0.0000, -0.0040,  0.0124,  0.0499,  0.0538,  0.0115, -0.0306,
          0.0000, -0.0491,  0.0000, -0.0112, -0.0545,  0.0000, -0.0412, -0.0903,
          0.1216, -0.0041,  0.0444, -0.0175, -0.0441,  0.0000, -0.0674,  0.0000,
         -0.0255,  0.0149,  0.0142, -0.0036, -0.0241,  0.0515,  0.0000,  0.0000,
          0.0000,  0.0190,  0.0000, -0.0450,  0.0632,  0.0000, -0.0550, -0.0311,
          0.0466, -0.0263,  0.0750, -0.0436, -0.0452,  0.1023, -0.0096,  0.0000,
          0.0000, -0.0422,  0.0210,  0.0045,  0.0057,  0.0000,  0.0609, -0.0098],
        [ 0.0000,  0.0166,  0.0000,  0.0214,  0.1489,  0.0000, -0.0344,  0.1235,
          0.0220,  0.0000,  0.0000, -0.0533,  0.0149,  0.0772,  0.1256,  0.0000,
         -0.0243, -0.0206,  0.0000,  0.0299,  0.0000,  0.0000,  0.0003,  0.0000,
          0.1655, -0.0307,  0.0056,  0.0000,  0.0626,  0.0342, -0.0429, -0.0639,
       