# Mixture-of-Experts (MoE)

## Sparse-Gated Mixture of Experts in LSTM [Shazeer et al. ICLR 2017](https://arxiv.org/pdf/1701.06538)

In [None]:
import torch
import torch.nn as nn

# Gating network based on LSTM
class LSTMGatingNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_experts):
        super(LSTMGatingNetwork, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, num_experts)
        
    def forward(self, x):
        # x shape: (batch, seq_len, input_dim)
        lstm_out, _ = self.lstm(x)              # lstm_out: (batch, seq_len, hidden_dim)
        scores = self.fc(lstm_out)             # scores: (batch, seq_len, num_experts)
        weights = torch.softmax(scores, dim=-1) # convert to probabilities per time step
        return weights

# Mixture-of-Experts model with LSTM gating
class LSTMMoE(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_experts, expert_network=None):
        super(LSTMMoE, self).__init__()
        # Gating network (LSTM-based)
        self.gating_network = LSTMGatingNetwork(input_dim, hidden_dim, num_experts)
        # Expert networks (if not provided, use simple linear layers as experts)
        if expert_network is None:
            # Default: each expert is a Linear layer from input_dim -> output_dim
            self.experts = nn.ModuleList([nn.Linear(input_dim, output_dim) for _ in range(num_experts)])
        else:
            # If a custom expert network class is provided, instantiate for each expert
            self.experts = nn.ModuleList([expert_network() for _ in range(num_experts)])
        
    def forward(self, x):
        # x shape: (batch, seq_len, input_dim)
        batch_size, seq_len, _ = x.size()
        # 1. Get gating weights from LSTM gating network
        gating_weights = self.gating_network(x)       # shape: (batch, seq_len, num_experts)
        # 2. Compute outputs of each expert on the inputs
        expert_outputs = [expert(x) for expert in self.experts]  # list of tensors, each (batch, seq_len, output_dim)
        expert_outputs = torch.stack(expert_outputs, dim=-1)     # shape: (batch, seq_len, output_dim, num_experts)
        # 3. Weight and sum expert outputs using the gating weights
        weights_expanded = gating_weights.unsqueeze(-2)          # shape: (batch, seq_len, 1, num_experts)
        combined_output = (expert_outputs * weights_expanded).sum(dim=-1)  # (batch, seq_len, output_dim)
        return combined_output

# Example usage:
input_dim, hidden_dim, output_dim, num_experts = 4, 8, 3, 2
model = LSTMMoE(input_dim, hidden_dim, output_dim, num_experts)
# Dummy input: batch of 1 sequence, length 5, feature dim 4
x = torch.randn(1, 5, input_dim)
y = model(x)
print("Input shape:", x.shape)
print("Output shape:", y.shape)
print("Output:", y)

In [None]:
# Output (example):
# Input shape: torch.Size([1, 5, 4])  
# Output shape: torch.Size([1, 5, 3])  
# Output: tensor([[[ 0.1991, -0.2271, -0.4974],
#          [-0.0026, -0.2181, -0.4217],
#          [-0.1261, -0.1725,  0.1611],
#          [-0.1749,  0.2343, -0.2493],
#          [ 0.2959,  0.3869, -0.7265]]], grad_fn=<SumBackward1>)

## Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity. [Fedus et al, ICLR, 2021](https://arxiv.org/pdf/2101.03961)

In [None]:
import torch
import torch.nn as nn

# Switch Transformer style MoE layer
class SwitchMoE(nn.Module):
    def __init__(self, input_dim, expert_hidden_dim, output_dim, num_experts):
        super(SwitchMoE, self).__init__()
        self.num_experts = num_experts
        # Gating (router) network: a linear layer that scores each expert for a token
        self.gate = nn.Linear(input_dim, num_experts)
        # Expert networks: each is a feed-forward MLP (two linear layers with ReLU)
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(input_dim, expert_hidden_dim),
                nn.ReLU(),
                nn.Linear(expert_hidden_dim, output_dim)
            ) for _ in range(num_experts)
        ])
    
    def forward(self, x):
        # x shape: (batch, seq_len, input_dim)
        batch_size, seq_len, dim = x.size()
        # 1. Compute gating scores for each token and select top-1 expert
        gating_scores = self.gate(x)                        # (batch, seq_len, num_experts)
        expert_indices = gating_scores.argmax(dim=-1)       # (batch, seq_len) index of chosen expert per token
        
        # 2. Prepare an output tensor
        output = torch.zeros(batch_size, seq_len, self.experts[0][-1].out_features)
        
        # 3. Route tokens to their chosen experts and compute expert outputs
        # Flatten batch and sequence dimensions for easier indexing
        x_flat = x.view(-1, dim)                            # shape: (batch*seq_len, input_dim)
        indices_flat = expert_indices.view(-1)              # shape: (batch*seq_len,)
        output_flat = torch.zeros(x_flat.size(0), self.experts[0][-1].out_features)
        # Process tokens group by expert to avoid loop over each token
        for expert_idx in range(self.num_experts):
            mask = (indices_flat == expert_idx)
            if mask.any():
                # select all tokens assigned to this expert
                tokens = x_flat[mask]                      # shape: (n_tokens_for_expert, input_dim)
                # compute outputs for these tokens using the expert
                tokens_out = self.experts[expert_idx](tokens)  # (n_tokens_for_expert, output_dim)
                output_flat[mask] = tokens_out             # place outputs in the corresponding positions
        # Reshape back to (batch, seq_len, output_dim)
        output = output_flat.view(batch_size, seq_len, -1)
        return output

# Example usage:
input_dim = 5
output_dim = 5   # usually same as input_dim in transformer for residual connection
num_experts = 3
expert_hidden_dim = 10  # hidden layer size in each expert FFN

model = SwitchMoE(input_dim, expert_hidden_dim, output_dim, num_experts)
# Dummy input: batch of 2 sequences, each with 4 tokens (seq_len=4), token feature dim=5
x = torch.randn(2, 4, input_dim)
y = model(x)
print("Input shape:", x.shape)
print("Output shape:", y.shape)
print("Token-to-Expert assignments:\n", model.gate(x).argmax(dim=-1))
print("Output:\n", y)

In [None]:
# Output (example):
# Input shape: torch.Size([2, 4, 5])  
# Output shape: torch.Size([2, 4, 5])  
# Token-to-Expert assignments:
#  tensor([[1, 1, 1, 0],
#         [1, 0, 2, 1]])  
# Output:
#  tensor([[[ 0.4486, -0.2896, -0.2615, -0.2078,  0.0117],
#           [ 0.4096, -0.3731, -0.2567, -0.2418, -0.0269],
#           [ 0.3724, -0.5101, -0.2314, -0.2437,  0.0134],
#           [-0.1222, -0.6251,  0.2697,  0.2317, -0.2568]],

#          [[ 0.5131, -0.3547, -0.3457, -0.1203,  0.1192],
#           [-0.0289, -0.3413,  0.2114,  0.0775,  0.0413],
#           [-0.2445,  0.0563,  0.1714,  0.2636,  0.3997],
#           [ 0.3778, -0.4098, -0.2812, -0.3058, -0.2247]]])