In [2]:

#for displaying figures in code editor
#%matplotlib inline
import matplotlib.pyplot as plt
from time import perf_counter
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn
import random




In [3]:
class Expert(nn.Module):
    """ An MLP is a simple linear layer followed by a non-linearity i.e. each Expert """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

In [4]:
# First define the top k router module
class TopkRouter(nn.Module):
    def __init__(self, n_embed, num_experts, top_k):
        super(TopkRouter, self).__init__()
        self.top_k = top_k
        self.linear =nn.Linear(n_embed, num_experts)

    def forward(self, mh_output):
        logits = self.linear(mh_output) # (B,Token,num_experts)
        top_k_logits, indices = logits.topk(self.top_k, dim=-1)# Get top-k experts
        #indices: for each observation get the highest values (indices of best models) 
        # (return the two biggest values of the last dim(-1))
        zeros = torch.full_like(logits, float('-inf')) #zero(-inf) matrx of shape logits
        sparse_logits = zeros.scatter(-1, indices, top_k_logits)#keep k-values in matrix
        router_output = F.softmax(sparse_logits, dim=-1)#convert to probs
        return router_output, indices #output router output and indices

#Changing the above to accomodate noisy top-k gating
class NoisyTopkRouter(nn.Module):
    def __init__(self, n_embed, num_experts, top_k):
        super(NoisyTopkRouter, self).__init__()
        self.top_k = top_k
        #layer for router logits
        self.topkroute_linear = nn.Linear(n_embed, num_experts)
        self.noise_linear =nn.Linear(n_embed, num_experts)


    def forward(self, mh_output):
        # mh_ouput is the output tensor from multihead self attention block
        logits = self.topkroute_linear(mh_output)
        #Noise logits
        noise_logits = self.noise_linear(mh_output)
        #Adding scaled unit gaussian noise to the logits
        noise = torch.randn_like(logits)*F.softplus(noise_logits)
        noisy_logits = logits + noise
        top_k_logits, indices = noisy_logits.topk(self.top_k, dim=-1)
        zeros = torch.full_like(noisy_logits, float('-inf'))
        sparse_logits = zeros.scatter(-1, indices, top_k_logits)
        router_output = F.softmax(sparse_logits, dim=-1)
        return router_output, indices

In [5]:
#Testing this out:
num_experts = 4
n_embd = 32
top_k = 2

mh_output = torch.randn(2, 4, n_embd)  # Example input
print(mh_output.shape)
top_k_gate = NoisyTopkRouter(n_embd, num_experts, top_k)

gating_output, indices = top_k_gate(mh_output)
gating_output.shape, gating_output, indices.shape
#And it works!!

torch.Size([2, 4, 32])


(torch.Size([2, 4, 4]),
 tensor([[[0.3192, 0.6808, 0.0000, 0.0000],
          [0.3447, 0.6553, 0.0000, 0.0000],
          [0.4407, 0.0000, 0.0000, 0.5593],
          [0.0000, 0.9703, 0.0000, 0.0297]],
 
         [[0.1508, 0.8492, 0.0000, 0.0000],
          [0.3747, 0.6253, 0.0000, 0.0000],
          [0.2809, 0.0000, 0.7191, 0.0000],
          [0.8214, 0.0000, 0.1786, 0.0000]]], grad_fn=<SoftmaxBackward0>),
 torch.Size([2, 4, 2]))

In [24]:
class SparseMoE(nn.Module):
    def __init__(self, n_embed, num_experts, top_k):
        super(SparseMoE, self).__init__()
        self.router = NoisyTopkRouter(n_embed, num_experts, top_k)
        self.experts = nn.ModuleList([Expert(n_embed) for _ in range(num_experts)])
        self.top_k = top_k

    def forward(self, x):
        #Get weights matrix, containing top-k and their indices(matrix) 
        gating_output, indices = self.router(x) #[B,tokens,num_experts],[B,tokens,top_k]
        final_output = torch.zeros_like(x)

        # Flatten observations and weight matrix
        
        flat_x = x.view(-1, x.size(-1)) #[B*tokens,d_model]
        flat_gating_output = gating_output.view(-1, gating_output.size(-1)) #[B*tokens,num_experts]
        # Process each expert in parallel
        for i, expert in enumerate(self.experts):
            # Create a mask for the inputs where the current expert is in top-k 
            #(indices == i): convert [B,tokens,top_k] to bool [B,tokens,top_k] if i present  
            #(indices == i).any(dim=-1): dim reduction from [B,tokens,top_k] to [B,tokens] #ex i=0: if [1,0] -->[False,True]->[True]
                                   #[2,5] -->[False,False]->[False]
                                   #[3,0] -->[False,True]->[True]
            print(indices)
            print("--------")
            #indices [B,Tokens,top_k]
            #expert_mask:bool:[B,tokens] (true if indice i was one of the top_k values for each obsevation (token))
            expert_mask = (indices == i).any(dim=-1) #[B,tokens]
            #flat_mask:bool: boolean mask for top_k indices for current expert
            flat_mask = expert_mask.view(-1) #bool[B*tokens] #observations where current expert is top-k
            print("Number of top_k: ",(flat_mask).sum(dim = -1))
            top_k_value =   (flat_mask).sum(dim = -1)
            
            if flat_mask.any(): #if current expert is in top_k for any observation(token)
                #Create a mask for the inputs where the current expert is in top_k
                #get obsevations for which the current expert was in top-k (best model)             
                expert_input = flat_x[flat_mask]             
                #pass top_k observations to expert current expert
#                print("expert input ", expert_input.shape)
            expert_output = expert(expert_input) #[top_k observations,d_model]
          
            # Extract and apply gating scores
            #get the observations' top-k weights for the current expert
            #flat_gating_output[flat_mask, i] returns current expert's top_k weights
            gating_scores = flat_gating_output[flat_mask, i].unsqueeze(1) #[top_k observations,1]
            print("flat_gating_output")
            #updating weights 
            print("expert_output ", expert_output.shape)
            print("gating_scores ", gating_scores.shape)
            weighted_output = expert_output * gating_scores #[top_k observations,d_model]
            print("weighted_output",weighted_output.shape)
            # Update final output additively by indexing and adding
            final_output[expert_mask] += weighted_output.squeeze(1) #[top_k observations,d_model]
            
            print("final_output ", final_output.shape)
        return final_output

In [25]:
# class SparseMoE(nn.Module):
#     def __init__(self, n_embed, num_experts, top_k):
#         super(SparseMoE, self).__init__()
#         self.router = NoisyTopkRouter(n_embed, num_experts, top_k)
#         self.experts = nn.ModuleList([Expert(n_embed) for _ in range(num_experts)])
#         self.top_k = top_k
# 
#     def forward(self, x):
#         #Get weights matrix, containing top-k and their indices(matrix) 
#         gating_output, indices = self.router(x) #[B,tokens,num_experts],[B,tokens,top_k]
#         final_output = torch.zeros_like(x)
#         final_outputt = 0
#         
#         # Flatten observations and weight matrix
#         
#         flat_x = x.view(-1, x.size(-1)) #[B*tokens,d_model]
#         flat_gating_output = gating_output.view(-1, gating_output.size(-1)) #[B*tokens,num_experts]
#         # Process each expert in parallel
#         for i, expert in enumerate(self.experts):
#             # Create a mask for the inputs where the current expert is in top-k 
#             #(indices == i): convert [B,tokens,top_k] to bool [B,tokens,top_k] if i present  
#             #(indices == i).any(dim=-1): dim reduction from [B,tokens,top_k] to [B,tokens] #ex i=0: if [1,0] -->[False,True]->[True]
#                                    #[2,5] -->[False,False]->[False]
#                                    #[3,0] -->[False,True]->[True]
#             print(indices)
#             print("--------")
#             #indices [B,Tokens,top_k]
#             #expert_mask:bool:[B,tokens] (true if indice i was one of the top_k values for each obsevation (token))
#             expert_mask = (indices == i).any(dim=-1) #[B,tokens]
#             #flat_mask:bool: boolean mask for top_k indices for current expert
#             #flat_mask = expert_mask.view(-1) #bool[B*tokens] #observations where current expert is top-k
#             #print("Number of top_k: ",(flat_mask).sum(dim = -1))
#             #top_k_value =   (flat_mask).sum(dim = -1)
#             
#             #if flat_mask.any(): #if current expert is in top_k for any observation(token)
#                 # Create a mask for the inputs where the current expert is in top_k
#                 #get obsevations for which the current expert was in top-k (best model)             
#             #    expert_input = flat_x[flat_mask]             
#                 #pass top_k observations to expert current expert
#             #print("expert input ", expert_input.shape)
#             expert_output = expert(flat_x) #[top_k observations,d_model]
#           
#             # Extract and apply gating scores
#             #get the observations' top-k weights for the current expert
#             #flat_gating_output[flat_mask, i] returns current expert's top_k weights
#             
#             gating_scores = flat_gating_output[:, i].unsqueeze(1) #[top_k observations,1]
#             #updating weights 
#             weighted_output = expert_output * gating_scores #[top_k observations,d_model]
#             # Update final output additively by indexing and adding
#             final_outputt += weighted_output.squeeze(1) #[top_k observations,d_model]
#             
#             print("final_output ", final_output.shape)
#         return final_output

In [26]:
#Let's test this out
num_experts = 7
top_k = 2
n_embd = 16
dropout=0.1

mh_output = torch.randn(4, 8, n_embd)  # Example multi-head attention output
print("mh_output", mh_output.shape) #[B,tokens,d_model]

sparse_moe = SparseMoE(n_embd, num_experts, top_k)
final_output = sparse_moe(mh_output)
#print("Shape of the final output:", final_output.shape)
print(final_output.shape)

mh_output torch.Size([4, 8, 16])
tensor([[[2, 3],
         [0, 2],
         [2, 3],
         [3, 5],
         [5, 1],
         [6, 5],
         [4, 3],
         [4, 6]],

        [[0, 2],
         [1, 4],
         [0, 2],
         [0, 2],
         [6, 2],
         [6, 1],
         [1, 0],
         [2, 0]],

        [[2, 3],
         [2, 0],
         [3, 0],
         [0, 1],
         [3, 2],
         [1, 6],
         [4, 6],
         [3, 6]],

        [[2, 6],
         [6, 3],
         [6, 2],
         [6, 3],
         [3, 0],
         [0, 3],
         [1, 3],
         [0, 2]]])
--------
Number of top_k:  tensor(12)
flat_gating_output
expert_output  torch.Size([12, 16])
gating_scores  torch.Size([12, 1])
weighted_output torch.Size([12, 16])
final_output  torch.Size([4, 8, 16])
tensor([[[2, 3],
         [0, 2],
         [2, 3],
         [3, 5],
         [5, 1],
         [6, 5],
         [4, 3],
         [4, 6]],

        [[0, 2],
         [1, 4],
         [0, 2],
         [0, 2],
       