In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F

![](https://noblecatt-1304922865.cos.ap-singapore.myqcloud.com/202502162115032.png)


In [16]:
from dataclasses import dataclass


@dataclass
class Config:

    hidden_dim = 512
    num_heads = 8
    num_layers = 6
    ff_dim = 2048
    # MoE
    num_experts: int = 4
    capacity_factor = 1.0
    use_aux_loss = False

In [17]:
class Router(nn.Module):
    def __init__(self, config: Config, epsilon=1e-6):
        super().__init__()
        self.num_experts = config.num_experts
        self.hidden_dim = config.hidden_dim
        self.capacity_factor = config.capacity_factor

        self.epsilon = epsilon

        self.w_gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)

    def forward(self, x, use_aux_loss=False):
        # x: (B, S, H )

        # Get the probability of each expert
        gate_scores = F.softmax(self.w_gate(x), dim=-1)  # (B, S, E)

        # Determine the top-1 expert for each token
        capacity = int(self.capacity_factor * self.num_experts)
        top_k_scores, top_k_indices = torch.topk(gate_scores, capacity, dim=-1)

        # Mask for enforce sparsity
        mask = torch.zeros_like(gate_scores).scatter_(1, top_k_indices, 1.0)

        masked_gate_scores = gate_scores * mask

        # Denominators
        denominators = masked_gate_scores.sum(dim=0, keepdim=True) + self.epsilon

        gate_scores = (masked_gate_scores / denominators) * capacity

        if use_aux_loss:
            load = gate_scores.sum(0)
            importance = gate_scores.sum(1)

            loss = ((load - importance) ** 2).mean()

            return gate_scores, loss

        return gate_scores, None

In [18]:
class FeedForward(nn.Module):
    def __init__(self, config: Config):
        super().__init__()

        self.hidden_dim = config.hidden_dim
        self.ff_dim = config.ff_dim

        self.linear1 = nn.Linear(self.hidden_dim, self.ff_dim)
        self.activation = nn.ReLU()
        self.linear2 = nn.Linear(self.ff_dim, self.hidden_dim)

    def forward(self, x):
        return self.linear2(self.activation(self.linear1(x)))

In [41]:
class SwitchMoE(nn.Module):
    def __init__(self, config: Config):
        super().__init__()

        self.dim = config.hidden_dim
        self.ff_dim = config.ff_dim
        self.num_experts = config.num_experts
        self.capacity_factor = config.capacity_factor
        self.mult = 4
        self.use_aux_loss = config.use_aux_loss

        self.experts = nn.ModuleList(
            [FeedForward(config) for _ in range(self.num_experts)]
        )

        self.router = Router(config)

    def forward(self, x):
        gate_scores, loss = self.router(x, self.use_aux_loss)

        # Dispatch to experts
        expert_outputs = [expert(x) for expert in self.experts]

        if torch.isnan(gate_scores).any():
            print("nan in gate_scores")
            gate_scores[torch.isnan(gate_scores)] = 0

        # Stack and weight outputs
        stacked_expert_outputs = torch.stack(expert_outputs, dim=-1)

        if torch.isnan(stacked_expert_outputs).any():
            stacked_expert_outputs[torch.isnan(stacked_expert_outputs)] = 0

        moe_output = torch.sum(
            gate_scores.unsqueeze(-2) * stacked_expert_outputs, dim=-1
        )

        return moe_output, loss

In [42]:
class MHA(nn.Module):
    def __init__(self, config: Config):
        super().__init__()

        self.num_heads = config.num_heads
        self.hidden_dim = config.hidden_dim

        assert (
            self.hidden_dim % self.num_heads == 0
        ), "hidden_dim must be divisible by num_heads"

        self.head_dim = self.hidden_dim // self.num_heads

        self.qkv_ln = nn.Linear(self.hidden_dim, 3 * self.hidden_dim)
        self.out_ln = nn.Linear(self.hidden_dim, self.hidden_dim)

    def forward(self, x, mask=None):
        q, k, v = map(
            lambda t: t.view(
                x.size(0), x.size(1), self.num_heads, self.head_dim
            ).transpose(1, 2),
            self.qkv_ln(x).chunk(3, dim=-1),
        )

        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim**0.5)
        scores = F.softmax(scores, dim=-1)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        output = (
            torch.matmul(scores, v)
            .transpose(1, 2)
            .contiguous()
            .view(x.size(0), x.size(1), self.hidden_dim)
        )

        return self.out_ln(output)

In [43]:
class SwitchMoEBlock(nn.Module):
    def __init__(self, config: Config):
        super().__init__()

        self.mha = MHA(config)
        self.moe = SwitchMoE(config)
        self.ff = FeedForward(config)

        self.norm1 = nn.LayerNorm(config.hidden_dim)
        self.norm2 = nn.LayerNorm(config.hidden_dim)

    def forward(self, x, mask=None):
        x = x + self.mha(self.norm1(x), mask)
        x, aux_loss = self.moe(self.norm2(x))
        x = x + self.ff(x)

        return x, aux_loss

In [44]:
class FakeEmbedding(nn.Module):
    def __init__(self, config: Config):
        super().__init__()

        self.embedding = nn.Embedding(1000, config.hidden_dim)

    def forward(self, x):
        return self.embedding(x)

In [49]:
class SwitchTransformer(nn.Module):
    def __init__(self, config: Config):
        super().__init__()

        self.embedding = FakeEmbedding(config)

        self.num_layers = config.num_layers
        self.blocks = nn.ModuleList(
            [SwitchMoEBlock(config) for _ in range(self.num_layers)]
        )

    def forward(self, x, mask=None):
        aux_losses = []
        x = self.embedding(x)
        for block in self.blocks:
            x, aux_loss = block(x, mask)
            if aux_loss is not None:
                aux_losses.append(aux_loss)

        return x, aux_losses

In [50]:
def causal_mask(size):
    mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int)
    return mask == 0

In [51]:
x = torch.randint(0, 1000, (32, 128))
tgt_mask = causal_mask(128)

config = Config()
transformer = SwitchTransformer(config)

assert transformer(x, tgt_mask)[0].shape == (32, 128, 512)
print("All tests passed!")

All tests passed!
