# Mixture of Experts

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MoE(nn.Module):
    """Mixtral of Experts: https://arxiv.org/pdf/2401.04088"""
    def __init__(self, n_experts: int, n_active_experts: int, d_model: int, d_ff: int) -> None:
        super().__init__()
        self.n_active_experts = n_active_experts
        self.experts = nn.ModuleList([SwiGLU(d_model, d_ff) for i in range(n_experts)])
        self.gate = nn.Linear(d_model, n_experts, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        gate_logits = self.gate(x)  # (B, T, C) -> (B, T, E)
        topk = torch.topk(gate_logits, self.n_active_experts, dim=-1)  # (B, T, k)
        gate_weights = torch.softmax(topk.values, dim=-1)  # (B, T, k)
        outputs = torch.zeros_like(x)
        for i, expert in enumerate(self.experts):
            expert_mask = (topk.indices == i)  # (B, T, k)
            if expert_mask.any():
                token_mask = expert_mask.any(dim=-1)  # (B, T)
                outputs[token_mask] += gate_weights[expert_mask].unsqueeze(-1) * expert(x[token_mask])
        return outputs

class SwiGLU(nn.Module):
    """
    SwiGLU: https://arxiv.org/pdf/2002.05202
    Swish aka SiLU: https://pytorch.org/docs/stable/generated/torch.nn.SiLU.html
    """
    def __init__(self, d_model: int, d_ff: int) -> None:
        super().__init__()
        self.gate = nn.Linear(d_model, d_ff, bias=False)
        self.up_proj = nn.Linear(d_model, d_ff, bias=False)
        self.down_proj = nn.Linear(d_ff, d_model, bias=False)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.silu(self.gate(x)) * self.up_proj(x)
        return self.down_proj(x)

B, T, C = 32, 16, 512
x_rand = torch.randn(B, T, C)
moe = MoE(n_experts=8, n_active_experts=2, d_model=C, d_ff=4*C)
assert moe(x_rand).shape == (B, T, C)