In [1]:
import torch
import torch.nn as nn
import torch.distributed as dist

In [2]:
class Expert(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()

        self.d_model = d_model
        self.d_ff = d_ff
        
        self.mlp = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d_model)
        )

    def forward(self, x):
        return self.mlp(x)

In [3]:
class TopKRouter(nn.Module):
    def __init__(self, d_model, n_experts, top_k=2, capacity_factor=1.25):
        super().__init__()

        self.d_model = d_model
        self.n_experts = n_experts
        self.top_k = top_k
        self.capacity_factor = capacity_factor

        self.router = nn.Linear(d_model, n_experts, bias=False)

    def forward(self, x):
        batch_size, seq_len, hidden_size = x.shape
        num_tokens = batch_size * seq_len

        x = x.reshape(num_tokens, hidden_size)
        router_logits = self.router(x)

        router_probs = torch.softmax(router_logits, dim=-1)

        expert_weights, expert_indices = torch.topk(router_probs, self.top_k, dim=-1)

        expert_weights = expert_weights / expert_weights.sum(dim=-1, keepdim=True)

        load_balancing_loss = self._compute_load_balancing_loss(router_probs)

        return expert_indices, expert_weights, load_balancing_loss

    def _compute_load_balancing_loss(self, router_probs):
        num_tokens = router_probs.shape[0]

        expert_assignment = torch.argmax(router_probs, dim=-1)
        expert_count = torch.bincount(
            expert_assignment
        ).float()

        f_i = expert_count / num_tokens

        p_i = router_probs.mean(dim=0)

        loss = self.n_experts * (f_i * p_i).sum()

        return loss

In [4]:
class DistributedMoELayer(nn.Module):
    def __init__(self, d_model, n_experts, d_ff, world_size, rank, top_k=2, capacity_factor=1.25):
        super().__init__()

        self.d_model = d_model
        self.n_experts = n_experts
        self.d_ff = d_ff
        self.world_size = world_size
        self.rank = rank

        self.top_k = top_k
        self.capacity_factor = capacity_factor

        self.n_local_experts = n_experts // world_size

        self.router = TopKRouter(d_model, n_experts, top_k, capacity_factor)

        self.experts = nn.ModuleList([
            Expert(d_model, d_ff)
            for _ in range(self.n_local_experts)
        ])

    def forward(self, x):
        batch_size, seq_len, d_model = x.shape
        num_tokens = batch_size * seq_len

        x = x.reshape(batch_size, seq_len)

        expert_weights, expert_indices, load_balancing_loss = self.router(x)

        capacity = int((num_tokens / self.n_experts) * self.capacity_factor)

        dispatch_data, combine_weights = self._prepare_all_to_all(
            x, expert_indices, expert_weights, capacity
        )

        received_data = self._all_to_all_scatter(dispatch_data)

        expert_output = self._process_local_experts(received_data)

        gathered_output = self._all_to_all_gather(expert_output)

        output = self._combine_outputs(gathered_output, combine_weights)

        output = output.reshape(batch_size, seq_len, d_model)

        return output, load_balancing_loss

In [6]:
class MoETransformerBlock(nn.Module):
    """
    Complete transformer block with MoE instead of dense FFN.
    """
    def __init__(self, hidden_size, num_heads, num_experts, intermediate_size,
                 top_k=2, capacity_factor=1.25):
        super().__init__()
        self.hidden_size = hidden_size
        
        # Layer norms
        self.ln1 = nn.LayerNorm(hidden_size)
        self.ln2 = nn.LayerNorm(hidden_size)
        
        # Attention (standard multi-head attention)
        self.attention = nn.MultiheadAttention(
            hidden_size, 
            num_heads, 
            batch_first=True
        )
        
        # MoE layer (replaces dense FFN)
        self.moe = DistributedMoELayer(
            hidden_size, 
            num_experts, 
            intermediate_size,
            top_k, 
            capacity_factor
        )
    
    def forward(self, x, attn_mask=None):
        """
        Args:
            x: (batch, seq, hidden)
            attn_mask: Optional attention mask
        
        Returns:
            output: (batch, seq, hidden)
            aux_loss: Load balancing loss
        """
        # Attention block
        normed = self.ln1(x)
        attn_output, _ = self.attention(normed, normed, normed, attn_mask=attn_mask)
        x = x + attn_output
        
        # MoE block
        normed = self.ln2(x)
        moe_output, aux_loss = self.moe(normed)
        x = x + moe_output
        
        return x, aux_loss