# Mixture of Experts (MoE)

In [1]:
import torch
from torch import nn
from einops import rearrange

x = torch.rand(2,3)
print(x)

tensor([[0.2868, 0.7317, 0.8240],
        [0.8960, 0.6777, 0.3623]])


In [2]:
class MLP(nn.Module):
    def __init__(
        self,
        input_size,
        layer_size = 64,
        heads = 1,
        dropout = 0.5,
        bias = False
    ):
        super(MLP, self).__init__()
        self.linear = nn.Sequential(
            nn.Linear(
                input_size,
                layer_size,
                bias = bias
            ),
            nn.Dropout(dropout)
        )
        self.GELU = torch.nn.GELU()

    def forward(self, x):
        z = self.linear(x)
        z = nn.functional.normalize(z, dim=-1)
        z = self.GELU(z)
        return z

In [3]:
class SparseGate(nn.Module):
    def __init__(
        self,
        input_size,
        expert_count,
        hidden_layers=64,
        k = 2,
        dropout = 0.5
    ):
        super(SparseGate, self).__init__()
        self.input = nn.Sequential(
            nn.Linear(
                input_size[-1],
                hidden_layers,
                bias = True
            ),
            nn.Dropout(dropout)
        )
        self.linear = nn.Sequential(
            nn.Linear(
                hidden_layers,
                expert_count,
                bias = True
            ),
            nn.Dropout(dropout)
        )
        self.output = nn.Sequential(
            nn.Linear(
                input_size[0],
                1,
                bias = False
            ),
            nn.Dropout(dropout)
        )
        self.softmax = torch.nn.Softmax(dim = -1)
        self.k = k
        
    def forward(self, x):
        x = self.input(x)
        x = self.linear(x)
        x = self.output(x.T)
        x = rearrange(x, 'y x -> (y x)')
        print(x)
        topk, indices = torch.topk(x, self.k, dim = -1)
        z = torch.zeros_like(x)
        z[indices] = topk
        z = self.softmax(z)
        return z

In [4]:
mlp = MLP(
    x.size(-1)
)

z = mlp(x)
print(z.shape)

router = SparseGate(
    z.size(),
    10,
    k = 3,
    hidden_layers = 20
)

r = router(z)
print(r)

torch.Size([2, 64])
tensor([ 0.0401, -0.0063,  0.0000, -0.0000,  0.0000,  0.0000, -0.0216,  0.0000,
        -0.0371, -0.0848], grad_fn=<ReshapeAliasBackward0>)
tensor([0.1037, 0.0996, 0.0996, 0.0996, 0.0996, 0.0996, 0.0996, 0.0996, 0.0996,
        0.0996], grad_fn=<SoftmaxBackward0>)
