# Mixture of Experts (MoE)

In [1]:
import torch
from torch import nn
from einops import rearrange, reduce

x = torch.rand(2,3)
print(x)

tensor([[0.9756, 0.4738, 0.3592],
        [0.6957, 0.7318, 0.4722]])


In [2]:
class MLP(nn.Module):
    def __init__(
        self,
        input_size,
        layer_size = 64,
        heads = 1,
        dropout = 0.5,
        bias = False
    ):
        super(MLP, self).__init__()
        self.linear = nn.Sequential(
            nn.Linear(
                input_size,
                layer_size,
                bias = bias
            ),
            nn.Dropout(dropout)
        )
        self.GELU = torch.nn.GELU()

    def forward(self, x):
        z = self.linear(x)
        z = nn.functional.normalize(z, dim=-1)
        z = self.GELU(z)
        return z

In [3]:
class SparseGate(nn.Module):
    def __init__(
        self,
        input_size,
        expert_count,
        hidden_layers=64,
        k = 2,
        dropout = 0.5
    ):
        super(SparseGate, self).__init__()
        self.input = nn.Sequential(
            nn.Linear(
                input_size[-1],
                hidden_layers,
                bias = False
            ),
            nn.Dropout(dropout)
        )
        self.linear = nn.Sequential(
            nn.Linear(
                hidden_layers,
                expert_count,
                bias = False
            ),
            nn.Dropout(dropout)
        )
        self.output = nn.Sequential(
            nn.Linear(
                input_size[0],
                1,
                bias = False
            ),
            nn.Dropout(dropout)
        )
        self.softmax = torch.nn.Softmax(dim = -1)
        self.k = k
        
    def forward(self, x):
        x = self.input(x)
        x = self.linear(x)
        x = self.output(x.T)
        x = rearrange(x, 'y x -> (y x)')
        topk, indices = torch.topk(x, self.k, dim = -1)
        topk = self.softmax(topk)
        return [(i, t) for i, t in zip(indices, topk)]

In [4]:
class MoE(nn.Module):
    def __init__(
        self,
        input_size,
        expert_count,
        hidden_layers=64,
        k = 2,
        dropout = 0.5
    ):
        super(MoE, self).__init__()
        self.gate = SparseGate(
            input_size,
            expert_count,
            k = k,
            hidden_layers = hidden_layers,
            dropout = dropout
        )
        self.experts = nn.ModuleList(MLP(input_size[-1]) for _ in range(expert_count))

    def forward(self, x):
        gate = self.gate(x)
        y = torch.tensor([])
        for i, g in gate:
            e = g * self.experts[i](x)
            e = rearrange(e, '(k m) d -> k m d', k=1)
            y = torch.cat([y, e])
        y = reduce(y, 'k m d -> m d', 'sum')
        return y

In [5]:
mlp = MLP(
    x.size(-1)
)

z = mlp(x)
print(z.shape)

e_count = 10 #Amount of experts

moe = MoE(
    z.size(),
    e_count,
    hidden_layers=64,
    k = 2,
    dropout = 0.5
)
print(moe)
print(moe(z))

torch.Size([2, 64])
MoE(
  (gate): SparseGate(
    (input): Sequential(
      (0): Linear(in_features=64, out_features=64, bias=False)
      (1): Dropout(p=0.5, inplace=False)
    )
    (linear): Sequential(
      (0): Linear(in_features=64, out_features=10, bias=False)
      (1): Dropout(p=0.5, inplace=False)
    )
    (output): Sequential(
      (0): Linear(in_features=2, out_features=1, bias=False)
      (1): Dropout(p=0.5, inplace=False)
    )
    (softmax): Softmax(dim=-1)
  )
  (experts): ModuleList(
    (0): MLP(
      (linear): Sequential(
        (0): Linear(in_features=64, out_features=64, bias=False)
        (1): Dropout(p=0.5, inplace=False)
      )
      (GELU): GELU()
    )
    (1): MLP(
      (linear): Sequential(
        (0): Linear(in_features=64, out_features=64, bias=False)
        (1): Dropout(p=0.5, inplace=False)
      )
      (GELU): GELU()
    )
    (2): MLP(
      (linear): Sequential(
        (0): Linear(in_features=64, out_features=64, bias=False)
        (1)