# Mixture of Experts (MoE)

In [1]:
import torch
from torch import nn

x = torch.rand(2,3)
print(x)

tensor([[0.0215, 0.8337, 0.6925],
        [0.5410, 0.0609, 0.2126]])


In [2]:
class MLP(nn.Module):
    def __init__(
        self,
        input_size,
        layer_size = 64,
        heads = 1,
        dropout = 0.5,
        bias = False
    ):
        super(MLP, self).__init__()
        self.linear = nn.Sequential(
            nn.Linear(
                input_size,
                layer_size,
                bias = bias
            ),
            nn.Dropout(dropout)
        )
        self.GELU = torch.nn.GELU()

    def forward(self, x):
        z = self.linear(x)
        z = nn.functional.normalize(z, dim=-1)
        z = self.GELU(z)
        return z

In [3]:
class Router(nn.Module):
    def __init__(
        self,
        input_size,
        dropout = 0.5
    ):
        super(Router, self).__init__()
        self.linear = nn.Sequential(
            nn.Linear(
                input_size,
                1,
                bias = False
            ),
            nn.Dropout(dropout)
        )
        self.GELU = nn.GELU()
        self.softmax = nn.Softmax(dim=0)

    def forward(self, x):
        z = self.linear(x)
        #z = nn.functional.normalize(z, dim=-1)
        z = self.GELU(z)
        z = self.softmax(z)
        return z

In [7]:
mlp = MLP(
    input_size = x.size(-1)
)

z = mlp(x)

router = Router(
    input_size = z.size(-1)
)

z = router(z)
print(z)

print(torch.topk(z, 1, dim = 0))

tensor([[0.5003],
        [0.4997]], grad_fn=<SoftmaxBackward0>)
torch.return_types.topk(
values=tensor([[0.5003]], grad_fn=<TopkBackward0>),
indices=tensor([[0]]))
