<h1>Mixture of Experts and Switch Transformers</h1>

<h1>Overview</h1>

In this post, I will delve into two advanced methodologies in machine learning: the Mixture of Experts (MoEs) and Switch Transformers.

The Mixture of Experts (MoEs) approach proposes the idea of utilizing an ensemle of specialized models, where each model excels in a particular domain. For language modeling, this means employing multiple models, each adept at capturing specific linguistic nuances.

Switch Transformers, a variant of the MoEs approach, offer a distinctive modification to the traditional transformer architecture. By introducing efficiencies and optimizations, they further enhance the overall performance of models." Using Switch Transformers, a language model can be scaled up substantially, resulting in enhancing the overall performance of models.

In the following sections, I will first explain the concept of MoEs. I will then modify the model that I constructed in the previous post, where I explained what language models are.

<h2>Mixure of Experts (MoEs)</h2>

The concept of "Mixture of Experts" (MoE) is an approach in machine learning where multiple specialized components (or "experts") come together to make a collective decision.

In a MoE model, there are multiple expert networks, and each one is responsible for handling a specific subset or type of data. Experts are individual models or subnetworks, each trained to specialize in a different aspect of the data. For example, in a language modeling task, one expert might specialize in grammar, another in vocabulary usage, another in capturing sentiment, and so on.


<h3>Advantages</h3>
<ul>
    <li><b>Specialization:</b> Each expert can become highly specialized in a specific subset of the data, leading to more tailored and accurate predictions.</li>
    <li><b>Scalability:</b> Instead of growing a single massive network, adding more experts can increase capacity while maintaining efficiency.</li>
    <li><b>Reduced Overfitting:</b> Individual experts are typically smaller networks, making them less prone to overfitting.</li>
    <li><b>Flexibility:</b> MoE offers flexibility in terms of architecture. Experts can have different architectures or even be different types of models altogether.</li>
</ul>

<h3>Implementation</h3>

In the <a href='https://github.com/lsafarne/NLPBites.github.io/blob/main/Language_Model.ipynb'>previous post</a>, I explained how to implement a language model from scratch. In what follows, I will modify that model to incorporate MoEs.

In [2]:
import torch
import torch.nn as nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer

<h3>Step 1: Define an Expert</h3>

In [4]:
class Expert(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Expert, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size),
        )

    def forward(self, x):
        return self.fc(x)

<h3>Step 2: Define the Mixure of Experts</h3>

In [6]:
class MixtureOfExperts(nn.Module):
    def __init__(self, num_experts, input_size, hidden_size, output_size):
        super(MixtureOfExperts, self).__init__()
        self.experts = nn.ModuleList([Expert(input_size, hidden_size, output_size) for _ in range(num_experts)])

    def forward(self, x):
        # Here you need logic to decide which expert to use for each input
        # For simplicity, I'll randomly choose an expert
        selected_expert = torch.randint(len(self.experts), (1,)).item()
        return self.experts[selected_expert](x)

<h3>Step 3: Create the Language Model</h3>

In [8]:
class LanguageModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size, output_size, num_experts):
        super(LanguageModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        
        # Transformer layer
        encoder_layers = TransformerEncoderLayer(embedding_dim, num_heads, hidden_size)
        self.transformer_encoder = TransformerEncoder(encoder_layers, num_layers)
        
        #MoE
        self.moe = MixtureOfExperts(num_experts, hidden_size, hidden_size, output_size)
    
        #fully connected network
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, x):
        embedded = self.embedding(x)
        output = self.transformer_encoder(embedded)
        output = self.moe(output)
        output = self.fc(output)
        return output

<h2>Gating Networks:</h2>

Alongside the experts, a gating network can be used to determine which expert (or combination of experts) should be utilized for a given input. In other words, the gating network determine the weighting of each expert's output based on the input. This way, the model can learn to rely on different experts for different portions of the input space.

The gating network is responsible for combining the outputs of the individual experts. It takes the same input as the experts and outputs a set of weights that determine how much each expert's prediction should contribute to the final prediction.

The idea is that for a given input, some experts may be more relevant than others, so the gating network <span style='background-color:yellow'>"routes"</span> the input to the appropriate experts by assigning higher weights to them.

For a given input, the gating network outputs a weight for each expert. These weights determine how much each expert contributes to the final output.
The expert networks process the input independently and produce their outputs.
The final output is a weighted sum of the experts' outputs based on the gating network's weights.

<h3>Implementation:</h3>

In [9]:
class MixtureOfExpertsGatingNet(nn.Module):
    def __init__(self, num_experts, input_size, hidden_size, output_size):
        super(MixtureOfExpertsGatingNet, self).__init__()
        self.experts = nn.ModuleList([Expert(input_size, hidden_size, output_size) for _ in range(num_experts)])
        
        # Gating network is a simple feedforward network with softmax output
        self.gating = nn.Sequential(
            nn.Linear(input_size, num_experts),
            nn.Softmax(dim=-1),
        )

    def forward(self, x):
        weights = self.gating(x) # Weights for each expert
        
        # Compute expert outputs and combine them based on the weights
        output = torch.zeros_like(x)
        for i, expert in enumerate(self.experts):
            expert_output = expert(x)
            output += weights[:, i].unsqueeze(1) * expert_output

        return output

<h3>Note</h3>

The nn.Linear layer will transform the input data into a shape of [batch_size, num_experts]. By specifying dim=-1 for the Softmax operation, we are ensuring that the Softmax is applied across the num_experts dimension, which means that for each data point in the batch, you get a probability distribution over all experts.

In [10]:
class LanguageModelGatingNet(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size, output_size, num_experts):
        super(LanguageModelGatingNet, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        
        # Transformer layer
        encoder_layers = TransformerEncoderLayer(embedding_dim, num_heads, hidden_size)
        self.transformer_encoder = TransformerEncoder(encoder_layers, num_layers)
        
        #MoE with gating network
        self.moe = MixtureOfExpertsGatingNet(num_experts, hidden_size, hidden_size, output_size)
    
        #fully connected network
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, x):
        embedded = self.embedding(x)
        output = self.transformer_encoder(embedded)
        output = self.moe(output)
        output = self.fc(output)
        return output

Our small language model with MoEs and gating network is complete now. Please note that this architecture is oversimplified. In practice, we should add activation functions, perform normalization, and probably use dropout. 
Non-linear activation functions introduce non-linearities into the model, which helps the network to model complex patterns. Batch normalization or layer normalization can help stabilize the learning process and reduce the training time. In the context of transformer models, layer normalization is typically more common. Dropout is a regularization technique used to prevent overfitting. It randomly sets a fraction of input units to 0 at each update during training time.

Let's modify the above language model to incorporate these concepts. I also added a few more layers to make the model deeper.

In [12]:
class LanguageModelGatingNet(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size, output_size, num_experts, num_heads, num_layers, dropout_prob=0.1):
        super(LanguageModelGatingNet, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        
        # Transformer layer with dropout and normalization
        encoder_norm = nn.LayerNorm(embedding_dim)
        encoder_layers = TransformerEncoderLayer(embedding_dim, num_heads, hidden_size, dropout=dropout_prob, activation="gelu", norm1=encoder_norm, norm2=encoder_norm)
        self.transformer_encoder = TransformerEncoder(encoder_layers, num_layers)
        
        # MoE with gating network
        self.moe = MixtureOfExpertsGatingNet(num_experts, hidden_size, hidden_size, output_size)
    
        # Enhanced Fully connected network with additional layer and activation
        self.fc = nn.Sequential(
            nn.Dropout(dropout_prob),
            nn.Linear(hidden_size, hidden_size),  # Additional layer
            nn.ReLU(),  # Activation function
            nn.Linear(hidden_size, vocab_size)  # Output layer
        )

    def forward(self, x):
        embedded = self.embedding(x)
        output = self.transformer_encoder(embedded)
        output = self.moe(output)
        output = self.fc(output)
        return output