In [1]:
# | default_exp moe

%load_ext autoreload
%autoreload 2

%env TOKENIZERS_PARALLELISM=false

env: TOKENIZERS_PARALLELISM=false


In [2]:
from icecream import ic

In [3]:
# | export

import torch
import torch.nn as nn
from torch.functional import F

# Mixture of experts

https://huggingface.co/blog/AviSoori1x/makemoe-from-scratch

In [13]:
# | export
# Expert module
class Expert(nn.Module):
    """An MLP is a simple linear layer followed by a non-linearity i.e. each Expert"""

    def __init__(self, n_embd: int, dropout: float = 0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

In [14]:
# Understanding how gating/router works
num_experts = 4
top_k = 2
n_embed = 32


# Example multi-head attention output for a simple illustrative example, consider n_embed=32, context_length=4 and batch_size=2
mh_output = torch.randn(2, 4, n_embed)

topkgate_linear = nn.Linear(n_embed, num_experts)  # nn.Linear(32, 4)

logits = topkgate_linear(mh_output)
top_k_logits, top_k_indices = logits.topk(top_k, dim=-1)  # Get top-k experts
print(logits)
"top k logits:", top_k_logits, "top k indices:", top_k_indices

tensor([[[-1.8681,  0.0758, -0.5876, -0.1003],
         [ 0.5114,  0.3268,  0.2343,  0.0684],
         [-0.4361,  0.2250,  0.1455, -0.2001],
         [ 0.4120,  0.1028,  0.6582,  0.0750]],

        [[-0.3097,  0.1241,  0.3161, -0.2436],
         [ 0.4127, -0.6060, -0.9747,  0.1802],
         [-0.3331, -0.0450,  0.3510,  0.2207],
         [ 1.0282, -1.1125, -0.4474, -0.0438]]], grad_fn=<ViewBackward0>)


('top k logits:',
 tensor([[[ 0.0758, -0.1003],
          [ 0.5114,  0.3268],
          [ 0.2250,  0.1455],
          [ 0.6582,  0.4120]],
 
         [[ 0.3161,  0.1241],
          [ 0.4127,  0.1802],
          [ 0.3510,  0.2207],
          [ 1.0282, -0.0438]]], grad_fn=<TopkBackward0>),
 'top k indices:',
 tensor([[[1, 3],
          [0, 1],
          [1, 2],
          [2, 0]],
 
         [[2, 1],
          [0, 3],
          [2, 3],
          [0, 3]]]))

In [15]:
# keep the top-k experts and set the rest to -inf
zeros = torch.full_like(
    logits, float("-inf")
)  # full_like clones a tensor and fills it with a specified value (like infinity) for masking or calculations.
sparse_logits = zeros.scatter(-1, top_k_indices, top_k_logits)
# transform the logits into a probability distribution
gating_output = F.softmax(sparse_logits, dim=-1)
sparse_logits, gating_output

(tensor([[[   -inf,  0.0758,    -inf, -0.1003],
          [ 0.5114,  0.3268,    -inf,    -inf],
          [   -inf,  0.2250,  0.1455,    -inf],
          [ 0.4120,    -inf,  0.6582,    -inf]],
 
         [[   -inf,  0.1241,  0.3161,    -inf],
          [ 0.4127,    -inf,    -inf,  0.1802],
          [   -inf,    -inf,  0.3510,  0.2207],
          [ 1.0282,    -inf,    -inf, -0.0438]]], grad_fn=<ScatterBackward0>),
 tensor([[[0.0000, 0.5439, 0.0000, 0.4561],
          [0.5460, 0.4540, 0.0000, 0.0000],
          [0.0000, 0.5199, 0.4801, 0.0000],
          [0.4387, 0.0000, 0.5613, 0.0000]],
 
         [[0.0000, 0.4521, 0.5479, 0.0000],
          [0.5579, 0.0000, 0.0000, 0.4421],
          [0.0000, 0.0000, 0.5325, 0.4675],
          [0.7450, 0.0000, 0.0000, 0.2550]]], grad_fn=<SoftmaxBackward0>))

In [16]:
# | export
# First define the top k router module
class TopkRouter(nn.Module):
    def __init__(self, n_embed, num_experts, top_k):
        super(TopkRouter, self).__init__()
        self.top_k = top_k
        self.linear = nn.Linear(n_embed, num_experts)

    def forward(self, mh_ouput):
        # mh_ouput is the output tensor from multihead self attention block
        logits = self.linear(mh_output)
        top_k_logits, indices = logits.topk(self.top_k, dim=-1)
        zeros = torch.full_like(logits, float("-inf"))
        sparse_logits = zeros.scatter(-1, indices, top_k_logits)
        router_output = F.softmax(sparse_logits, dim=-1)
        return router_output, indices

In [17]:
# Testing this out:
num_experts = 4
top_k = 2
n_embd = 32

mh_output = torch.randn(2, 4, n_embd)  # Example input
top_k_gate = TopkRouter(n_embd, num_experts, top_k)
gating_output, indices = top_k_gate(mh_output)
gating_output.shape, gating_output, indices
# And it works!!

(torch.Size([2, 4, 4]),
 tensor([[[0.5215, 0.4785, 0.0000, 0.0000],
          [0.0000, 0.5903, 0.4097, 0.0000],
          [0.4191, 0.0000, 0.5809, 0.0000],
          [0.5085, 0.4915, 0.0000, 0.0000]],
 
         [[0.5912, 0.4088, 0.0000, 0.0000],
          [0.6839, 0.3161, 0.0000, 0.0000],
          [0.6206, 0.3794, 0.0000, 0.0000],
          [0.0000, 0.5278, 0.4722, 0.0000]]], grad_fn=<SoftmaxBackward0>),
 tensor([[[0, 1],
          [1, 2],
          [2, 0],
          [0, 1]],
 
         [[0, 1],
          [0, 1],
          [0, 1],
          [1, 2]]]))

In [18]:
# softplus is a smoothed version of RELU function.
input = torch.tensor([0.2, 2.3, 10.0, -0.1, -3.2, -10.0])
sp = F.softplus(input)

rand_noise = torch.randn_like(input)
out = rand_noise * F.softplus(input)
(
    input,
    sp.numpy().round(4),
    out.numpy().round(4),
)

(tensor([  0.2000,   2.3000,  10.0000,  -0.1000,  -3.2000, -10.0000]),
 array([ 0.7981,  2.3955, 10.    ,  0.6444,  0.04  ,  0.    ],
       dtype=float32),
 array([-5.3660e-01,  5.3200e-01,  4.0517e+00, -7.4320e-01,  1.0960e-01,
        -1.0000e-04], dtype=float32))

In [19]:
# | export
# Changing the above to accomodate noisy top-k gating
class NoisyTopkRouter(nn.Module):
    """Essentially, you don't want all the tokens to be sent to the same set of 'favored' experts.
    You want a fine balance of exploitation and exploration. For this purpose, to load balance,
    it is helpful to add standard normal noise to the logits from the gating linear layer.
    This makes training more efficient"""

    def __init__(self, n_embed, num_experts, top_k):
        super(NoisyTopkRouter, self).__init__()
        self.top_k = top_k
        # layer for router logits
        self.topkroute_linear = nn.Linear(n_embed, num_experts)
        self.noise_linear = nn.Linear(n_embed, num_experts)

    def forward(self, mh_output):
        # mh_ouput is the output tensor from multihead self attention block
        logits = self.topkroute_linear(mh_output)

        # Noise logits
        noise_logits = self.noise_linear(mh_output)

        # Adding scaled unit gaussian noise to the logits
        # softplus ensures that the noise is always positive and right skewed
        noise = torch.randn_like(logits) * F.softplus(noise_logits)
        # noisy logit add noise to the logits so some tokens are sent to different experts and not just the top-k.
        # It pushes the model to explore more.
        noisy_logits = logits + noise

        top_k_logits, indices = noisy_logits.topk(self.top_k, dim=-1)
        zeros = torch.full_like(noisy_logits, float("-inf"))
        sparse_logits = zeros.scatter(-1, indices, top_k_logits)
        router_output = F.softmax(sparse_logits, dim=-1)
        return router_output, indices

In [20]:
num_experts = 8
top_k = 2
n_embd = 16

mh_output = torch.randn(2, 4, n_embd)  # Example input
noisy_top_k_gate = NoisyTopkRouter(n_embd, num_experts, top_k)
gating_output, indices = noisy_top_k_gate(mh_output)
gating_output.shape, gating_output, indices

(torch.Size([2, 4, 8]),
 tensor([[[0.0000, 0.0000, 0.6716, 0.3284, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.4136, 0.0000, 0.0000, 0.5864, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.5642, 0.0000, 0.0000, 0.0000, 0.0000, 0.4358, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4763, 0.5237]],
 
         [[0.2149, 0.0000, 0.0000, 0.0000, 0.7851, 0.0000, 0.0000, 0.0000],
          [0.4339, 0.0000, 0.5661, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.6298, 0.0000, 0.0000, 0.0000, 0.3702, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.5971, 0.0000, 0.4029, 0.0000]]],
        grad_fn=<SoftmaxBackward0>),
 tensor([[[2, 3],
          [3, 0],
          [0, 5],
          [7, 6]],
 
         [[4, 0],
          [2, 0],
          [0, 4],
          [4, 6]]]))

In [21]:
# | export
class SparseMoE(nn.Module):
    def __init__(self, n_embed, num_experts, top_k):
        super(SparseMoE, self).__init__()
        self.router = NoisyTopkRouter(n_embed, num_experts, top_k)
        self.experts = nn.ModuleList([Expert(n_embed) for _ in range(num_experts)])
        self.top_k = top_k

    def forward(self, x):
        gating_output, indices = self.router(x)
        ic(gating_output)
        final_output = torch.zeros_like(x)

        # Reshape inputs for batch processing
        flat_x = x.view(-1, x.size(-1))
        flat_gating_output = gating_output.view(-1, gating_output.size(-1))

        # Process each expert in parallel
        for i, expert in enumerate(self.experts):
            # Create a mask for the inputs where the current expert is in top-k
            expert_mask = (indices == i).any(dim=-1)
            ic(expert_mask)
            flat_mask = expert_mask.view(-1)

            if flat_mask.any():
                expert_input = flat_x[flat_mask]
                expert_output = expert(expert_input)

                # Extract and apply gating scores
                gating_scores = flat_gating_output[flat_mask, i].unsqueeze(1)
                weighted_output = expert_output * gating_scores

                # Update final output additively by indexing and adding
                final_output[expert_mask] += weighted_output.squeeze(1)

        return final_output

In [22]:
import torch
import torch.nn as nn

ic.enable()
# ic.disable()

# Let's test this out
num_experts = 4
top_k = 2
n_embd = 16
dropout = 0.1

mh_output = torch.randn(1, 8, n_embd)  # Example multi-head attention output
sparse_moe = SparseMoE(n_embd, num_experts, top_k)
final_output = sparse_moe(mh_output)
print("Shape of the final output:", final_output.shape)

ic| gating_output: tensor([[[0.5464, 0.4536, 0.0000, 0.0000],
                            [0.6865, 0.0000, 0.3135, 0.0000],
                            [0.5082, 0.0000, 0.0000, 0.4918],
                            [0.0000, 0.6408, 0.0000, 0.3592],
                            [0.0000, 0.4666, 0.0000, 0.5334],
                            [0.7807, 0.2193, 0.0000, 0.0000],
                            [0.6676, 0.0000, 0.3324, 0.0000],
                            [0.7278, 0.0000, 0.2722, 0.0000]]], grad_fn=<SoftmaxBackward0>)
ic| expert_mask: tensor([[ True,  True,  True, False, False,  True,  True,  True]])
ic| expert_mask: tensor([[ True, False, False,  True,  True,  True, False, False]])
ic| expert_mask: tensor([[False,  True, False, False, False, False,  True,  True]])
ic| expert_mask: tensor([[False, False,  True,  True,  True, False, False, False]])


Shape of the final output: torch.Size([1, 8, 16])
