## Import Libraries
Imports PyTorch and necessary modules for building the Mixture of Experts model. Sets up torch.manual_seed for reproducibility.

In [1]:
import torch
import torch.nn as nn 
from torch.nn import functional as F
torch.manual_seed

<function torch.random.manual_seed(seed) -> torch._C.Generator>

## Download Training Data
Downloads the input text file from GitHub that will be used for training the MoE model.

In [3]:
!curl -O https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/input.txt

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 1089k  100 1089k    0     0   335k      0  0:00:03  0:00:03 --:--:--  335k


## Define Expert Module
Defines an `Expert` class - a simple MLP (Multi-Layer Perceptron) that serves as one expert in the MoE architecture. Each expert consists of two linear layers with ReLU activation and dropout for regularization.

In [4]:
#Expert module class Expert (nn.Module):
'''An MLP is a simple linear layer followed by a non-linearity i.e. each Expert'''
class Expert(nn.Module):
    def __init__(self,n_embd,dropout):
        super().__init__()
        self.net = nn. Sequential(
            nn. Linear (n_embd, 4 * n_embd),
            nn. ReLU(),
            nn. Linear (4 * n_embd, n_embd),
            nn. Dropout (dropout),
        )
        def forward (self, x):
            return self.net (x)

## Understanding Gating Mechanism - Setup
Sets up a simple example to demonstrate how the gating/routing mechanism works. Creates sample multi-head attention output and a linear layer to produce logits for expert selection.

In [7]:
#Understanding how gating works
num_experts = 4
top_k=2
n_embed=8

#Example multi-head attention output for a simple illustrative example, consider n_embed=32, context_length=4
mh_output = torch. rand(1, 4, n_embed)
topkgate_linear = nn.Linear(n_embed, num_experts) # nn.Linear(32, 4)
logits = topkgate_linear (mh_output)
print (logits)

tensor([[[-0.4867, -0.3086, -0.2504, -0.2049],
         [-0.0975,  0.0568, -0.0490, -0.4454],
         [-0.4659, -0.2366,  0.2183, -0.3786],
         [-0.3189, -0.3614,  0.2343, -0.5286]]], grad_fn=<ViewBackward0>)


## Select Top-K Experts
Demonstrates selecting the top-k (top 2) experts with highest logits for each token. Returns both the logit values and the indices of selected experts.

In [8]:
top_k_logits, top_k_indices = logits.topk(top_k, dim=-1) # Get top-k experts
top_k_logits, top_k_indices

(tensor([[[-0.2049, -0.2504],
          [ 0.0568, -0.0490],
          [ 0.2183, -0.2366],
          [ 0.2343, -0.3189]]], grad_fn=<TopkBackward0>),
 tensor([[[3, 2],
          [1, 2],
          [2, 1],
          [2, 0]]]))

## Create Sparse Logits
Creates a sparse representation by filling a tensor with -inf and scattering only the top-k logit values back. This ensures only selected experts will have non-zero weights after softmax.

In [9]:
zeros = torch.full_like(logits, float('-inf')) #full_like clones a tensor and fills it with a specified
sparse_logits = zeros.scatter(-1, top_k_indices, top_k_logits)
sparse_logits

tensor([[[   -inf,    -inf, -0.2504, -0.2049],
         [   -inf,  0.0568, -0.0490,    -inf],
         [   -inf, -0.2366,  0.2183,    -inf],
         [-0.3189,    -inf,  0.2343,    -inf]]], grad_fn=<ScatterBackward0>)

## Apply Softmax for Gating Weights
Applies softmax to the sparse logits to produce final gating weights. Non-selected experts (with -inf logits) will have zero probability, implementing sparse routing.

In [10]:
gating_output= F.softmax(sparse_logits, dim=-1)
gating_output

tensor([[[0.0000, 0.0000, 0.4886, 0.5114],
         [0.0000, 0.5264, 0.4736, 0.0000],
         [0.0000, 0.3882, 0.6118, 0.0000],
         [0.3651, 0.0000, 0.6349, 0.0000]]], grad_fn=<SoftmaxBackward0>)

## Define TopkRouter Class
Implements the complete top-k routing mechanism as a PyTorch module. Takes multi-head attention output and routes it to the top-k experts based on learned linear transformation and softmax selection.

In [15]:
class TopkRouter (nn.Module):

    def __init__(self, n_embed, num_experts, top_k) :
        super (TopkRouter, self). __init__()
        self.top_k = top_k
        self. linear =nn. Linear(n_embed, num_experts)
    def forward (self, mh_ouput) :
        # mh_ouput is the output tensor from multihead self attention block
        logits = self. linear (mh_output)
        top_k_logits, indices = logits.topk(self.top_k, dim=-1)
        zeros = torch.full_like(logits, float('-inf'))
        sparse_logits = zeros.scatter(-1, indices, top_k_logits)
        router_output = F.softmax (sparse_logits, dim=-1)
        return router_output, indices

## Test TopkRouter
Tests the TopkRouter with 3 experts, top-k=2, and embedding dimension of 8. Demonstrates the shape of outputs and shows which experts are selected for each token position.

In [16]:
num_experts = 3
top_k = 2
n_embd = 8
mh_output = torch. randn(1, 4, n_embd) # Example input
top_k_gate = TopkRouter (n_embd, num_experts,top_k)
gating_output, indices = top_k_gate(mh_output)
gating_output.shape, gating_output, indices

(torch.Size([1, 4, 3]),
 tensor([[[0.7659, 0.0000, 0.2341],
          [0.0000, 0.7413, 0.2587],
          [0.0000, 0.7312, 0.2688],
          [0.3586, 0.6414, 0.0000]]], grad_fn=<SoftmaxBackward0>),
 tensor([[[0, 2],
          [1, 2],
          [1, 2],
          [1, 0]]]))

## Noisy Top-K Gating

Noisy top-k gating is an important technique for training Mixture of Experts models effectively.

**Problem Without Noisy Gating:**
- All tokens tend to route to the same set of "favored" experts (the ones with highest logits)
- This causes load imbalance - some experts are overused while others are underutilized
- The model becomes inefficient and wastes computational resources

**Solution - Noisy Top-K Gating:**
- Add Gaussian noise to the logits from the gating linear layer during training
- This encourages random exploration of different expert combinations
- Creates a balance between **exploitation** (using the best experts) and **exploration** (trying other experts)

**Benefits:**
1. **Load Balancing**: Distributes tokens more evenly across all experts
2. **Better Training**: Prevents experts from becoming inactive or redundant
3. **Improved Performance**: Leads to more efficient and effective MoE models
4. **Regularization**: Acts as a form of regularization during training

**Implementation:**
During training, add noise to logits: `logits_noisy = logits + Gaussian_noise`
During inference, use the clean logits without noise for deterministic expert selection