## Import Libraries
Imports PyTorch and necessary modules for building the Mixture of Experts model. Sets up torch.manual_seed for reproducibility.

In [33]:
import torch
import torch.nn as nn 
from torch.nn import functional as F
torch.manual_seed

<function torch.random.manual_seed(seed) -> torch._C.Generator>

## Download Training Data
Downloads the input text file from GitHub that will be used for training the MoE model.

In [34]:
!curl -O https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/input.txt

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 1089k  100 1089k    0     0  1937k      0 --:--:-- --:--:-- --:--:-- 1934k


## Define Expert Module
Defines an `Expert` class - a simple MLP (Multi-Layer Perceptron) that serves as one expert in the MoE architecture. Each expert consists of two linear layers with ReLU activation and dropout for regularization.

In [35]:
# Expert module
'''An MLP is a simple linear layer followed by a non-linearity i.e. each Expert'''
class Expert(nn.Module):
    def __init__(self, n_embd, dropout):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )
    
    def forward(self, x):
        return self.net(x)

## Understanding Gating Mechanism - Setup
Sets up a simple example to demonstrate how the gating/routing mechanism works. Creates sample multi-head attention output and a linear layer to produce logits for expert selection.

In [36]:
#Understanding how gating works
num_experts = 4
top_k=2
n_embed=8

#Example multi-head attention output for a simple illustrative example, consider n_embed=32, context_length=4
mh_output = torch. rand(1, 4, n_embed)
topkgate_linear = nn.Linear(n_embed, num_experts) # nn.Linear(32, 4)
logits = topkgate_linear (mh_output)
print (logits)

tensor([[[-0.0123,  0.3042,  0.4986,  0.3198],
         [ 0.1164,  0.1188,  0.2375,  0.1165],
         [ 0.3512,  0.1809,  0.1322,  0.2058],
         [-0.2247,  0.2521,  0.6194,  0.2505]]], grad_fn=<ViewBackward0>)


## Select Top-K Experts
Demonstrates selecting the top-k (top 2) experts with highest logits for each token. Returns both the logit values and the indices of selected experts.

In [37]:
top_k_logits, top_k_indices = logits.topk(top_k, dim=-1) # Get top-k experts
top_k_logits, top_k_indices

(tensor([[[0.4986, 0.3198],
          [0.2375, 0.1188],
          [0.3512, 0.2058],
          [0.6194, 0.2521]]], grad_fn=<TopkBackward0>),
 tensor([[[2, 3],
          [2, 1],
          [0, 3],
          [2, 1]]]))

## Create Sparse Logits
Creates a sparse representation by filling a tensor with -inf and scattering only the top-k logit values back. This ensures only selected experts will have non-zero weights after softmax.

In [38]:
zeros = torch.full_like(logits, float('-inf')) #full_like clones a tensor and fills it with a specified
sparse_logits = zeros.scatter(-1, top_k_indices, top_k_logits)
sparse_logits

tensor([[[  -inf,   -inf, 0.4986, 0.3198],
         [  -inf, 0.1188, 0.2375,   -inf],
         [0.3512,   -inf,   -inf, 0.2058],
         [  -inf, 0.2521, 0.6194,   -inf]]], grad_fn=<ScatterBackward0>)

## Apply Softmax for Gating Weights
Applies softmax to the sparse logits to produce final gating weights. Non-selected experts (with -inf logits) will have zero probability, implementing sparse routing.

In [39]:
gating_output= F.softmax(sparse_logits, dim=-1)
gating_output

tensor([[[0.0000, 0.0000, 0.5446, 0.4554],
         [0.0000, 0.4704, 0.5296, 0.0000],
         [0.5363, 0.0000, 0.0000, 0.4637],
         [0.0000, 0.4092, 0.5908, 0.0000]]], grad_fn=<SoftmaxBackward0>)

## Define TopkRouter Class
Implements the complete top-k routing mechanism as a PyTorch module. Takes multi-head attention output and routes it to the top-k experts based on learned linear transformation and softmax selection.

In [40]:
class TopkRouter (nn.Module):

    def __init__(self, n_embed, num_experts, top_k) :
        super (TopkRouter, self). __init__()
        self.top_k = top_k
        self. linear =nn. Linear(n_embed, num_experts)
    def forward (self, mh_ouput) :
        # mh_ouput is the output tensor from multihead self attention block
        logits = self. linear (mh_output)
        top_k_logits, indices = logits.topk(self.top_k, dim=-1)
        zeros = torch.full_like(logits, float('-inf'))
        sparse_logits = zeros.scatter(-1, indices, top_k_logits)
        router_output = F.softmax (sparse_logits, dim=-1)
        return router_output, indices

## Test TopkRouter
Tests the TopkRouter with 3 experts, top-k=2, and embedding dimension of 8. Demonstrates the shape of outputs and shows which experts are selected for each token position.

In [41]:
num_experts = 3
top_k = 2
n_embd = 8
mh_output = torch. randn(1, 4, n_embd) # Example input
top_k_gate = TopkRouter (n_embd, num_experts,top_k)
gating_output, indices = top_k_gate(mh_output)
gating_output.shape, gating_output, indices

(torch.Size([1, 4, 3]),
 tensor([[[0.0000, 0.7829, 0.2171],
          [0.7314, 0.0000, 0.2686],
          [0.5791, 0.0000, 0.4209],
          [0.7519, 0.0000, 0.2481]]], grad_fn=<SoftmaxBackward0>),
 tensor([[[1, 2],
          [0, 2],
          [0, 2],
          [0, 2]]]))

## Noisy Top-K Gating

Noisy top-k gating is an important technique for training Mixture of Experts models effectively.

**Problem Without Noisy Gating:**
- All tokens tend to route to the same set of "favored" experts (the ones with highest logits)
- This causes load imbalance - some experts are overused while others are underutilized
- The model becomes inefficient and wastes computational resources

**Solution - Noisy Top-K Gating:**
- Add Gaussian noise to the logits from the gating linear layer during training
- This encourages random exploration of different expert combinations
- Creates a balance between **exploitation** (using the best experts) and **exploration** (trying other experts)

**Benefits:**
1. **Load Balancing**: Distributes tokens more evenly across all experts
2. **Better Training**: Prevents experts from becoming inactive or redundant
3. **Improved Performance**: Leads to more efficient and effective MoE models
4. **Regularization**: Acts as a form of regularization during training

**Implementation:**
During training, add noise to logits: `logits_noisy = logits + Gaussian_noise`
During inference, use the clean logits without noise for deterministic expert selection

In [42]:
class NoisyTopkRouter(nn.Module):
    """
    Implements noisy top-k routing for Mixture of Experts.
    Adds Gaussian noise to logits during training to encourage load balancing.
    """
    def __init__(self, n_embed, num_experts, top_k):
        super(NoisyTopkRouter, self).__init__()
        self.top_k = top_k
        self.num_experts = num_experts
        # Layer for router logits
        self.topkroute_linear = nn.Linear(n_embed, num_experts)
        # Layer for noise scaling
        self.noise_linear = nn.Linear(n_embed, num_experts)
    
    def forward(self, mh_output, training=True):
        """
        Args:
            mh_output: Output from multihead self-attention block
            training: Whether in training mode (adds noise) or inference mode
        
        Returns:
            router_output: Gating weights for expert selection
            indices: Indices of top-k selected experts
        """
        # Get logits from the router linear layer
        logits = self.topkroute_linear(mh_output)
        
        if training:
            # Get noise scaling logits
            noise_logits = self.noise_linear(mh_output)
            # Add scaled unit gaussian noise to the logits
            # F.softplus ensures noise scaling is positive
            noise = torch.randn_like(logits) * F.softplus(noise_logits)
            noisy_logits = logits + noise
        else:
            # Use clean logits during inference
            noisy_logits = logits
        
        # Select top-k experts
        top_k_logits, indices = noisy_logits.topk(self.top_k, dim=-1)
        
        # Create sparse logits (set non-top-k to -inf)
        zeros = torch.full_like(logits, float('-inf'))
        sparse_logits = zeros.scatter(-1, indices, top_k_logits)
        
        # Apply softmax for gating weights
        router_output = F.softmax(sparse_logits, dim=-1)
        
        return router_output, indices

In [43]:
# Test NoisyTopkRouter
num_experts = 4
top_k = 2
n_embd = 8
mh_output = torch.randn(1, 4, n_embd)  # Batch size 2, sequence length 4

# Create router
noisy_top_k_gate = NoisyTopkRouter(n_embd, num_experts, top_k)
gating_output, indices = noisy_top_k_gate(mh_output)
gating_output.shape, gating_output,indices

(torch.Size([1, 4, 4]),
 tensor([[[0.4249, 0.0000, 0.5751, 0.0000],
          [0.7868, 0.0000, 0.0000, 0.2132],
          [0.2520, 0.0000, 0.7480, 0.0000],
          [0.8716, 0.1284, 0.0000, 0.0000]]], grad_fn=<SoftmaxBackward0>),
 tensor([[[2, 0],
          [0, 3],
          [2, 0],
          [0, 1]]]))

## Create the Sparse Mixture of Experts (MoE)

![MoE Architecture](assets/MoE_enhanced.png)

## SparseMoE Block Output Computation

The SparseMoE (Sparse Mixture of Experts) block's output is computed through a multi-step process:

### (a) Expert Selector Weight Matrix
The primary aspect of this process involves the **expert selector weight matrix**, which is generated by the router/gating mechanism. This weight matrix contains:
- One row per token in the sequence
- One column per expert in the mixture
- Values representing the gating weights for routing each token to experts
- Only **top-k experts have non-zero weights** (others are zero due to sparse routing)

### (b) Top-K Expert Output Multiplication
After acquiring the expert selector weight matrix, the **top-k gating weights are selectively multiplied** with the outputs from the corresponding top-k experts for each token:

$$\text{weighted\_expert\_output} = \text{gating\_weight} \times \text{expert\_output}$$

For each token, only the k experts with the highest gating weights process that token, making the computation sparse and efficient.

### (c) Weighted Sum - SparseMoE Output
This selective multiplication of gating weights with expert outputs forms a **weighted sum**, which constitutes the SparseMoE block's final output:

$$\text{SparseMoE\_output} = \sum_{i \in \text{top-k experts}} \text{gating\_weight}_i \times \text{expert\_output}_i$$

Where:
- The sum iterates only over the selected top-k experts
- Each expert's output is weighted by its corresponding gating weight
- The result is a single output vector per token that combines information from multiple experts
- This weighted sum effectively blends the knowledge of multiple specialized experts for each token

In [44]:
class SparseMoE(nn.Module):
    """
    Sparse Mixture of Experts module that routes tokens to top-k experts
    and combines their outputs using learned gating weights.
    """
    def __init__(self, n_embed, num_experts, top_k, dropout=0.1):
        super(SparseMoE, self).__init__()
        self.router = NoisyTopkRouter(n_embed, num_experts, top_k)
        self.experts = nn.ModuleList([Expert(n_embed, dropout) for _ in range(num_experts)])
        self.top_k = top_k
    
    def forward(self, x, training=True):
        """
        Args:
            x: Input tensor of shape (batch_size, seq_len, n_embed)
            training: Whether in training mode (affects gating noise)
        
        Returns:
            final_output: Weighted combination of expert outputs
        """
        # Get gating weights and expert indices from router
        gating_output, indices = self.router(x, training=training)
        
        # Initialize output tensor with same shape as input
        final_output = torch.zeros_like(x)
        
        # Reshape inputs for batch processing
        flat_x = x.view(-1, x.size(-1))
        flat_gating_output = gating_output.view(-1, gating_output.size(-1))
        
        # Process each expert in parallel
        for i, expert in enumerate(self.experts):
            # Create a mask for the inputs where the current expert is in top-k
            expert_mask = (indices == i).any(dim=-1)
            flat_mask = expert_mask.view(-1)
            
            if flat_mask.any():
                # Extract inputs for this expert
                expert_input = flat_x[flat_mask]
                
                # Process through expert
                expert_output = expert(expert_input)
                
                # Extract and apply gating scores
                gating_scores = flat_gating_output[flat_mask, i].unsqueeze(1)
                weighted_output = expert_output * gating_scores
                
                # Update final output additively by indexing and adding
                final_output[expert_mask] += weighted_output.squeeze(1)
        
        return final_output

In [45]:
# Test SparseMoE
num_experts = 4
top_k = 2
n_embd = 8
dropout = 0.1

# Create SparseMoE block
sparse_moe = SparseMoE(n_embd, num_experts, top_k, dropout)

# Create sample input
batch_size = 2
seq_len = 4
x = torch.randn(batch_size, seq_len, n_embd)

# Forward pass
output = sparse_moe(x, training=True)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Output:\n{output}")
print(f"\nShape matches input: {output.shape == x.shape}")

Input shape: torch.Size([2, 4, 8])
Output shape: torch.Size([2, 4, 8])
Output:
tensor([[[-0.1479,  0.1550, -0.0756, -0.1690, -0.0151, -0.0343, -0.1801,
           0.0772],
         [-0.1597,  0.2710,  0.0299,  0.0714, -0.0486,  0.0669, -0.5236,
          -0.2877],
         [ 0.0617, -0.0243, -0.0955,  0.0070,  0.0805, -0.0490, -0.3016,
          -0.0283],
         [ 0.1945, -0.1716,  0.1555, -0.0450, -0.0152,  0.2378, -0.0966,
          -0.0355]],

        [[ 0.3903,  0.3048, -0.1689,  0.1523,  0.0050, -0.0926, -0.1269,
          -0.1720],
         [ 0.0941,  0.1254, -0.1340, -0.1856, -0.0643, -0.0652, -0.1836,
           0.0155],
         [-0.2607,  0.5290,  0.2636, -0.1587,  0.4651,  0.4380, -0.2503,
          -0.0622],
         [ 0.4173, -0.0921,  0.2127, -0.0892,  0.1658, -0.0387, -0.3208,
          -0.0671]]], grad_fn=<IndexPutBackward0>)

Shape matches input: True
