<h1>Mixture of Experts and Switch Transformers</h1>

<h1>Overview</h1>

In this post, I will delve into two advanced methodologies in machine learning: the Mixture of Experts (MoEs) and Switch Transformers.

The Mixture of Experts (MoEs) approach proposes the idea of utilizing an ensemle of specialized models, where each model excels in a particular domain. For language modeling, this means employing multiple models, each adept at capturing specific linguistic nuances.

Switch Transformers, a variant of the MoEs approach, offer a distinctive modification to the traditional transformer architecture. By introducing efficiencies and optimizations, they further enhance the overall performance of models." Using Switch Transformers, a language model can be scaled up substantially, resulting in enhancing the overall performance of models.

In the upcoming sections, I'll begin by introducing the concept of MoEs. Next, I'll adjust the model I presented in the <a href='https://github.com/lsafarne/NLPBites.github.io/blob/main/Language_Model.ipynb'>previous post</a>, where I explained the basics of language models. Subsequently, I'll discuss switch transformers and conclude by integrating them into our language model.

<h2>Mixure of Experts (MoEs)</h2>

The concept of "Mixture of Experts" (MoE) is an approach in machine learning where multiple specialized components (or "experts") come together to make a collective decision.

In a MoE model, there are multiple expert networks, and each one is responsible for handling a specific subset or type of data. Experts are individual models or subnetworks, each trained to specialize in a different aspect of the data. For example, in a language modeling task, one expert might specialize in grammar, another in vocabulary usage, another in capturing sentiment, and so on.


<h3>Advantages</h3>
<ul>
    <li><b>Specialization:</b> Each expert can become highly specialized in a specific subset of the data, leading to more tailored and accurate predictions.</li>
    <li><b>Scalability:</b> Instead of growing a single massive network, adding more experts can increase capacity while maintaining efficiency.</li>
    <li><b>Reduced Overfitting:</b> Individual experts are typically smaller networks, making them less prone to overfitting.</li>
    <li><b>Flexibility:</b> MoE offers flexibility in terms of architecture. Experts can have different architectures or even be different types of models altogether.</li>
</ul>

<h3>Implementation</h3>

In the <a href='https://github.com/lsafarne/NLPBites.github.io/blob/main/Language_Model.ipynb'>previous post</a>, I explained how to implement a language model from scratch. In what follows, I will modify that model to incorporate MoEs.

In [43]:
import torch
import torch.nn as nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer

<h3>Step 1: Define an Expert</h3>

In [None]:
class Expert(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Expert, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size),
        )

    def forward(self, x):
        return self.fc(x)

<h3>Step 2: Define the Mixure of Experts</h3>

In [None]:
class MixtureOfExperts(nn.Module):
    def __init__(self, num_experts, input_size, hidden_size, output_size):
        super(MixtureOfExperts, self).__init__()
        self.experts = nn.ModuleList([Expert(input_size, hidden_size, output_size) for _ in range(num_experts)])

    def forward(self, x):
        # Here we need logic to decide which expert to use for each input
        # For simplicity, I choose an expert randomly
        selected_expert = torch.randint(len(self.experts), (1,)).item()
        return self.experts[selected_expert](x)

<h3>Step 3: Create the Language Model</h3>

In [None]:
class LanguageModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size, output_size, num_experts):
        super(LanguageModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        
        # Transformer layer
        encoder_layers = TransformerEncoderLayer(embedding_dim, num_heads, hidden_size)
        self.transformer_encoder = TransformerEncoder(encoder_layers, num_layers)
        
        #MoE
        self.moe = MixtureOfExperts(num_experts, hidden_size, hidden_size, output_size)
    
        #fully connected network
        self.fc = nn.Linear(output_size, vocab_size)

    def forward(self, x):
        embedded = self.embedding(x)
        output = self.transformer_encoder(embedded)
        output = self.moe(output)
        output = self.fc(output)
        return output

<h2>Gating Networks:</h2>

In the model described above, an expert is randomly selected to process the output of the encoder layer for a given token representation. This raises the question: can we optimize the selection of experts beyond mere randomness? The answer lies in utilizing gating networks. Alongside the experts, a gating network can be used to determine which expert (or combination of experts) should be utilized for a given input. In other words, the gating network determine the weighting of each expert's output based on the input. This way, the model can learn to rely on different experts for different portions of the input space.

The gating network is responsible for combining the outputs of the individual experts. It takes the same input as the experts and outputs a set of weights that determine how much each expert's prediction should contribute to the final prediction.

The idea is that for a given input, some experts may be more relevant than others, so the gating network <span style='background-color:yellow'>"routes"</span> the input to the appropriate experts by assigning higher weights to them.

For a given input, the gating network outputs a weight for each expert. These weights determine how much each expert contributes to the final output.
The expert networks process the input independently and produce their outputs.
The final output is a weighted sum of the experts' outputs based on the gating network's weights.

<h3>Implementation:</h3>

In [None]:
class MixtureOfExpertsGatingNet(nn.Module):
    def __init__(self, num_experts, input_size, hidden_size, output_size):
        super(MixtureOfExpertsGatingNet, self).__init__()
        self.experts = nn.ModuleList([Expert(input_size, hidden_size, output_size) for _ in range(num_experts)])
        
        # Gating network is a simple feedforward network with softmax output
        self.gating = nn.Sequential(
            nn.Linear(input_size, num_experts),
            nn.Softmax(dim=-1),
        )

    def forward(self, x):
        weights = self.gating(x) # Weights for each expert
        
        # Compute expert outputs and combine them based on the weights
        output = torch.zeros_like(x)
        for i, expert in enumerate(self.experts):
            expert_output = expert(x)
            output += weights[:, i].unsqueeze(1) * expert_output

        return output

<h3>Note</h3>

The nn.Linear layer will transform the input data into a shape of [batch_size, num_experts]. By specifying dim=-1 for the Softmax operation, we are ensuring that the Softmax is applied across the num_experts dimension, which means that for each data point in the batch, you get a probability distribution over all experts.

In [26]:
class LanguageModelGatingNet(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size, output_size, num_experts):
        super(LanguageModelGatingNet, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        
        # Transformer layer
        encoder_layers = TransformerEncoderLayer(embedding_dim, num_heads, hidden_size)
        self.transformer_encoder = TransformerEncoder(encoder_layers, num_layers)
        
        #MoE with gating network
        self.moe = MixtureOfExpertsGatingNet(num_experts, hidden_size, hidden_size, output_size)
    
        #fully connected network
        self.fc = nn.Linear(output_size, vocab_size)

    def forward(self, x):
        embedded = self.embedding(x)
        output = self.transformer_encoder(embedded)
        output = self.moe(output)
        output = self.fc(output)
        return output

Our small language model with MoEs and gating network is complete now. Please note that this architecture is oversimplified. In practice, we should add activation functions, perform normalization, and probably use dropout. 
Non-linear activation functions introduce non-linearities into the model, which helps the network to model complex patterns. Batch normalization or layer normalization can help stabilize the learning process and reduce the training time. In the context of transformer models, layer normalization is typically more common. Dropout is a regularization technique used to prevent overfitting. It randomly sets a fraction of input units to 0 at each update during training time.

Let's modify the above language model to incorporate these concepts. I also added a few more layers to make the model deeper.

In [None]:
class LanguageModelGatingNet(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size, output_size, num_experts, num_heads, num_layers, dropout_prob=0.1):
        super(LanguageModelGatingNet, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        
        # Transformer layer with dropout and normalization
        encoder_norm = nn.LayerNorm(embedding_dim)
        encoder_layers = TransformerEncoderLayer(embedding_dim, num_heads, hidden_size, dropout=dropout_prob, activation="gelu", norm1=encoder_norm, norm2=encoder_norm)
        self.transformer_encoder = TransformerEncoder(encoder_layers, num_layers)
        
        # MoE with gating network
        self.moe = MixtureOfExpertsGatingNet(num_experts, hidden_size, hidden_size, output_size)
    
        # Enhanced Fully connected network with additional layer and activation
        self.fc = nn.Sequential(
            nn.Dropout(dropout_prob),
            nn.Linear(output_size, output_size),  # Additional layer
            nn.ReLU(),  # Activation function
            nn.Linear(output_size, vocab_size)  # Output layer
        )

    def forward(self, x):
        embedded = self.embedding(x)
        output = self.transformer_encoder(embedded)
        output = self.moe(output)
        output = self.fc(output)
        return output

<h2>Switch Trandformers</h2>

Drawing from the principles of MoEs and gating networks, William Fedus et al. introduced the Switch Transformers in <a href='https://arxiv.org/abs/2101.03961'>this paper</a>. Switch Transformers incorporate experts within their design. By integrating experts, we can amplify the number of parameters, which can subsequently enhance a model's performance, provided that there's an ample dataset. Additionally, as touched upon earlier, each expert, by design, specializes in a unique task. This allows the model to concentrate on specific segments or combinations of the input. In the realm of language models, this pertains to processing texts or, more precisely, sequences of tokens. However, a challenge with augmenting parameters is the consequent rise in computational demands. Yet, in their paper titled "Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity", William Fedus et al. contend that by selecting only the top-scoring expert, it's feasible to harness the merits of MoEs without a drastic surge in floating point operations (FLOPs). Switch Transformers, a variant of the transformer architecture, use a gating network mechanism to 'switch' between multiple 'experts' or modules based on the input, hence their name. Instead of aggregating the outputs from all these experts and passing the combined result forward, only the expert with the highest score is selected for routing the computation. Consequently, even though the addition of a network of experts increases the number of parameters, the total number of operations doesn't see a proportional rise

<h3>Implentation</h3>

In [35]:
class SwitchLayer(nn.Module):
    def __init__(self, dim, num_experts, hidden_size):
        super(SwitchLayer, self).__init__()
        self.num_experts = num_experts
        self.experts = nn.ModuleList([nn.Linear(dim, hidden_size) for _ in range(num_experts)])
        self.gate = nn.Linear(dim, num_experts)

    def forward(self, x):
        # Gating mechanism
        gates = nn.functional.softmax(self.gate(x), dim=-1)
        
        # Route tokens to the experts
        outputs = [expert(x) * gate.unsqueeze(-1) for expert, gate in zip(self.experts, gates.chunk(self.num_experts, dim=-1))]
        
        # Sum outputs of the experts
        output = sum(outputs)
        return output

In [45]:
class LanguageModelSwitchNet(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size, output_size, num_experts, num_heads, num_layers, dropout_prob=0.1):
        super(LanguageModelSwitchNet, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        
        # Transformer layer with dropout and normalization
        encoder_norm = nn.LayerNorm(embedding_dim)
#         encoder_layers = TransformerEncoderLayer(embedding_dim, num_heads, hidden_size, dropout=dropout_prob, activation="gelu", norm1=encoder_norm, norm2=encoder_norm)
#         self.transformer_encoder = TransformerEncoder(encoder_layers, num_layers)
        encoder_layers = TransformerEncoderLayer(embedding_dim, num_heads, hidden_size)
        self.transformer_encoder = TransformerEncoder(encoder_layers, num_layers)
        
        # Switch layer instead of MoE
        self.switch = SwitchLayer(embedding_dim, num_experts, hidden_size)
    
        # Fully connected network with additional layer and activation
        self.fc = nn.Sequential(
            nn.Dropout(dropout_prob),
            nn.Linear(hidden_size, output_size),  # Additional layer
            nn.ReLU(),  # Activation function
            nn.Linear(output_size, vocab_size)  # Output layer
        )

    def forward(self, x):
        embedded = self.embedding(x)
        output = self.transformer_encoder(embedded)
        output = self.switch(output)
        output = self.fc(output)
        return output

In [37]:
import sentencepiece as spm
from torchtext.datasets import PennTreebank

# Load the trained SentencePiece model
sp = spm.SentencePieceProcessor()
sp.load('models/SentencePiecePennTree.model')

True

In [38]:
from torch.utils.data import DataLoader, Dataset

class PTBDataset(Dataset):
    def __init__(self, data, seq_len=30):
        tokens = sp.encode_as_ids(data)
        self.data = [tokens[i:i+seq_len] for i in range(len(tokens) - seq_len)]
        self.targets = [tokens[i+1:i+seq_len+1] for i in range(len(tokens) - seq_len)]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return torch.tensor(self.data[idx]), torch.tensor(self.targets[idx])

def get_raw_text_from_dataset(dataset):
    res = ''
    for sentence in dataset:
        res = res + sentence
    return res

train_raw_text = get_raw_text_from_dataset(PennTreebank(split='train'))
valid_raw_text = get_raw_text_from_dataset(PennTreebank(split='valid'))

train_dataset = PTBDataset(train_raw_text)
valid_dataset = PTBDataset(valid_raw_text)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=64)

In [46]:
import torch
import torch.optim as optim

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Initialize the model
vocab_size = sp.get_piece_size()
embedding_dim = 256 # Can adjust based on our needs
hidden_size=512 
output_size=512 
num_heads=8 
num_layers=6
num_experts=32


model = LanguageModelSwitchNet(vocab_size, embedding_dim, hidden_size, output_size, num_experts, num_heads, num_layers, dropout_prob=0.1)


criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())

# Lists for storing losses for each epoch
train_losses = []
val_losses = []

num_epochs = 1
for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for batch_idx, (data, targets) in enumerate(train_loader):
        data, targets = data.to(device), targets.to(device)
        outputs = model(data)
        loss = criterion(outputs.view(-1, vocab_size), targets.view(-1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        if batch_idx % 500 == 0:
            print(f"Epoch: {epoch+1} | Batch: {batch_idx+1} | Loss: {loss.item()}")
            train_losses.append(loss.item())
            
    # Calculate the average training loss for this epoch
    avg_train_loss = train_loss / len(train_loader)
    

    # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for data, targets in valid_loader:
            data, targets = data.to(device), targets.to(device)
            outputs = model(data)
            valid_loss = criterion(outputs.view(-1, vocab_size), targets.view(-1)).item()
            val_losses.append(valid_loss)
            val_loss += valid_loss
        # Calculate the average validation loss for this epoch
        avg_val_loss = val_loss / len(valid_loader)
    print(f"Validation Loss after epoch {epoch+1}: Train Loss: {avg_train_loss:.4f}, Validation Loss: {avg_val_loss:.4f}")


RuntimeError: The size of tensor a (64) must match the size of tensor b (30) at non-singleton dimension 1

Implementing the full Switch Transformer architecture from scratch would be quite complex, but we can use the <b>Hugging Face</b> &#x1F60A; Transformers library, which provides a prebuilt Switch Transformer implementation that you can use as part of your model.

In [12]:
#!pip install transformers

In [13]:
import transformers
from transformers import SwitchTransformersConfig
from transformers import SwitchTransformersModel

In a <a href='https://github.com/lsafarne/NLPBites.github.io/blob/main/text_tokenization.ipynb'>previous post</a>, I discussed SentencePiece tokenizers and trained one using the Penn Treebank dataset. For brevity in this post, I will reuse that tokenizer and dataset. For those interested in details, I recommend referring to the aforementioned post.

In [16]:
import sentencepiece as spm
from torchtext.datasets import PennTreebank

# Load the trained SentencePiece model
sp = spm.SentencePieceProcessor()
sp.load('models/SentencePiecePennTree.model')

True

First, we need to define a configuration object:

In [17]:
vocab_size = sp.get_piece_size()
config = SwitchTransformersConfig(
    vocab_size=vocab_size, # Size of your vocabulary
    num_experts=32, # Number of experts
    hidden_size=256, # Size of the hidden layer
    num_attention_heads=8, # Number of attention heads
    num_layers=6, # Number of layers
)

Then, we create an instance of the SwitchTransformersModel. The <i>config</i> object provides the architectural details for the switch transformer instance:

In [22]:
switch_transformer = SwitchTransformersModel(config)

In [27]:
import torch.nn as nn
class SwitchTransformerLanguageModel(nn.Module):
    def __init__(self, vocab_size, switch_transformer, embedding_dim):
        super(SwitchTransformerLanguageModel, self).__init__()
        
        # Embedding layer
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        
        # Switch Transformer layer
        self.switch_transformer = switch_transformer
        
        # Output fully connected layer
        self.fc = nn.Linear(embedding_dim, vocab_size) # Assuming the switch transformer doesn't change the embedding dimensionality

    def forward(self, x):
        # x: [batch_size, seq_len]
        
        embedded = self.embedding(x) 
        # embedded: [batch_size, seq_len, embedding_dim]
        
        switch_out = self.switch_transformer(embedded)
        # switch_out: [batch_size, seq_len, embedding_dim]
        
        output = self.fc(switch_out)
        # output: [batch_size, seq_len, vocab_size]
        
        return output



<h2>Training</h2>

Before we start training, we need to setup our dataset.

In [33]:
from torch.utils.data import DataLoader, Dataset

class PTBDataset(Dataset):
    def __init__(self, data, seq_len=30):
        tokens = sp.encode_as_ids(data)
        self.data = [tokens[i:i+seq_len] for i in range(len(tokens) - seq_len)]
        self.targets = [tokens[i+1:i+seq_len+1] for i in range(len(tokens) - seq_len)]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return torch.tensor(self.data[idx]), torch.tensor(self.targets[idx])

def get_raw_text_from_dataset(dataset):
    res = ''
    for sentence in dataset:
        res = res + sentence
    return res

train_raw_text = get_raw_text_from_dataset(PennTreebank(split='train'))
valid_raw_text = get_raw_text_from_dataset(PennTreebank(split='valid'))

train_dataset = PTBDataset(train_raw_text)
valid_dataset = PTBDataset(valid_raw_text)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=64)

In [34]:
import torch
import torch.optim as optim

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Initialize the model
vocab_size = sp.get_piece_size()
embedding_dim = 256 # Can adjust based on our needs

model = SwitchTransformerLanguageModel(vocab_size, switch_transformer, embedding_dim)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())

# Lists for storing losses for each epoch
train_losses = []
val_losses = []

num_epochs = 1
for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for batch_idx, (data, targets) in enumerate(train_loader):
        data, targets = data.to(device), targets.to(device)
        outputs = model(data)
        loss = criterion(outputs.view(-1, vocab_size), targets.view(-1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        if batch_idx % 500 == 0:
            print(f"Epoch: {epoch+1} | Batch: {batch_idx+1} | Loss: {loss.item()}")
            train_losses.append(loss.item())
            
    # Calculate the average training loss for this epoch
    avg_train_loss = train_loss / len(train_loader)
    

    # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for data, targets in valid_loader:
            data, targets = data.to(device), targets.to(device)
            outputs = model(data)
            valid_loss = criterion(outputs.view(-1, vocab_size), targets.view(-1)).item()
            val_losses.append(valid_loss)
            val_loss += valid_loss
        # Calculate the average validation loss for this epoch
        avg_val_loss = val_loss / len(valid_loader)
    print(f"Validation Loss after epoch {epoch+1}: Train Loss: {avg_train_loss:.4f}, Validation Loss: {avg_val_loss:.4f}")


RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.FloatTensor instead (while checking arguments for embedding)