# Mixture of Experts (MoE)

In [1]:
import torch
from torch import nn
from einops import rearrange, reduce

x = torch.rand(2,3)
print(x)

tensor([[0.6499, 0.6035, 0.2683],
        [0.0190, 0.7235, 0.1508]])


In [2]:
class MLP(nn.Module):
    def __init__(
        self,
        input_size,
        layer_size = 64,
        heads = 1,
        dropout = 0.5,
        bias = False
    ):
        super(MLP, self).__init__()
        self.linear = nn.Sequential(
            nn.Linear(
                input_size,
                layer_size,
                bias = bias
            ),
            nn.Dropout(dropout)
        )
        self.GELU = torch.nn.GELU()

    def forward(self, x):
        z = self.linear(x)
        z = nn.functional.normalize(z, dim=-1)
        z = self.GELU(z)
        return z

In [19]:
class SparseGate(nn.Module):
    def __init__(
        self,
        input_size,
        expert_count,
        hidden_layers=64,
        k = 2,
        dropout = 0.5
    ):
        super(SparseGate, self).__init__()
        self.input = nn.Sequential(
            nn.Linear(
                input_size[-1],
                hidden_layers,
                bias = False
            ),
            nn.Dropout(dropout)
        )
        self.linear = nn.Sequential(
            nn.Linear(
                hidden_layers,
                expert_count,
                bias = False
            ),
            nn.Dropout(dropout)
        )
        self.output = nn.Sequential(
            nn.Linear(
                input_size[0],
                1,
                bias = False
            ),
            nn.Dropout(dropout)
        )
        self.softmax = torch.nn.Softmax(dim = -1)
        self.k = k
        
    def forward(self, x):
        x = self.input(x)
        x = self.linear(x)
        x = self.output(x.T)
        x = rearrange(x, 'y x -> (y x)')
        topk, indices = torch.topk(x, self.k, dim = -1)
        topk = self.softmax(topk)
        return [(i, t) for i, t in zip(indices, topk)]

In [32]:
class MoE(nn.Module):
    def __init__(
        self,
        input_size,
        expert_count,
        hidden_layers=64,
        k = 2,
        dropout = 0.5
    ):
        super(MoE, self).__init__()
        self.gate = SparseGate(
            input_size,
            expert_count,
            k = k,
            hidden_layers = hidden_layers,
            dropout = dropout
        )
        self.experts = [MLP(input_size[-1]) for _ in range(expert_count)]
        
    def forward(self, x):
        gate = self.gate(x)
        y = torch.tensor([])
        for i, g in gate:
            e = g * self.experts[i](x)
            e = rearrange(e, '(k m) d -> k m d', k=1)
            y = torch.cat([y, e])
        y = reduce(y, 'k m d -> m d', 'sum')
        return y

In [33]:
mlp = MLP(
    x.size(-1)
)

z = mlp(x)
print(z.shape)

e_count = 10 #Amount of experts

moe = MoE(
    z.size(),
    e_count,
    hidden_layers=64,
    k = 2,
    dropout = 0.5
)

print(moe(z))

torch.Size([2, 64])
tensor([[ 0.0000,  0.0451, -0.0163,  0.0000, -0.0265, -0.0006, -0.0319,  0.0381,
          0.0521,  0.0062, -0.0047, -0.0040, -0.0402,  0.0194,  0.0000, -0.0416,
          0.0815, -0.0694,  0.0266, -0.0547,  0.0211, -0.0225,  0.0000, -0.0342,
          0.0030,  0.0184, -0.0578,  0.0000,  0.0000,  0.0000,  0.0686, -0.0198,
          0.0000, -0.0569,  0.1204,  0.0398,  0.0318, -0.0225,  0.0000, -0.0078,
          0.0169, -0.0319,  0.0000,  0.1281, -0.0389,  0.0951,  0.0173, -0.0406,
         -0.0475, -0.0139,  0.0771,  0.0285,  0.0532,  0.0121,  0.0395,  0.0000,
          0.0000,  0.0000, -0.0526,  0.0048, -0.0043,  0.0489,  0.0631,  0.0870],
        [ 0.0360,  0.0000,  0.0860,  0.0113,  0.0245,  0.0000,  0.0225,  0.0000,
          0.1122, -0.0734, -0.0129,  0.0000,  0.0000,  0.0555,  0.1088,  0.0000,
         -0.0355, -0.1017,  0.0401,  0.0000,  0.0000,  0.0000,  0.0937,  0.0072,
          0.1026, -0.0120,  0.0834, -0.0067, -0.0160, -0.0166,  0.0000,  0.0000,
       