<a href="https://colab.research.google.com/github/goelnikhils-lgtm/languagemodels/blob/main/Mixture_of_Expert.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(42)
#optional https://www.youtube.com/watch?v=W7ktPe1HfZs&t=10s

<torch._C.Generator at 0x7f6e6c627e70>

In [None]:
#load dataset
!wget https://raw.githubusercontent.com/AviSoori1x/makeMoE/main/input.txt

In [None]:
#STEP 1
#Expert Module
#every expert is a FF layer neural layer
class Expert(nn.Module):
  """"An MLP is a single Linear Layer network followed by non-linearity i.e each expert"""
  """expansion and contraction layer"""
  """n_embd: embedding dimension - The input that comes to MOE is after MHA , Dropout and RMSNorm"""
  def __init__(self,n_embd):
    super().__init__()
    self.net = nn.Sequential(
        nn.Linear(n_embd,4*n_embd), #expansion layer in each MOE
        nn.ReLU(),
        nn.Linear(4*n_embd,n_embd), #contraction layer
        nn.Dropout(dropout),
    )
  def forward(self,x):
    return self.net(x)

In [None]:
#STEP 2
#Routing Module
"""Routing Module"""
num_experts = 3 #no of experts in MOE
top_k=2 #every token will be routed to top two experts
n_embed = 8 #every token is of 32 embeddings

#example MHA output of a simple illustrative example, consider n_embed = 32 , content_length=
mh_output = torch.randn(1,4,n_embed) #(B,no of tokens , embed_size of no of tokens)
topkgate_linear = nn.Linear(n_embed,num_experts) #nn.Linear(32,4)
logits = topkgate_linear(mh_output) #torch.Size([2, 4, 4]) expert selector matrix or logits
print(logits)

In [None]:
#STEP 3
top_k_logits , top_k_indices = logits.topk(top_k,dim=-1) #get top_k experts
top_k_logits , top_k_indices
print(top_k_logits)
print(top_k_indices)

In [None]:
#STEP 4
#use -inf and apply Softmax
zeros = torch.full_like(logits,float('-inf')) #full_like clones a tensor and fills it with a specified
sparse_logits = zeros.scatter(-1,top_k_indices,top_k_logits)
sparse_logits

In [None]:
gating_output = F.softmax(sparse_logits,dim=-1)
gating_output

In [None]:
#Step 5: Create a class for TopKRouting
class TopKRouter(nn.Module):
  def __init__(self,n_embed,num_experts,top_k):
    super(TopKRouter,self).__init__()
    self.top_k = top_k
    self.linear = nn.Linear(n_embed,num_experts)

  def forward(self,mh_output):
    #mh output is the output tensor from multihead self attention block
    logits = self.linear(mh_output)
    top_k_logits, indices = logits.topk(self.top_k,dim=-1)
    zeros = torch.full_like(logits,float('-inf'))
    sparse_logits = zeros.scatter(-1,indices,top_k_logits)
    router_output = F.softmax(sparse_logits,dim=-1)
    return router_output , indices

In [None]:
#testing this out:
num_experts = 3
top_k = 2
n_embd = 8

mh_output = torch.randn(1,4,n_embd) #example input (batchsize,no of tokens,embed_dim_size)
top_k_gate = TopKRouter(n_embd,num_experts,top_k)
gating_output , indices = top_k_gate(mh_output)
gating_output.shape , gating_output , indices

In [None]:
#Step 6
#NoisyTopk Routing
class NoisyTopkRouter(nn.Module):
  def __init__(self,n_embed,num_experts,top_k):
    super(NoisyTopkRouter,self).__init__()
    self.top_k = top_k
    #layer for router logits
    self.topkroute_linear = nn.Linear(n_embed,num_experts)
    self.noise_linear = nn.Linear(n_embed,num_experts)

  def forward(self,mh_ouput):
    #mh_output is the output tensor from multihead self attention block
    logits = self.topkroute_linear(mh_output)

    #Noise logits
    noise_logits = self.noise_linear(mh_output)
    #Adding scaled unit gaussian noise to the logits
    noise = torch.randn_like(logits)*F.softplus(noise_logits)
    noisy_logits = logits + noise

    top_k_logits, indices = noisy_logits.topk(self.top_k,dim=-1)
    zeros = torch.full_like(noisy_logits,float('-inf'))
    sparse_logits = zeros.scatter(-1,indices,top_k_logits)
    router_output = F.softmax(sparse_logits,dim=-1)
    return router_output , indices

In [None]:
#testing this out again
num_experts = 3
top_k = 2
n_embed = 8

mh_output = torch.randn(1,4,n_embed) #Example input
noisy_topk_gate = NoisyTopkRouter(n_embed,num_experts,top_k)
gating_output , indices = noisy_topk_gate(mh_output)
gating_output.shape , gating_output , indices


In [None]:
#Step 7: Create a Sparse Mixture of Expert (MoE) module
class SparseMoE(nn.Module):
  def __init__(self,n_embd,num_experts,top_k):
    super(SparseMoE,self).__init__()
    self.router = NoisyTopkRouter(n_embd,num_experts,top_k)
    self.experts = nn.ModuleList([Expert(n_embd) for _ in range(num_experts)])
    self.top_k = top_k

  def forward(self,x):
    gating_output , indices = self.router(x)
    final_output = torch.zeros_like(x)

    #Reshape inputs for Batch Processing
    flat_x = x.view(-1,x.size(-1))
    flat_gating_output = gating_output.view(-1,gating_output.size(-1))

    #Process each expert in parallel
    for i , expert in enumerate(self.experts):
      #Create a mask for the inputs where the current expert is in top-k
      expert_mask = (indices==i).any(dim=-1)
      flat_mask = expert_mask.view(-1)

      if flat_mask.any():
        expert_input  = flat_x[flat_mask]
        expert_output = expert(expert_input)

        #Extract and apply gating scores
        gating_scores = flat_gating_output[flat_mask,i].unsqueeze(1)
        weighted_output = expert_output * gating_scores

        #Update final output additively by indexing and adding
        final_output[expert_mask] += weighted_output.squeeze(1)
    return final_output

In [None]:
#Step 8
import torch
import torch.nn as nn

#let's test this out
num_experts = 3
top_k = 2
n_embed = 8
dropout = 0.1

mh_output = torch.randn(1,4,n_embed) #example multi-head attention output
sparse_moe = SparseMoE(n_embed,num_experts,top_k)
final_output = sparse_moe(mh_output)
print("Shape of the final output:", final_output.shape)
print("Final output:", final_output)

In [None]:
#Step 8 Putting all together
class Expert(nn.Module):
  """"An MLP is a single Linear Layer network followed by non-linearity i.e each expert"""
  """expansion and contraction layer"""
  """n_embd: embedding dimension - The input that comes to MOE is after MHA , Dropout and RMSNorm"""
  def __init__(self,n_embd):
    super().__init__()
    self.net = nn.Sequential(
        nn.Linear(n_embd,4*n_embd), #expansion layer in each MOE
        nn.ReLU(),
        nn.Linear(4*n_embd,n_embd), #contraction layer
        nn.Dropout(dropout),
    )
  def forward(self,x):
    return self.net(x)

#Change the above to accomdate noisy top-k gating
class NoisyTopkRouter(nn.Module):
  def __init__(self,n_embed,num_experts,top_k):
    super(NoisyTopkRouter,self).__init__()
    self.top_k = top_k
    #layer for router logits
    self.topkroute_linear = nn.Linear(n_embed,num_experts)
    self.noise_linear = nn.Linear(n_embed,num_experts)

  def forward(self,mh_ouput):
    #mh_output is the output tensor from multihead self attention block
    logits = self.topkroute_linear(mh_output)

    #Noise logits
    noise_logits = self.noise_linear(mh_output)
    #Adding scaled unit gaussian noise to the logits
    noise = torch.randn_like(logits)*F.softplus(noise_logits)
    noisy_logits = logits + noise

    top_k_logits, indices = noisy_logits.topk(self.top_k,dim=-1)
    zeros = torch.full_like(noisy_logits,float('-inf'))
    sparse_logits = zeros.scatter(-1,indices,top_k_logits)
    router_output = F.softmax(sparse_logits,dim=-1)
    return router_output , indices

#Now create the Sparse Mixture of experts module
class SparseMoE(nn.Module):
  def __init__(self,n_embd,num_experts,top_k):
    super(SparseMoE,self).__init__()
    self.router = NoisyTopkRouter(n_embd,num_experts,top_k)
    self.experts = nn.ModuleList([Expert(n_embd) for _ in range(num_experts)])
    self.top_k = top_k

  def forward(self,x):
    gating_output , indices = self.router(x)
    final_output = torch.zeros_like(x)

    #Reshape inputs for Batch Processing
    flat_x = x.view(-1,x.size(-1))
    flat_gating_output = gating_output.view(-1,gating_output.size(-1))

    #Process each expert in parallel
    for i , expert in enumerate(self.experts):
      #Create a mask for the inputs where the current expert is in top-k
      expert_mask = (indices==i).any(dim=-1)
      flat_mask = expert_mask.view(-1)

      if flat_mask.any():
        expert_input  = flat_x[flat_mask]
        expert_output = expert(expert_input)

        #Extract and apply gating scores
        gating_scores = flat_gating_output[flat_mask,i].unsqueeze(1)
        weighted_output = expert_output * gating_scores

        #Update final output additively by indexing and adding
        final_output[expert_mask] += weighted_output.squeeze(1)
    return final_output

In [None]:
#Step 9:Code the entire Transformer block: Part 1 MHA
class Head(nn.Module):
  """ one head of self attention"""
  def __init__(self,head_size):
      super().__init__()
      self.key = nn.Linear(n_embd,head_size,bias=False)
      self.query = nn.Linear(n_embd,head_size,bias=False)
      self.value = nn.Linear(n_embd,head_size,bias=False)
      self.register_buffer('tril',torch.tril(torch.ones(block_size,block_size)))
      self.dropout = nn.Dropout(dropout)

  def forward(self, x):
      B,T,C = x.shape
      k = self.key(x) # (B,T,C)
      q = self.query(x) # (B,T,C)
      #compute attention scores ("affinities")
      wei = q @ k.transpose(-2,-1) * C**-0.5 # (B,T,C) @ (B,C,T) -> (B,T,T)
      #causal mask
      wei = wei.masked_fill(self.tril[:T,:T]==0,float('-inf')) # (B,T,T)
      wei = F.softmax(wei,dim=-1) # (B,T,T)
      wei = self.dropout(wei)
      #perform the weighted aggregation of the values
      v = self.value(x) # (B,T,C)
      out = wei @ v # (B,T,T) @ (B,T,C) -> (B,T,C)
      return out

#Multi Head Attention
class MultiHeadAttention(nn.Module):
  def __init__(self,num_heads,head_size):
    super().__init__()
    self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
    self.proj = nn.Linear(n_embd,n_embd)
    self.dropout = nn.Dropout(dropout)

  def forward(self,x):
    out = torch.cat([h(x) for h in self.heads] , dim=-1)
    out = self.dropout(self.proj(out))
    return out

In [None]:
#Step 10: Code the entire transformer block: Part2(Assemble all layers)
class Block(nn.Module):
  def __init__(self,n_embd,n_head, num_experts, top_k):
    super().__init__()
    head_size = n_embd // n_head
    self.sa = MultiHeadAttention(n_head,head_size)
    self.smoe = SparseMoE(n_embd,num_experts,top_k)
    self.ln1 = nn.LayerNorm(n_embd)
    self.ln2 = nn.LayerNorm(n_embd)

  def forward(self,x):
    x = x + self.sa(self.ln1(x))
    x = x + self.smoe(self.ln2(x))
    return x


In [None]:
#Step 11: Define the entire language model architecture
class SparseMoeLanguageModel(nn.Module):

  def __init__(self):
    super().__init__()
    #each token directly reads off the logits for the next token from a lookup table
    self.token_embedding_table = nn.Embedding(vocab_size,n_embd)
    self.position_embedding_table = nn.Embedding(block_size,n_embd)
    #chain of transformer blocks. embedding is passed thru the same
    self.blocks = nn.Sequential(*[Block(n_embd,n_head=n_head,num_experts=num_experts,top_k=top_k) for _ in range(n_layer)])
    self.ln_f = nn.LayerNorm(n_embd) #final layer norm
    self.lm_head = nn.Linear(n_embd,vocab_size)

  def forward(self,idx,targets=None):
    B,T = idx.shape
    #idx and targets are both(B,T) tensor of integers
    tok_emb = self.token_embedding_table(idx) # (B,T,C)
    pos_emb = self.position_embedding_table(torch.arange(T,device=device)) # (T,C)
    x = tok_emb + pos_emb # (B,T,C) #input embedding
    x = self.blocks(x) # (B,T,C)
    x = self.ln_f(x) # (B,T,C) #layer normalization
    logits = self.lm_head(x) # (B,T,vocab_size)

    if targets is None:
      loss = None
    else:
      B,T,C = logits.shape
      logits = logits.view(B*T,C)
      targets = targets.view(B*T)
      loss = F.cross_entropy(logits,targets)
    return logits,loss

  def generate(self,idx,max_new_tokens):
      #idx is (B,T) array of indices in the current context
      for _ in range(max_new_tokens):
        #crop idx to the last block_size tokens
        idx_cond = idx[:,-block_size:]
        #get the predictions
        logits,loss = self(idx_cond)
        #focus only on the last time step
        logits = logits[:,-1,:] #becomes (B,C)
        #apply softmax to get probabilities
        probs = F.softmax(logits,dim=-1) # (B,C)
        #sample from distribution
        idx_next = torch.multinomial(probs,num_samples=1) # (B,1)
        #append sampled index to the running sequence
        idx = torch.cat((idx,idx_next),dim=1) # (B,T+1)
        return idx

In [None]:
#Step 12: Create training and testing data
torch.manual_seed(1337)
with open('input.txt','r',encoding='utf-8') as f:
  text = f.read()
#here are all the unique characters that occur in the text
chars = sorted(list(set(text)))
vocab_size = len(chars)
#create a mapping from characters to integers
stoi = {ch:i for i,ch in enumerate(chars)} # mapping character to integer
itos = {i:ch for i,ch in enumerate(chars)} # mapping integer back to character
encode = lambda s: [stoi[c] for c in s] #encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) #decoder: take a list of integers, output a string

#train and test splits
data = torch.tensor(encode(text),dtype=torch.long)
n = int(0.9*len(data)) #first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

#data loading
def get_batch(split):
  #generate a small batch of data inputs x and targets y
  data = train_data if split == 'train' else val_data
  ix = torch.randint(len(data) - block_size , (batch_size,))
  x = torch.stack([data[i:i+block_size] for i in ix])
  y = torch.stack([data[i+1:i+block_size+1] for i in ix])
  x,y = x.to(device),y.to(device)
  return x,y

In [None]:
#Step 13: Define LLM Loss
@torch.no_grad()
def estimate_loss():
  out = {}
  model.eval()
  for split in ['train','test']:
    losses = torch.zeros(eval_iters)
    for k in range(eval_iters):
      X,Y = get_batch(split)
      logits , loss = model(Y,Y)
      losses[k] = loss.item()
    out[split] = losses.mean()
  model.train()
  return out


In [None]:
#Step 14: Define training loop parameters and other hyperparameters
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.nn import init

#hyper parameters
batch_size = 16
block_size = 32
max_iters = 20
eval_interval = 100
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("device",device)
eval_iters = 400
head_size = 16
n_embed = 128
n_head = 8
n_layer = 8
dropout = 0.1
num_experts = 8
top_k = 2
#----

In [None]:
#Step 15: Intialize the entire model
def kaiming_init_weights(m):
  if isinstance(m,(nn.Linear)):
    init.kaiming_normal(m.weight)
model = SparseMoeLanguageModel()
model.apply(kaiming_init_weights)


In [None]:
#Step16: Run the pre-training loop
m = model.to(device)
print(sum(p.numel() for p in m.parameters())/1e6,'M parameters')
optimizer = torch.optim.AdamW(model.parameters(),lr=learning_rate)

for iter in range(max_iters):
  if iter % eval_interval == 0 or iter == max_iters - 1:
    losses = estimate_loss()
    print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

  #sample a batch of data
  xb,yb = get_batch('train')

  #evaluate the loss
  logits,loss = model(xb,yb)
  optimizer.zero_grad(set_to_none=True)
  loss.backward()
  optimizer.step()

In [None]:
#Step 17: Inference
#generate from the model. Not great. Not too bad either
context = torch.zeros((1,1), dtype = torch.long , device = device)
print(decode(m.generate(context,max_new_tokens=2000)[0].tolist()))