<a href="https://colab.research.google.com/github/goelnikhils-lgtm/languagemodels/blob/main/Mixture_of_Expert.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

MIXTURE OF EXPERT IMPLEMENTATION
#credits  https://www.youtube.com/watch?v=W7ktPe1HfZs&t=10s
#credits https://github.com/AviSoori1x/makeMoE/blob/main/makeMoE_from_Scratch.ipynb

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(42)

In [None]:
#load dataset
!wget https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/input.txt

In [None]:
#STEP 1
#Expert Module
#every expert is a FF layer neural layer
class Expert(nn.Module):
  """"An MLP is a single Linear Layer network followed by non-linearity i.e each expert"""
  """expansion and contraction layer"""
  """n_embd: embedding dimension - The input that comes to MOE is after MHA , Dropout and RMSNorm"""
  def __init__(self,n_embd):
    super().__init__()
    self.net = nn.Sequential(
        nn.Linear(n_embd,4*n_embd), #expansion layer in each MOE
        nn.ReLU(),
        nn.Linear(4*n_embd,n_embd), #contraction layer
        nn.Dropout(dropout),
    )
  def forward(self,x):
    return self.net(x)

In [None]:
#STEP 2
#Routing Module
"""Routing Module"""
num_experts = 3 #no of experts in MOE
top_k=2 #every token will be routed to top two experts
n_embed = 8 #every token is of 32 embeddings

#example MHA output of a simple illustrative example, consider n_embed = 32 , content_length=
mh_output = torch.randn(1,4,n_embed) #(B,no of tokens , embed_size of no of tokens)
topkgate_linear = nn.Linear(n_embed,num_experts) #nn.Linear(32,4)
logits = topkgate_linear(mh_output) #torch.Size([2, 4, 4]) expert selector matrix or logits
print(logits)

In [None]:
#STEP 3
top_k_logits , top_k_indices = logits.topk(top_k,dim=-1) #get top_k experts
top_k_logits , top_k_indices
print(top_k_logits)
print(top_k_indices)

In [None]:
#STEP 4
#use -inf and apply Softmax
zeros = torch.full_like(logits,float('-inf')) #full_like clones a tensor and fills it with a specified
sparse_logits = zeros.scatter(-1,top_k_indices,top_k_logits)
sparse_logits

In [None]:
gating_output = F.softmax(sparse_logits,dim=-1)
gating_output

In [None]:
#Step 5: Create a class for TopKRouting
class TopKRouter(nn.Module):
  def __init__(self,n_embed,num_experts,top_k):
    super(TopKRouter,self).__init__()
    self.top_k = top_k
    self.linear = nn.Linear(n_embed,num_experts)

  def forward(self,mh_output):
    #mh output is the output tensor from multihead self attention block
    logits = self.linear(mh_output)
    top_k_logits, indices = logits.topk(self.top_k,dim=-1)
    zeros = torch.full_like(logits,float('-inf'))
    sparse_logits = zeros.scatter(-1,indices,top_k_logits)
    router_output = F.softmax(sparse_logits,dim=-1)
    return router_output , indices

In [None]:
#testing this out:
num_experts = 3
top_k = 2
n_embd = 8

mh_output = torch.randn(1,4,n_embd) #example input (batchsize,no of tokens,embed_dim_size)
top_k_gate = TopKRouter(n_embd,num_experts,top_k)
gating_output , indices = top_k_gate(mh_output)
gating_output.shape , gating_output , indices

In [None]:
#Step 6
#NoisyTopk Routing
class NoisyTopkRouter(nn.Module):
  def __init__(self,n_embed,num_experts,top_k):
    super(NoisyTopkRouter,self).__init__()
    self.top_k = top_k
    #layer for router logits
    self.topkroute_linear = nn.Linear(n_embed,num_experts)
    self.noise_linear = nn.Linear(n_embed,num_experts)

  def forward(self,mh_ouput):
    #mh_output is the output tensor from multihead self attention block
    logits = self.topkroute_linear(mh_output)

    #Noise logits
    noise_logits = self.noise_linear(mh_output)
    #Adding scaled unit gaussian noise to the logits
    noise = torch.randn_like(logits)*F.softplus(noise_logits)
    noisy_logits = logits + noise

    top_k_logits, indices = noisy_logits.topk(self.top_k,dim=-1)
    zeros = torch.full_like(noisy_logits,float('-inf'))
    sparse_logits = zeros.scatter(-1,indices,top_k_logits)
    router_output = F.softmax(sparse_logits,dim=-1)
    return router_output , indices

In [None]:
#testing this out again
num_experts = 3
top_k = 2
n_embed = 8

mh_output = torch.randn(1,4,n_embed) #Example input
noisy_topk_gate = NoisyTopkRouter(n_embed,num_experts,top_k)
gating_output , indices = noisy_topk_gate(mh_output)
gating_output.shape , gating_output , indices


In [None]:
#Step 7: Create a Sparse Mixture of Expert (MoE) module
class SparseMoE(nn.Module):
  def __init__(self,n_embd,num_experts,top_k):
    super(SparseMoE,self).__init__()
    self.router = NoisyTopkRouter(n_embd,num_experts,top_k)
    self.experts = nn.ModuleList([Expert(n_embd) for _ in range(num_experts)])
    self.top_k = top_k

  def forward(self,x):
    gating_output , indices = self.router(x)
    final_output = torch.zeros_like(x)

    #Reshape inputs for Batch Processing
    flat_x = x.view(-1,x.size(-1))
    flat_gating_output = gating_output.view(-1,gating_output.size(-1))

    #Process each expert in parallel
    for i , expert in enumerate(self.experts):
      #Create a mask for the inputs where the current expert is in top-k
      expert_mask = (indices==i).any(dim=-1)
      flat_mask = expert_mask.view(-1)

      if flat_mask.any():
        expert_input  = flat_x[flat_mask]
        expert_output = expert(expert_input)

        #Extract and apply gating scores
        gating_scores = flat_gating_output[flat_mask,i].unsqueeze(1)
        weighted_output = expert_output * gating_scores

        #Update final output additively by indexing and adding
        final_output[expert_mask] += weighted_output.squeeze(1)
    return final_output

In [None]:
#Step 8
import torch
import torch.nn as nn

#let's test this out
num_experts = 3
top_k = 2
n_embed = 8
dropout = 0.1

mh_output = torch.randn(1,4,n_embed) #example multi-head attention output
sparse_moe = SparseMoE(n_embed,num_experts,top_k)
final_output = sparse_moe(mh_output)
print("Shape of the final output:", final_output.shape)
print("Final output:", final_output)

In [None]:
#Step 8 Putting all together
class Expert(nn.Module):
  """"An MLP is a single Linear Layer network followed by non-linearity i.e each expert"""
  """expansion and contraction layer"""
  """n_embd: embedding dimension - The input that comes to MOE is after MHA , Dropout and RMSNorm"""
  def __init__(self,n_embd):
    super().__init__()
    self.net = nn.Sequential(
        nn.Linear(n_embd,4*n_embd), #expansion layer in each MOE
        nn.ReLU(),
        nn.Linear(4*n_embd,n_embd), #contraction layer
        nn.Dropout(dropout),
    )
  def forward(self,x):
    return self.net(x)

#Change the above to accomdate noisy top-k gating
class NoisyTopkRouter(nn.Module):
  def __init__(self,n_embed,num_experts,top_k):
    super(NoisyTopkRouter,self).__init__()
    self.top_k = top_k
    #layer for router logits
    self.topkroute_linear = nn.Linear(n_embed,num_experts)
    self.noise_linear = nn.Linear(n_embed,num_experts)

  def forward(self,mh_ouput):
    #mh_output is the output tensor from multihead self attention block
    logits = self.topkroute_linear(mh_output)

    #Noise logits
    noise_logits = self.noise_linear(mh_output)
    #Adding scaled unit gaussian noise to the logits
    noise = torch.randn_like(logits)*F.softplus(noise_logits)
    noisy_logits = logits + noise

    top_k_logits, indices = noisy_logits.topk(self.top_k,dim=-1)
    zeros = torch.full_like(noisy_logits,float('-inf'))
    sparse_logits = zeros.scatter(-1,indices,top_k_logits)
    router_output = F.softmax(sparse_logits,dim=-1)
    return router_output , indices

#Now create the Sparse Mixture of experts module
class SparseMoE(nn.Module):
  def __init__(self,n_embd,num_experts,top_k):
    super(SparseMoE,self).__init__()
    self.router = NoisyTopkRouter(n_embd,num_experts,top_k)
    self.experts = nn.ModuleList([Expert(n_embd) for _ in range(num_experts)])
    self.top_k = top_k

  def forward(self,x):
    gating_output , indices = self.router(x)
    final_output = torch.zeros_like(x)

    #Reshape inputs for Batch Processing
    flat_x = x.view(-1,x.size(-1))
    flat_gating_output = gating_output.view(-1,gating_output.size(-1))

    #Process each expert in parallel
    for i , expert in enumerate(self.experts):
      #Create a mask for the inputs where the current expert is in top-k
      expert_mask = (indices==i).any(dim=-1)
      flat_mask = expert_mask.view(-1)

      if flat_mask.any():
        expert_input  = flat_x[flat_mask]
        expert_output = expert(expert_input)

        #Extract and apply gating scores
        gating_scores = flat_gating_output[flat_mask,i].unsqueeze(1)
        weighted_output = expert_output * gating_scores

        #Update final output additively by indexing and adding
        final_output[expert_mask] += weighted_output.squeeze(1)
    return final_output

In [None]:
#Step 9:Code the entire Transformer block: Part 1 MHA
class Head(nn.Module):
  """ one head of self attention"""
  def __init__(self,head_size):
      super().__init__()
      self.key = nn.Linear(n_embd,head_size,bias=False)
      self.query = nn.Linear(n_embd,head_size,bias=False)
      self.value = nn.Linear(n_embd,head_size,bias=False)
      self.register_buffer('tril',torch.tril(torch.ones(block_size,block_size)))
      self.dropout = nn.Dropout(dropout)

  def forward(self, x):
      B,T,C = x.shape
      k = self.key(x) # (B,T,C)
      q = self.query(x) # (B,T,C)
      #compute attention scores ("affinities")
      wei = q @ k.transpose(-2,-1) * C**-0.5 # (B,T,C) @ (B,C,T) -> (B,T,T)
      #causal mask
      wei = wei.masked_fill(self.tril[:T,:T]==0,float('-inf')) # (B,T,T)
      wei = F.softmax(wei,dim=-1) # (B,T,T)
      wei = self.dropout(wei)
      #perform the weighted aggregation of the values
      v = self.value(x) # (B,T,C)
      out = wei @ v # (B,T,T) @ (B,T,C) -> (B,T,C)
      return out

#Multi Head Attention
class MultiHeadAttention(nn.Module):
  def __init__(self,num_heads,head_size):
    super().__init__()
    self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
    self.proj = nn.Linear(n_embd,n_embd)
    self.dropout = nn.Dropout(dropout)

  def forward(self,x):
    out = torch.cat([h(x) for h in self.heads] , dim=-1)
    out = self.dropout(self.proj(out))
    return out

In [None]:
#Step 10: Code the entire transformer block: Part2(Assemble all layers)
class Block(nn.Module):
  def __init__(self,n_embd,n_head, num_experts, top_k):
    super().__init__()
    head_size = n_embd // n_head
    self.sa = MultiHeadAttention(n_head,head_size)
    self.smoe = SparseMoE(n_embd,num_experts,top_k)
    self.ln1 = nn.LayerNorm(n_embd)
    self.ln2 = nn.LayerNorm(n_embd)

  def forward(self,x):
    x = x + self.sa(self.ln1(x))
    x = x + self.smoe(self.ln2(x))
    return x


In [None]:
#Step 11: Define the entire language model architecture
class SparseMoeLanguageModel(nn.Module):

  def __init__(self):
    super().__init__()
    #each token directly reads off the logits for the next token from a lookup table
    self.token_embedding_table = nn.Embedding(vocab_size,n_embd)
    self.position_embedding_table = nn.Embedding(block_size,n_embd)
    #chain of transformer blocks. embedding is passed thru the same
    self.blocks = nn.Sequential(*[Block(n_embd,n_head=n_head,num_experts=num_experts,top_k=top_k) for _ in range(n_layer)])
    self.ln_f = nn.LayerNorm(n_embd) #final layer norm
    self.lm_head = nn.Linear(n_embd,vocab_size)

  def forward(self,idx,targets=None):
    B,T = idx.shape
    #idx and targets are both(B,T) tensor of integers
    tok_emb = self.token_embedding_table(idx) # (B,T,C)
    pos_emb = self.position_embedding_table(torch.arange(T,device=device)) # (T,C)
    x = tok_emb + pos_emb # (B,T,C) #input embedding
    x = self.blocks(x) # (B,T,C)
    x = self.ln_f(x) # (B,T,C) #layer normalization
    logits = self.lm_head(x) # (B,T,vocab_size)

    if targets is None:
      loss = None
    else:
      B,T,C = logits.shape
      logits = logits.view(B*T,C)
      targets = targets.view(B*T)
      loss = F.cross_entropy(logits,targets)
    return logits,loss

  def generate(self,idx,max_new_tokens):
      #idx is (B,T) array of indices in the current context
      for _ in range(max_new_tokens):
        #crop idx to the last block_size tokens
        idx_cond = idx[:,-block_size:]
        #get the predictions
        logits,loss = self(idx_cond)
        #focus only on the last time step
        logits = logits[:,-1,:] #becomes (B,C)
        #apply softmax to get probabilities
        probs = F.softmax(logits,dim=-1) # (B,C)
        #sample from distribution
        idx_next = torch.multinomial(probs,num_samples=1) # (B,1)
        #append sampled index to the running sequence
        idx = torch.cat((idx,idx_next),dim=1) # (B,T+1)
        return idx

In [None]:
#Step 12: Create training and testing data
torch.manual_seed(1337)
with open('input.txt','r',encoding='utf-8') as f:
  text = f.read()
#here are all the unique characters that occur in the text
chars = sorted(list(set(text)))
vocab_size = len(chars)
#create a mapping from characters to integers
stoi = {ch:i for i,ch in enumerate(chars)} # mapping character to integer
itos = {i:ch for i,ch in enumerate(chars)} # mapping integer back to character
encode = lambda s: [stoi[c] for c in s] #encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) #decoder: take a list of integers, output a string

#train and test splits
data = torch.tensor(encode(text),dtype=torch.long)
n = int(0.9*len(data)) #first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

#data loading
def get_batch(split):
  #generate a small batch of data inputs x and targets y
  data = train_data if split == 'train' else val_data
  ix = torch.randint(len(data) - block_size , (batch_size,))
  x = torch.stack([data[i:i+block_size] for i in ix])
  y = torch.stack([data[i+1:i+block_size+1] for i in ix])
  x,y = x.to(device),y.to(device)
  return x,y

In [None]:
#Step 13: Define LLM Loss
@torch.no_grad()
def estimate_loss():
  out = {}
  model.eval()
  for split in ['train','test']:
    losses = torch.zeros(eval_iters)
    for k in range(eval_iters):
      X,Y = get_batch(split)
      logits , loss = model(Y,Y)
      losses[k] = loss.item()
    out[split] = losses.mean()
  model.train()
  return out


In [None]:
#Step 14: Define training loop parameters and other hyperparameters
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.nn import init

#hyper parameters
batch_size = 16
block_size = 32
max_iters = 20
eval_interval = 100
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("device",device)
eval_iters = 400
head_size = 16
n_embed = 128
n_head = 8
n_layer = 8
dropout = 0.1
num_experts = 8
top_k = 2
#----

In [None]:
#Step 15: Intialize the entire model
def kaiming_init_weights(m):
  if isinstance(m,(nn.Linear)):
    init.kaiming_normal(m.weight)
model = SparseMoeLanguageModel()
model.apply(kaiming_init_weights)


In [None]:
#Step16: Run the pre-training loop
m = model.to(device)
print(sum(p.numel() for p in m.parameters())/1e6,'M parameters')
optimizer = torch.optim.AdamW(model.parameters(),lr=learning_rate)

for iter in range(max_iters):
  if iter % eval_interval == 0 or iter == max_iters - 1:
    losses = estimate_loss()
    print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

  #sample a batch of data
  xb,yb = get_batch('train')

  #evaluate the loss
  logits,loss = model(xb,yb)
  optimizer.zero_grad(set_to_none=True)
  loss.backward()
  optimizer.step()

In [None]:
#Step 17: Inference
#generate from the model. Not great. Not too bad either
context = torch.zeros((1,1), dtype = torch.long , device = device)
print(decode(m.generate(context,max_new_tokens=2000)[0].tolist()))

In [None]:
#MOE CODE AGAIN ON 12/19

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(42)
#optional

In [None]:
!wget https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/input.txt

In [None]:
#creating a NN for each expert
class Expert(nn.Module):
  def __init__(self,n_embd):
    super().__init__()
    self.net = nn.Sequential(
        nn.Linear(n_embd,4*n_embd),
        nn.ReLU(),
        nn.Linear(4*n_embd,n_embd),
        nn.Dropout(dropout),
    )
  def forward(self,x):
    return self.net(x)

In [None]:
#Understanding how gating works
num_experts = 3
top_k = 2
n_embed = 8

#output from MHA goes into MOE
mh_output = torch.randn(1,4,n_embed) #example input [B,#no of tokens,dimension of each token]
topkgate_linear = nn.Linear(n_embed,num_experts) #[embedding_dimension , #no of experts] . Linear Layer y = a*xT + b
logits = topkgate_linear(mh_output) #mh_output is the input to the linear layer and what it returns is a matirx of size mh_output * num_experts
print(logits)

In [None]:
#implement the top-k load balancing
#from every row we will select two experts to do that we would have to get the value from the logits matrix but before that we need to convert to prob
top_k_logits, indices = logits.topk(top_k,dim=-1) #get top 2 expert
top_k_logits , indices

In [None]:
#create a class for TopKRouting
class TopKRouter(nn.Module): #output is expert selector weight matrix
  def __init__(self,n_embed,num_experts,top_k):
    super(TopKRouter,self).__init__()
    self.top_k = top_k
    self.linear = nn.Linear(n_embed,num_experts)

  def forward(self,mh_output): #mh_output is the output tensor of MHA
    logits = self.linear(mh_output) #get logits from the output
    top_k_logits, indices = logits.topk(self.top_k,dim=-1) #get top k logits
    zeros = torch.full_like(logits,float('-inf')) #fill the logits matrix with - infinity. -infinity indicate to get zero probability from Softmax
    sparse_logits = zeros.scatter(-1,indices,top_k_logits) #get sparse logits across last index
    router_output = F.softmax(sparse_logits,dim=-1) # take a softmax of the sparse logits along last dimension
    #print(router_output)
    #print(indices)
    return router_output , indices

In [None]:
#test expert selector weight matrix
num_experts = 3
top_k = 2
n_embd = 8

mh_output = torch.randn(1,4,n_embd)
top_k_gate = TopKRouter(n_embd,num_experts,top_k)
gating_output , indices = top_k_gate(mh_output)
gating_output.shape , gating_output , indices

In [None]:
#Noisy TopK Routing
class NoisyTopkRouter(nn.Module):
  def __init__(self,n_embd,num_experts,top_k):
    super(NoisyTopkRouter,self).__init__()
    self.top_k = top_k
    self.topkroute_linear = nn.Linear(n_embd,num_experts)
    self.noise_linear = nn.Linear(n_embd,num_experts)

  def forward(self,mh_output):
    logits = self.topkroute_linear(mh_output)
    noise_logits = self.noise_linear(mh_output)

    #Added scale unit gaussian noise to logits
    noise = torch.randn_like(logits)*F.softplus(noise_logits)
    noisy_logits = logits + noise

    top_k_logits, indices = noisy_logits.topk(self.top_k,dim=-1)
    zeros = torch.full_like(noisy_logits,float('-inf'))
    sparse_logits = zeros.scatter(-1,indices,top_k_logits)
    router_output = F.softmax(sparse_logits,dim=-1)
    return router_output , indices

In [None]:
#test expert selector weight matrix
num_experts = 3
top_k = 2
n_embd = 8

mh_output = torch.randn(1,4,n_embd)
top_k_gate = NoisyTopkRouter(n_embd,num_experts,top_k)
gating_output , indices = top_k_gate(mh_output)
gating_output.shape , gating_output , indices

In [None]:
#Create a Sparse
class SparseMoE(nn.Module):
  def __init__(self,n_embed,num_experts,top_k):
    super(SparseMoE,self).__init__()
    self.router = NoisyTopkRouter(n_embed,num_experts,top_k)
    self.experts = nn.ModuleList([Expert(n_embd) for _ in range(num_experts)])
    self.top_k = top_k

  def forward(self,x):
    gating_output , indices = self.router(x)
    final_output = torch.zeros_like(x)

    #Reshape inputs for batch processing
    flat_x = x.view(-1,x.size(-1))
    flat_gating_output = gating_output.view(-1,gating_output.size(-1))

    #Process each expert in parallel
    for i , expert in enumerate(self.experts):
      #Create a mask for the inputs where the current expert is in top-k
      expert_mask = (indices==i).any(dim=-1)
      flat_mask = expert_mask.view(-1)

      if flat_mask.any():
        expert_input = flat_x[flat_mask]
        expert_output = expert(expert_input)

        #extract and apply gating scores
        gating_scores = flat_gating_output[flat_mask,i].unsqueeze(1)
        weighted_output = expert_output * gating_scores

        #Update final output additively by indexing and adding
        final_output[expert_mask] += weighted_output.squeeze(1)
    return final_output

In [None]:
import torch
import torch.nn as nn

#test
num_experts = 3
top_k = 2
n_embd = 8
dropout = 0.1

mh_output = torch.randn(1,4,n_embd) #example multi head attention output
sparse_moe = SparseMoE(n_embd,num_experts,top_k)
final_output = sparse_moe(mh_output)
print("Shape of the final output:", final_output.shape)
print(final_output)

In [None]:
#Putting all MOE together
class Expert(nn.Module):
  def __init__(self,n_embd):
    super().__init__()
    self.net = nn.Sequential(
        nn.Linear(n_embd,4*n_embd),
        nn.ReLU(),
        nn.Linear(4*n_embd,n_embd),
        nn.Dropout(dropout),
    )
  def forward(self,x):
    return self.net(x)

#Noisy TopK Routing
class NoisyTopkRouter(nn.Module):
  def __init__(self,n_embd,num_experts,top_k):
    super(NoisyTopkRouter,self).__init__()
    self.top_k = top_k
    self.topkroute_linear = nn.Linear(n_embd,num_experts)
    self.noise_linear = nn.Linear(n_embd,num_experts)

  def forward(self,mh_output):
    logits = self.topkroute_linear(mh_output)
    noise_logits = self.noise_linear(mh_output)

    #Added scale unit gaussian noise to logits
    noise = torch.randn_like(logits)*F.softplus(noise_logits)
    noisy_logits = logits + noise

    top_k_logits, indices = noisy_logits.topk(self.top_k,dim=-1)
    zeros = torch.full_like(noisy_logits,float('-inf'))
    sparse_logits = zeros.scatter(-1,indices,top_k_logits)
    router_output = F.softmax(sparse_logits,dim=-1)
    return router_output , indices

#Create a Sparse
class SparseMoE(nn.Module):
  def __init__(self,n_embed,num_experts,top_k):
    super(SparseMoE,self).__init__()
    self.router = NoisyTopkRouter(n_embed,num_experts,top_k)
    self.experts = nn.ModuleList([Expert(n_embd) for _ in range(num_experts)])
    self.top_k = top_k

  def forward(self,x):
    gating_output , indices = self.router(x)
    final_output = torch.zeros_like(x)

    #Reshape inputs for batch processing
    flat_x = x.view(-1,x.size(-1))
    flat_gating_output = gating_output.view(-1,gating_output.size(-1))

    #Process each expert in parallel
    for i , expert in enumerate(self.experts):
      #Create a mask for the inputs where the current expert is in top-k
      expert_mask = (indices==i).any(dim=-1)
      flat_mask = expert_mask.view(-1)

      if flat_mask.any():
        expert_input = flat_x[flat_mask]
        expert_output = expert(expert_input)

        #extract and apply gating scores
        gating_scores = flat_gating_output[flat_mask,i].unsqueeze(1)
        weighted_output = expert_output * gating_scores

        #Update final output additively by indexing and adding
        final_output[expert_mask] += weighted_output.squeeze(1)
    return final_output

In [None]:
#construct the transformer block

class Head(nn.Module):
  """ one head of self-attention"""

  def __init__(self, head_size):
    super().__init__()
    self.key = nn.Linear(n_embd, head_size, bias=False)
    self.query = nn.Linear(n_embd, head_size, bias=False)
    self.value = nn.Linear(n_embd, head_size, bias=False)
    self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

    self.dropout = nn.Dropout(dropout)

  def forward(self,x):
    B,T,C = x.shape
    k = self.key(x) #(B,T,C)
    q = self.query(x) #(B,T,C)

    #compute attention scores("affinities")
    wei = q @ k.transpose(-2,-1) * C**-0.5 #(B,T,C) @ (B,C,T) ---> (B,T,T)
    wei.masked_fill(self.tril[:T,:T] == 0,float('-inf')) # (B,T,T) #causality
    wei = F.softmax(wei,dim=-1) # (B,T,T)
    wei = self.dropout(wei)
    #perform the weighted aggregation of the values
    v = self.value(x) #(B,T,C)
    out = wei @ v #(B,T,T) @ (B,T,C) ---> (B,T,C)
    return out

class MultiHeadAttention(nn.Module):
  """multiple heads of self-attention in parallel"""
  def __init__(self,num_heads,head_size):
    super().__init__()
    self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
    self.proj = nn.Linear(n_embd,n_embd)
    self.dropout = nn.Dropout(dropout)

  def forward(self,x):
    out = torch.cat([h(x) for h in self.heads],dim=-1)
    out = self.dropout(self.proj(out))
    return out

In [None]:
#first create a self attention + mixture of experts block , that may be repeated several numbe of times
class Block(nn.Module):
  """ Mixture of Experts Transformer block: communication followed by computation(multi-head self attention + SparseMoE) """
  def __init__(self,n_embd,n_head,num_experts,top_k):
    #n_embd: embedding dimension, n_head: number of heads we'd like
    super().__init__()
    head_size = n_embd // n_head
    self.sa = MultiHeadAttention(n_head,head_size)
    self.smoe = SparseMoE(n_embd,num_experts,top_k)
    self.ln1 = nn.LayerNorm(n_embd)
    self.ln2 = nn.LayerNorm(n_embd)

  def forward(self,x):
    x = x + self.sa(self.ln1(x)) #multi head attention layer and then layer normalization(short cut connection)
    x = x + self.smoe(self.ln2(x)) # MOE followed by layer normalization (shorcut connection)
    return x

In [None]:
#Define the entire language model architecture
#finally putting it together to create a sparse MOE language model
class SparseMoeLanguageModel(nn.Module):
  def __init__(self):
    super().__init__()
    #each token directly reads off the logits for the next token from a lookup table
    self.token_embedding_table = nn.Embedding(vocab_size,n_embd)
    self.position_embedding_table = nn.Embedding(block_size,n_embd)
    self.blocks = nn.Sequential(*[Block(n_embd,n_head=n_head,num_experts = num_experts,top_k = top_k) for _ in range(n_layer)])
    self.ln_f = nn.LayerNorm(n_embd) #final layer norm
    self.lm_head = nn.Linear(n_embd,vocab_size)

  def forward(self,idx,targets=None):
    B,T = idx.shape
    #idx and targets are both (B,T) tensor of integers
    tok_emb = self.token_embedding_table(idx) #(B,T,C)
    pos_emb = self.position_embedding_table(torch.arange(T,device=device)) #(T,C)
    x = tok_emb + pos_emb # (B,T,C)
    x = self.blocks(x) # #each transformer block .... attention and MOE is happening inside these Transformer blocks
    x= self.ln_f(x)
    logits = self.lm_head(x) # (B,T,vocab_size)

    if targets is None:
      loss = None
    else:
      B,T,C = logits.shape
      logits = logits.view(B*T,C)
      targets = targets.view(B*T)
      loss = F.cross_entropy(logits,targets)

    return logits,loss

In [None]:
#input and target for LLM
#Step 12: Create training and testing data
torch.manual_seed(1337)
with open('input.txt','r',encoding='utf-8') as f:
  text = f.read()
#here are all the unique characters that occur in the text
chars = sorted(list(set(text)))
vocab_size = len(chars)
#create a mapping from characters to integers
stoi = {ch:i for i,ch in enumerate(chars)} # mapping character to integer
itos = {i:ch for i,ch in enumerate(chars)} # mapping integer back to character
encode = lambda s: [stoi[c] for c in s] #encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) #decoder: take a list of integers, output a string

#train and test splits
data = torch.tensor(encode(text),dtype=torch.long)
n = int(0.9*len(data)) #first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

#data loading
def get_batch(split):
  #generate a small batch of data inputs x and targets y
  data = train_data if split == 'train' else val_data
  ix = torch.randint(len(data) - block_size , (batch_size,))
  x = torch.stack([data[i:i+block_size] for i in ix])
  y = torch.stack([data[i+1:i+block_size+1] for i in ix])
  x,y = x.to(device),y.to(device)
  return x,y

In [None]:
#loss function
@torch.no_grad()
def estimate_loss():
  out = {}
  model.eval()
  for split in ['train','val']:
    losses = torch.zeros(eval_iters)
    for k in range(eval_iters):
      X,Y = get_batch(split)
      logits,loss = model(X,Y)
      losses[k] = loss.item()
    out[split] = losses.mean()
  model.train()
  return out

In [None]:
#define the loop and hyper parameters
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.nn import init

#hyper parameters
batch_size = 16
block_size = 32
max_iters = 200 # this has to be very high if we need to see some good output from LLM.... on A100 GPU it needs to be around 50000
eval_interval = 100
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("device",device)
eval_iters = 400 #evaluate the loss after every 400 iterations
head_size = 16
n_embd = 128
n_head = 8
n_layer = 8
dropout = 0.1
num_experts = 8
top_k = 2

In [None]:
#intialize the entire model
def intialize_model(m):
  if isinstance(m,(nn.Linear)):
    init.kaiming_normal_(m.weight)
model = SparseMoeLanguageModel()
model.apply(intialize_model)

In [None]:
#create a PyTorch optimizer

optimizer = torch.optim.AdamW(model.parameters(),lr=learning_rate)
for iter in range(max_iters):
  #every once in a while evaluate the loss on train and val sets
  if iter % eval_interval == 0 or iter == max_iters - 1:
    losses = estimate_loss()
    print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
  #sample a batch of data
  xb , yb = get_batch('train')

  #evaluate the loss
  logits , loss = model(xb,yb)
  optimizer.zero_grad(set_to_none=True)
  loss.backward()
  optimizer.step()

In [None]:
#inference
context = torch.zeros((1,1) , dtype = torch.long , device = device)
print(decode(m.generate(context,max_new_tokens=500)[0].tolist()))

In [None]:
#makeMOE -
%pip install mlflow

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import mlflow
import mlflow.pytorch


In [None]:
torch.manual_seed(42)

In [None]:
#downloading the tiny shakespeare dataset
!wget https://raw.githubusercontent.com/AviSoori1x/makeMOE/main/input.txt

In [None]:
#read the dataset to inspect the dataset
with open('input.txt','r',encoding='utf-8') as f:
  text = f.read()

In [None]:
print("length of dataset is:" ,len(text))

In [None]:
print(text[:1000])

In [None]:
#the unique characters that occur in text .. we need to map these characters to integers
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)

In [None]:
#create a mappinhg from characters to integers - but why ?
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for i,ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s] #take a string , output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) #take a list of integers , output a string

print(encode("Sparse MoE Implementation"))
print(decode(encode("Sparse MoE Implementation")))

In [None]:
#let's now encode the entire text dataset and store it into a torch.tensor
data = torch.tensor(encode(text),dtype=torch.long)
print(data.shape,data.dtype)
print(data[:1000])

In [None]:
n = int(0.9*len(data)) #first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

In [None]:
block_size = 8 #sequence length for prediction
train_data[:block_size+1]

In [None]:
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
  context = x[:t+1]
  target = y[t]
  print(f"when input is {context} the target: {target}")

In [None]:
batch_size = 4 #how many independent sequences will be process in parallel
block_size = 8 #what is the maximum context length for predictions

ix = torch.randint(len(data) - block_size, (batch_size,))
ix

In [None]:
x = torch.stack([data[i:i+block_size] for i in ix])
y = torch.stack([data[i+1:i+block_size+1] for i in ix])
x

In [None]:
y

In [None]:
def get_batch(split):
  #generate a small batch of data of inputs x and targets y
  data = train_data if split =='train' else val_data
  ix = torch.randint(len(data) - block_size , (batch_size,))
  x = torch.stack([data[i:i+block_size] for i in ix])
  y = torch.stack([data[i+1:i+block_size+1] for i in ix])
  return x,y
xb , yb = get_batch('train')
print('inputs:')
print(xb.shape)
print(xb)
print('targets:')
print(yb.shape)
print(yb)

print('----')
for b in range(batch_size): #batch dimension
  for t in range(block_size): #time dimension
    context = xb[b,:t+1]
    target = yb[b,t]
    print(f"when input is {context.tolist()} the target: {target}")

In [None]:
torch.manual_seed(1337)
B,T,C = 4,8,32 # batch , seqeuence length , embedding dimension
x = torch.randn(B,T,C)

head_size = 16 #attention head dimension
key = nn.Linear(C,head_size,bias=False)
query = nn.Linear(C,head_size,bias=False)
value = nn.Linear(C,head_size,bias=False)
k = key(x) # (B,T,16)
q = query(x) # (B,T,16)
wei = q @ k.transpose(-2,-1) # (B,T,16) @ (B,16,T) ---> (B,T,T)

tril = torch.tril(torch.ones(T,T))
#wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0,float('-inf'))
wei = F.softmax(wei,dim=-1) # B,T,T

v = value(x) # B,T,H
out = wei @ v # B,T,H
out.shape

In [None]:
n_embd = 64
n_head = 4
n_layer = 4
head_size = 16
dropout = 0.1

class Head(nn.Module):
  """ one head of self-attention"""

  def __init__(self,head_size):
    super().__init__()
    self.key = nn.Linear(n_embd,head_size,bias=False)
    self.query = nn.Linear(n_embd,head_size,bias=False)
    self.value = nn.Linear(n_embd,head_size,bias=False)
    self.register_buffer('tril',torch.tril(torch.ones(block_size,block_size)))
    self.dropout = nn.Dropout(dropout)

  def forward(self,x):
    B,T,C = x.shape
    k = self.key(x) #(B,T,C)
    q = self.query(x) #(B,T,C)
    #compute attention scores
    wei = q @ k.transpose(-2,-1)*C**-0.5 #B,T,C  @(B,C,T) -> (B,T,T)
    wei = wei.masked_fill(self.tril[:T,:T] == 0,float('-inf')) # (B,T,T)
    wei = F.softmax(wei,dim=-1)
    wei = self.dropout(wei)
    #perform the weighted aggregation of the values
    v = self.value(x) #(B,T,C)
    out = wei @ v #(B,T,T) @ (B,T,C) ---> (B,T,C)
    return out


In [None]:
#multi-headed self attention
class MultiHeadAttention(nn.Module):

  def __init__(self,num_heads,head_size):
    super().__init__()
    self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
    self.proj = nn.Linear(n_embd,n_embd)
    self.dropout = nn.Dropout(dropout)

  def forward(self,x):
    out = torch.cat([h(x) for h in self.heads],dim=-1)
    out = self.dropout(self.proj(out))
    return out

In [None]:
#confirming that multi head self attention works
B,T,C = 4,8,64
x = torch.randn(B,T,C)
mha = MultiHeadAttention(4,16)
mha(x).shape

In [None]:
#Expert module
class Expert(nn.Module):
  """ Simple MLP for each Expert in MoE """
  def __init__(self,n_embd):
    super().__init__()
    self.net = nn.Sequential(
        nn.Linear(n_embd,4*n_embd),
        nn.ReLU(),
        nn.Linear(4*n_embd,n_embd),
        nn.Dropout(dropout),
    )
  def forward (self,x):
    return self.net(x)

In [None]:
#top-k gating intituion through an example
num_experts = 4
top_k = 2
n_embed = 32 #embedding dimension 32

#multi-head attention output of a simple illustrative example , consider n_embed = 32 , context_length = 4 and batch_size = 2
mh_output = torch.randn(2,4,n_embed) # 3Dimension Tensor #(B,Context Length , Token embedding dimension)

# Linear layer is for affine transformation and the inputs is a Weight MATRIX of size n_embed and num_experts and this matrix is applied on input
topk_gate_linear = nn.Linear(n_embed,num_experts) #(32,4)
logits = topk_gate_linear(mh_output) #(2,4,32)
top_k_logits , top_k_indices = logits.topk(top_k,dim=-1) #get top-k experts (2,2,4) matrix as return
top_k_logits , top_k_indices

(tensor([[[ 0.5148, -0.3222],
          [ 0.4683,  0.3065],
          [ 0.5787, -0.1454],
          [ 0.0301, -0.1731]],
 
         [[ 0.4308, -0.0073],
          [ 0.6029,  0.3178],
          [ 0.6333,  0.3774],
          [ 1.4708,  0.8050]]], grad_fn=<TopkBackward0>),
 tensor([[[1, 3],
          [1, 0],
          [2, 3],
          [3, 0]],
 
         [[3, 1],
          [1, 0],
          [0, 1],
          [2, 3]]]))

In [None]:
zeros = torch.full_like(logits,float('-inf')) #full like clones a tensor and fills with -inf values
sparse_logits = zeros.scatter(-1,top_k_indices,top_k_logits)
sparse_logits

tensor([[[   -inf,  0.5148,    -inf, -0.3222],
         [ 0.3065,  0.4683,    -inf,    -inf],
         [   -inf,    -inf,  0.5787, -0.1454],
         [-0.1731,    -inf,    -inf,  0.0301]],

        [[   -inf, -0.0073,    -inf,  0.4308],
         [ 0.3178,  0.6029,    -inf,    -inf],
         [ 0.6333,  0.3774,    -inf,    -inf],
         [   -inf,    -inf,  1.4708,  0.8050]]], grad_fn=<ScatterBackward0>)

In [None]:
#apply softmax on sparse logits
gating_output = F.softmax(sparse_logits,dim=-1)
gating_output

tensor([[[0.0000, 0.6978, 0.0000, 0.3022],
         [0.4596, 0.5404, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.6735, 0.3265],
         [0.4494, 0.0000, 0.0000, 0.5506]],

        [[0.0000, 0.3922, 0.0000, 0.6078],
         [0.4292, 0.5708, 0.0000, 0.0000],
         [0.5636, 0.4364, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.6606, 0.3394]]], grad_fn=<SoftmaxBackward0>)

In [None]:
class TopKRouter(nn.Module):
  def __init__(self,n_embed,num_experts,top_k):
    super(TopKRouter,self).__init__()
    self.top_k = top_k
    self.linear = nn.Linear(n_embed,num_experts)

  def forward(self,mh_output):
    logits = self.linear(mh_output)
    top_k_logits , indices = logits.topk(self.top_k,dim=-1)
    zeros = torch.full_like(logits,float('-inf'))
    sparse_logits = zeros.scatter(-1,indices,top_k_logits)
    gating_output = F.softmax(sparse_logits,dim=-1)
    return gating_output , indices

In [None]:
#test the above out
num_experts = 4
top_k = 2
n_embed = 32

mh_output = torch.randn(2,4,n_embed)
router = TopKRouter(n_embed,num_experts,top_k)
gating_output , indices = router(mh_output)
gating_output.shape , gating_output , indices

(torch.Size([2, 4, 4]),
 tensor([[[0.0000, 0.6366, 0.3634, 0.0000],
          [0.0000, 0.4316, 0.0000, 0.5684],
          [0.5030, 0.0000, 0.0000, 0.4970],
          [0.0000, 0.0000, 0.4805, 0.5195]],
 
         [[0.4824, 0.0000, 0.0000, 0.5176],
          [0.4244, 0.0000, 0.5756, 0.0000],
          [0.1620, 0.0000, 0.0000, 0.8380],
          [0.4032, 0.5968, 0.0000, 0.0000]]], grad_fn=<SoftmaxBackward0>),
 tensor([[[1, 2],
          [3, 1],
          [0, 3],
          [3, 2]],
 
         [[3, 0],
          [2, 0],
          [3, 0],
          [1, 0]]]))

In [None]:
#Changing the above to accomdate noisy top-k gating
class NoisyTopkRouter(nn.Module):
  def __init__(self,n_embed,num_experts,top_k):
    super(NoisyTopkRouter,self).__init__()
    self.top_k = top_k
    self.linear = nn.Linear(n_embed,num_experts)
    self.noise_linear = nn.Linear(n_embed,num_experts)

  def forward(self, mh_output):
    logits = self.linear(mh_output)
    noise = self.noise_linear(mh_output)
    noisy_logits = logits + noise
    top_k_logits , indices = noisy_logits.topk(self.top_k,dim=-1)
    zeros = torch.full_like(logits,float('-inf'))
    sparse_logits = zeros.scatter(-1,indices,top_k_logits)
    router_output = F.softmax(sparse_logits,dim=-1)
    return router_output , indices

In [None]:
num_experts = 8
top_k = 2
n_embed = 16

mh_output = torch.randn(2,4,n_embed)
gating = NoisyTopkRouter(n_embed,num_experts,top_k)
gating_output , indices = gating(mh_output)
gating_output.shape , gating_output , indices

(torch.Size([2, 4, 8]),
 tensor([[[0.0000, 0.0000, 0.3578, 0.6422, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.5695, 0.0000, 0.0000, 0.0000, 0.0000, 0.4305, 0.0000],
          [0.0000, 0.0000, 0.5943, 0.0000, 0.0000, 0.0000, 0.0000, 0.4057],
          [0.5610, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4390]],
 
         [[0.4770, 0.5230, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.5314, 0.0000, 0.0000, 0.4686, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.3633, 0.0000, 0.0000, 0.0000, 0.6367],
          [0.0000, 0.0000, 0.0000, 0.7481, 0.0000, 0.0000, 0.0000, 0.2519]]],
        grad_fn=<SoftmaxBackward0>),
 tensor([[[3, 2],
          [1, 6],
          [2, 7],
          [0, 7]],
 
         [[1, 0],
          [3, 6],
          [7, 3],
          [3, 7]]]))

In [None]:
class SparseMoE(nn.Module):
  def __init__(self,n_embed,num_experts,top_k):
    super(SparseMoE,self).__init__()
    self.router = NoisyTopkRouter(n_embed,num_experts,top_k)
    self.experts = nn.ModuleList([Expert(n_embed) for _ in range(num_experts)])
    self.top_k = top_k

  def forward(self,x):
    gating_output , indices = self.router(x)
    final_output = torch.zeros_like(x)

    #Reshape inputs for batch processing
    flat_x = x.view(-1,x.size(-1))
    flat_gating_output = gating_output.view(-1,gating_output.size(-1))

    #Process
    for i , expert in enumerate(self.experts):
      expert_mask = (indices == i).any(dim=-1)
      flat_mask = expert_mask.view(-1)

      if flat_mask.any():
        expert_input = flat_x[flat_mask]
        expert_output = expert(expert_input)

        #Extract and apply gating scores
        gating_scores = flat_gating_output[flat_mask,i].unsqueeze(1)
        weighted_output = expert_output * gating_scores

        final_output[expert_mask] += weighted_output.squeeze(1)

    return final_output

In [None]:
import torch
import torch.nn as nn

#let's test this out
num_experts = 8
top_k = 2
n_embed = 16
dropout = 0.1

mh_output = torch.randn(4,8,n_embed)
sparse_moe = SparseMoE(n_embed,num_experts,top_k)
final_output = sparse_moe(mh_output)
print(final_output)

In [None]:
#putting it all together - building a LLM
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.nn import init

#hyperparameters
batch_size = 16
block_size = 32
max_iters = 5000
eval_interval = 100
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("device",device)
eval_iters = 400
head_size = 16
n_embd = 128
n_head = 8
n_layer = 8
dropout = 0.1
num_experts = 8
top_k = 2
#-----

torch.manual_seed(1337)
#wget https://raw.githubusercontent.com/AviSoori1x/makeMOE/main/input.txt

with open("input.txt",'r',encoding= 'utf-8') as f:
  text = f.read()

#here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)

#create a mapping from characters to integers
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for i,ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

#Train and Test Splits
data = torch.tensor(encode(text),dtype=torch.long)
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

#data loading
def get_batch(split):
  #generate a small batch of data of inputs x and targets y
  data = train_data if split == 'train' else val_data
  ix = torch.randint(len(data) - block_size , (batch_size,))
  x = torch.stack([data[i:i+block_size] for i in ix])
  y = torch.stack([data[i+1:i+block_size+1] for i in ix])
  x,y = x.to(device),y.to(device)
  return x,y

@torch.no_grad()
def estimate_loss():
  out = {}
  model.eval()
  for split in ['train','val']:
    losses = torch.zeros(eval_iters)
    for k,v in range(eval_iters):
      X,Y = get_batch(split)
      logits,loss = model(X,Y)
      losses[k] = loss.item()
    out[split] = losses.mean()
  model.train()
  return out

device cpu


In [None]:
class Head(nn.Module):
  """ on head of self-attention"""
  def __init__(self,head_size):
    super().__init__()
    self.key = nn.Linear(n_embd,head_size,bias=False)
    self.query = nn.Linear(n_embd,head_size,bias=False)
    self.value = nn.Linear(n_embd,head_size,bias=False)
    self.register_buffer('tril',torch.tril(torch.ones(block_size,block_size)))

    self.dropout = nn.Dropout(dropout)

  def forward(self,x):
    B,T,C = x.shape
    k = self.key(x)
    q = self.query(x)

    wei = q @ k.transpose(-2,-1)*C**-0.5
    wei = wei.masked_fill(self.tril[:T,:T] == 0,float('-inf'))
    wei = F.softmax(wei,dim=-1)
    wei = self.dropout(wei)

    #perform the weighted aggregration of the values
    v = self.value(x)
    out = wei@v
    return out

class MultiHeadAttention(nn.Module):
  def __init__(self,num_heads,head_size):
    super().__init__()
    self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
    self.proj = nn.Linear(n_embd,n_embd)
    self.dropout = nn.Dropout(dropout)

  def forward(self,x):
    out = torch.cat([h(x) for h in self.heads],dim=-1)
    out = self.dropout(self.proj(out))
    return out

In [None]:
#Expert Module
class Expert(nn.Module):
  """ Simple MLP for each Expert in MoE """
  def __init__(self,n_embd):
    super().__init__()
    self.net = nn.Sequential(
        nn.Linear(n_embd,4*n_embd),
        nn.ReLU(),
        nn.Linear(4*n_embd,n_embd),
        nn.Dropout(dropout),
    )
  def forward(self,x):
    return self.next(x)

class NoisyTopkRouter(nn.Module):
  def __init__(self,n_embed,num_experts,top_k):
    super(NoisyTopkRouter,self).__init__()
    self.top_k = top_k
    self.linear = nn.Linear(n_embed,num_experts)
    self.noise_linear = nn.Linear(n_embed,num_experts)

  def forward(self,mh_output):
    logits = self.linear(mh_output)
    noise = self.noise_linear(mh_output)
    noisy_logits = logits + noise
    top_k_logits , indices = noisy_logits.topk(self.top_k,dim=-1)
    zeros = torch.full_like(logits,float('-inf'))
    sparse_logits = zeros.scatter(-1,indices,top_k_logits)
    router_output = F.softmax(sparse_logits,dim=-1)
    return router_output , indices

class SparseMoE(nn.Module):
  def __init__(self,n_embed,num_experts,top_k):
    super(SparseMoE,self).__init__()
    self.router = NoisyTopkRouter(n_embed,num_experts,top_k)
    self.experts = nn.ModuleList([Expert(n_embed) for _ in range(num_experts)])
    self.top_k = top_k
  def forward(self,x):
    gating_output , indices = self.router(x)
    final_output = torch.zeros_like(x)

    #reshape inputs for batch processing
    flat_x = x.view(-1,x.size(-1))
    flat_gating_output = gating_output.view(-1,gating_output.size(-1))

    #Process each expert in parallel
    for i,expert in enumerate(self.experts):
      expert_mask = (indices == i).any(dim=-1)
      flat_mask = expert_mask.view(-1)
      if flat_mask.any():
        expert_input = flat_x[flat_mask]
        expert_output = expert(expert_input)

        #Extract and apply gating scores
        gating_scores = flat_gating_output[flat_mask,i].unsqueeze(1)
        weighted_output = expert_output * gating_scores

        final_output[expert_mask] += weighted_output.squeeze(1)
    return final_output

In [None]:
#create the Transformer Block
class Block(nn.Module):
  def __init__(self,n_embed,n_head,num_experts,top_k):
    super().__init__()
    head_size = n_embed // n_head
    self.sa = MultiHeadAttention(n_head,head_size)
    self.smoe = SparseMoE(n_embed,num_experts,top_k)
    self.ln1 = nn.LayerNorm(n_embed)
    self.ln2 = nn.LayerNorm(n_embed)

  def forward(self,x):
    x = x + self.sa(self.ln1(x))
    x = x + self.smoe(self.ln2(x))
    return x

In [None]:
#Finally putting it all together to create a sparse MoE labguage model
class SparseMoELM(nn.Module):
  def __init__(self):
    super().__init__()
    self.token_embedding_table = nn.Embedding(vocab_size,n_embd)
    self.position_embedding_table = nn.Embedding(block_size,n_embd)
    self.blocks = nn.Sequential(*[Block(n_embd,n_head = n_head,num_experts = num_experts,top_k = top_k) for _ in range(n_layer)])
    self.ln_f = nn.LayerNorm(n_embd)
    self.lm_head = nn.Linear(n_embd,vocab_size)

  def forward(self,idx, targets = None):
    B,T = idx.shape
    #idx and targets are both (B,T) tensor of integers
    tok_emb = self.token_embedding_table(idx) #(B,T,C)
    pos_emb = self.position_embedding_table(torch.arange(T,device=device)) #(T,C)
    x = tok_emb + pos_emb #(B,T,C)
    x = self.blocks(x) #(B,T,C)
    x = self.ln_f(x) #(B,T,C)
    logits = self.lm_head(x) #(B,T,vocab_size)

    if targets is None:
      loss = None
    else:
      B,T,C = logits.shape
      logits = logits.view(B*T,C)
      targets = targets.view(B*T)
      loss = F.cross_entropy(logits,targets)

    return logits,loss

  def generate(self,idx,max_new_tokens):
    #idx is (B,T) array of indices in the current context
    for _ in range(max_new_tokens):
      #crop idx to the last block_size tokens
      idx_cond = idx[:,-block_size:]
      #get the predictions
      logits,loss = self(idx_cond)
      #focus only on the last time step
      logits = logits[:,-1,:] #(B,C)
      #apply softmax to get probabilities
      probs = F.softmax(logits,dim=-1) #(B,C)
      #sample from the distribution
      idx_next = torch.multinomial(probs,num_samples=1) #(B,1)
      #append sampled index to the running sequence
      idx = torch.cat((idx,idx_next),dim=1) #(B,T+1)
    return idx

In [None]:
#weight intialization of model - THIS IS KEY
def kaiming_init_weights(m):
  if isinstance(m,(nn.Linear)):
    init.kaiming_normal_(m.weight)

#https://www.geeksforgeeks.org/deep-learning/kaiming-initialization-in-deep-learning/
#https://www.geeksforgeeks.org/deep-learning/xavier-initialization/

In [None]:
model = SparseMoELM()
model.apply(kaiming_init_weights)

In [20]:
#MULTI TOKEN PREDICTION
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init

In [21]:
class RMSNorm(nn.Module):
  # RMS will be applied to both hidden state and input embedding . Post that concatenation of hidden state and input embedding will be done
  def __init__(self,d_model,eps:float = 1e-8):
    super().__init__()
    self.eps = eps

  def forward(self,x):
    #x :{batch ,d_model}
    rms = torch.rsqrt(x.pow(2).mean(-1,keepdim=True) + self.eps)
    return x/rms

In [18]:
class SimpleMTP(nn.Module):
  def __init__(self,d_model: int,vocab_size: int,num_heads: int=3, nhead: int =2):

    """
      d_model: hidden size (8 in your example)
      num_heads: number of sequential MTP steps(D)
      nhead: attention heads in each Transformer block
    """
    super().__init__()
    self.num_heads = num_heads
    self.d_model = d_model
    self.vocab_size = vocab_size

    #shared modules
    self.rms = RMSNorm(d_model)
    self.embed = nn.Embedding(vocab_size,d_model)
    self.unembed = nn.Linear(d_model,vocab_size, bias= False)
    #share weights between emed and unembed
    self.unembed.weight = self.embed.weight

    #projection
    self.proj = nn.ModuleList([nn.Linear(2*d_model,d_model) for _ in range(num_heads)])
    self.Transformer = nn.ModuleList([nn.TransformerEncoderLayer(d_model,nhead) for _ in range(num_heads)])

  def forward(self,token_ids: torch.LongTensor, init_hidden: torch.Tensor = None):
    """
    token_ids: (batch,seq_len) integer IDs for your input tokens
    init_hidden: optional (batch,seq_len,d_model) base hidden states
                 if none , uses token_embeddings as intial hidden
    Returns:
      logits_out:Tensor of shape (batch,T-D,D,vocab_size)
      where T = seq_len and D = num_heads
    """
    B,T = token_ids.shape
    device = token_ids.device

    #token embeddings: (B,T,d_model)
    embeds = self.embed(token_ids)
    if init_hidden is None:
      h0_seq = embeds
    else:
      h0_seq = init_hidden
    outputs = []
    max_i = T - self.num_heads -1 #T is the input sequence length and num_heads = 2
    for i in range(0,max_i+1):
      h_prev = h0_seq[:,i,:] #[B, d_model]
      #collect logits for all k at this i
      logits_k = []
      for k in range(self.num_heads):
        future_pos = i +(k+1)
        tok_embed = embeds[:,future_pos,:] #(B,d_model)
        # 1)RMS-normalize
        h_norm = self.rmsnorm(h_prev) # (B,d_model)
        e_norm = self.rmsnorm(tok_embed)  #(B, d_model)

        # 2) concatenate ->(B,2*d_model)
        merged = torch.cat([h_norm,e_norm], dim = -1)

        #3) Project back to d_model
        proj = self.projections[k](merged) #(B,d_model)

        #4)Transformer block (expects shape (B,S,d_model))
        x = proj.unsqueeze(1) #(1,B,d_model)
        x = self.Transformer[k](x) #(1,B,d_model) #get transformer block output


In [19]:
#MTP prediction
batch_size , seq_len , d_model , vocab_size = 1,8,8,5000
model = SimpleMTP(d_model = d_model, vocab_size = vocab_size , num_heads = 3)
tokens = torch.randint(0,vocab_size,(batch_size,seq_len))

#forward pass
logits = model(tokens)
print("logits.shape",logits.shape)
#logits shape = [1,4-3,3,5000] ->[B,T-D,D,V] -> batch size , T-D , D , Vocabulary
print("Head at k=0 at i =0 logits:",logits[0,0,0])
pred_ids = logits[0,0].argmax(dim=-1)
print("Predicted ids:",pred_ids) # a length of tensor 3

AttributeError: 'SimpleMTP' object has no attribute 'rmsnorm'