In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(69)

<torch._C.Generator at 0x78c477fbd3f0>

Downloading shakespear dataset

In [2]:
!wget https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/input.txt

--2025-07-11 22:28:42--  https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2025-07-11 22:28:42 (17.9 MB/s) - ‘input.txt’ saved [1115394/1115394]



## Starting with coding out FFNN and their experts individually.

- We are using ReLU() as an activation

In [3]:
class Expert(nn.Module):
  """
  Expert networkA simple MLP with a linear layer followed by a ReLU activation for each experts.
  """

  def __init__(self, embed_dim, dropout=0.1):
    super().__init__()
    self.net = nn.Sequential(
        nn.Linear(embed_dim, 4*embed_dim),
        nn.ReLU(),
        nn.Linear(4*embed_dim, embed_dim),
        nn.Dropout(dropout)
    )

  def forward(self, x):
    return self.net(x)

So now we need a routing matrix, which helps us in routing the input to various experts, along the entire MLP.

In [4]:
class NoisyTopK(nn.Module):
  """
  This is the class for the routing matrix and getting the top k experts.
  """
  def __init__(self, n_embed, n_experts, top_k):
    super().__init__()
    # This is the routing matrix which goes from embedding dim to number of experts and topk
    self.linear = nn.Linear(n_embed, n_experts)
    self.top_k = top_k

    # A bit of noise
    self.noise = nn.Linear(n_embed, n_experts)

  def forward(self, x):
    # Getting the expert selector matrix and then getting topk results from each dimensions
    logits = self.linear(x)
    noise_logits = self.noise(x)
    noisy_logits = logits + noise_logits

    topk_logits, topk_indices = torch.topk(noisy_logits, k=self.top_k, dim=2)

    # we create a same shaped matrix with all being -inf and then wherever the indices are for topk, we leave that and make others -inf
    zeroes = torch.full_like(noisy_logits, float('-inf'))
    sparse_logits = zeroes.scatter(-1, topk_indices, topk_logits)
    router_output = F.softmax(sparse_logits, dim=-1)

    return router_output, topk_indices

We make something we call it as NosiyTopK gating which is basically to add some noise (Gaussian noise) to the Selector matrix for load balancing. Just a small addition.

## And finally sparse mixture of experts

In [5]:
class SparseMoE(nn.Module):

  def __init__(self, embed_dim, n_experts, top_k):
    super().__init__()
    self.router = NoisyTopK(embed_dim, n_experts, top_k)
    self.experts = nn.ModuleList([Expert(embed_dim) for _ in range(n_experts)])
    self.topk = top_k

  def forward(self, x):
    gating_output , indices = self.router(x)
    final_output = torch.zeros_like(x)

    # Reshaping for batch processing
    flat_x = x.view(-1, x.size(-1)) # [batch, seq, emb] -> [batch*seq, emb]
    flat_gatting_output = gating_output.view(-1, gating_output.size(-1)) # [batch, seq, n_experts] -> [batch*seq, n_experts]

    # Processing each expert in parellel
    for i, expert in enumerate(self.experts):
      # Creating a mask where each token is routed to expert i
        # For example, expert_mask = [True, False, False, True, ...]
          # Shape: [batch, seq_len] — one True/False per token
      expert_mask = (indices == i).any(dim=-1) # [batch, seq_len]

      # Flattened to [batch * seq_len] so it matches flat_x
      flat_mask = expert_mask.view(-1) # [batch * seq_len]

      if flat_mask.any():
        # WHERVER we have TRUE in flat_mask, we take those tokens from flat_x, we pass them through expert and we save those tokens
          # At the exact places in final_output where we have true in corespondance to flat_mask
        expert_input = flat_x[flat_mask]
        expert_output = expert(expert_input)

        # Extracting and applying gating scores
        gating_scores = flat_gatting_output[flat_mask, i].unsqueeze(1)
        weighted_expert_output = gating_scores * expert_output

        # putting in weighted expert outputs to the final output matrix
        final_output[expert_mask] += weighted_expert_output.squeeze(1)

    return final_output