In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import pandas as pd
from icecream import ic
from tqdm import tqdm

#### <span style="font-family: 'Bebas Neue'; font-size:1.2em;">The wikitext dataset.

- <span style="font-family: 'Bebas Neue'; font-size:1em;">Has 1.8 million rows of text from wikipedia
- <span style="font-family: 'Bebas Neue'; font-size:1em;">Combined it to a very long string for processing

In [None]:
df1 = pd.read_parquet('train-00000-of-00002.parquet')
df2 = pd.read_parquet('train-00001-of-00002.parquet')
comb = pd.concat([df1,df2], ignore_index=True)

text_data = "\n".join(comb["text"].astype(str))

<span style="font-family: 'Bebas Neue'; font-size:24px;">Tried training a custom tokenizer, but due to loading issues dropped it and grabbed a pre-trained tokenizer  
<span style="font-family: 'Bebas Neue'; font-size:24px;">Will eventually figure it out

In [None]:
# from tokenizers import BertWordPieceTokenizer

# tokenizer = BertWordPieceTokenizer(
#     clean_text=True,
#     handle_chinese_chars=True,
#     strip_accents=True,
#     lowercase=True,
# )

# tokenizer.train_from_iterator(
#     text_data.split("\n"),
#     vocab_size=32000,
#     min_frequency=2,
#     special_tokens=["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"],
#     limit_alphabet=1000,
#     wordpieces_prefix="##"
# )

# def tokenize_text(text_data, tokenizer, max_length, chunk_size):
#     encodings = []
#     for i in range(0, len(text_data), chunk_size):
#         chunk = text_data[i:i+chunk_size]
#         chunk_texts = chunk.split("\n")
#         chunk_encodings = tokenizer.encode_batch(chunk_texts)
#         for enc in chunk_encodings:
#             encoding = {
#                 "ids": enc.ids[:max_length],
#                 "attention_mask": [1] * len(enc.ids[:max_length]),
#             }
#             overflow_ids = enc.ids[max_length:]
#             while overflow_ids:
#                 encoding["overflowing_tokens"] = overflow_ids[:max_length]
#                 encoding["overflow_attention_mask"] = [1] * len(overflow_ids[:max_length])
#                 overflow_ids = overflow_ids[max_length:]
#                 encodings.append(encoding)
#                 encoding = {
#                     "ids": [],
#                     "attention_mask": [],
#                 }
#     return encodings

# # Set the chunk size according to your memory constraints
# chunk_size = 10000  # Adjust this value based on your system's memory capacity

# # Tokenize the text data in chunks
# encodings = tokenize_text(text_data, tokenizer, max_length=256, chunk_size=chunk_size)

<span style="font-family: 'Bebas Neue'; font-size:24px;"> Will be working with 1% of the data for testing purposes  
<span style="font-family: 'Bebas Neue'; font-size:24px;"> Grabbed the next 0.5% of the data as valid_set

In [None]:
train_idx = int(0.01 * len(text_data))
train_text = text_data[:train_idx]
test_text = text_data[train_idx:int(train_idx*1.5)]

### <span style="font-family: 'Bebas Neue'; font-size:34px;">Gets the data ready for training.
- <span style="font-family: 'Bebas Neue'; font-size:24px;">Grabbed a pre-trained tokenizer from HF. Since i will be closely following BERT, hence BERTTokenizer
- <span style="font-family: 'Bebas Neue'; font-size:24px;">Makes a custom torch dataset 
- <span style="font-family: 'Bebas Neue'; font-size:24px;">Data collator function from HF, which masks the text with prob of 0.15%.
- <span style="font-family: 'Bebas Neue'; font-size:24px;">Operates on chunk of the text for faster processing
- <span style="font-family: 'Bebas Neue'; font-size:24px;">Resultant train_dl and text_dl is obtained.


In [None]:
from transformers import BertTokenizerFast, DataCollatorForLanguageModeling
from torch.utils.data import Dataset, DataLoader

bs = 32

tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

class MaskedLanguageModelDataset(Dataset):
    def __init__(self, encodings):
        self.encodings = encodings
    
    def __getitem__(self, idx):
        return {
            "input_ids": self.encodings[idx]["input_ids"],
            "attention_mask": self.encodings[idx]["attention_mask"],
        }
    
    def __len__(self):
        return len(self.encodings)

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=True,
    mlm_probability=0.15
)

def get_mlm(text_data):
    def tokenize_text(text_data, tokenizer, max_length, chunk_size):
        encodings = []
        for i in range(0, len(text_data), chunk_size):
            chunk = text_data[i:i+chunk_size]
            chunk_encodings = tokenizer(chunk, truncation=True, max_length=max_length, return_overflowing_tokens=True, padding=True)
            
            for encoding in chunk_encodings.encodings:
                encodings.append({
                    "input_ids": encoding.ids,
                    "attention_mask": encoding.attention_mask,
                })
        
        return encodings

    chunk_size = 10000  

    encodings = tokenize_text(text_data, tokenizer, max_length=256, chunk_size=chunk_size)

    dataset = MaskedLanguageModelDataset(encodings)

    dataloader = DataLoader(
        dataset,
        batch_size=bs,
        shuffle=True,
        collate_fn=data_collator
    )
    return dataloader

train_dl = get_mlm(train_text)
test_dl = get_mlm(test_text)

- <span style="font-family: 'Bebas Neue'; font-size:24px;"> **input_ids**: Maps sub-words to inetgers according to the tokenizer
- <span style="font-family: 'Bebas Neue'; font-size:24px;"> **attention_mask**: Ignores padded tokens
- <span style="font-family: 'Bebas Neue'; font-size:24px;"> **labels**: Labesl for the task. -100 means ignore the label when calculating the loss in torch

In [None]:
for i,batch in enumerate(train_dl):
    print(batch.input_ids[4][:20])
    print(batch.attention_mask[4][:20])
    print(batch.labels[4][:20])
    break

#### <span style="font-family: 'Bebas Neue'; font-size:1.2em;">Layer Norm
- <span style="font-family: 'Bebas Neue'; font-size:1.2em;"> Two trainable params gamma and beta

In [None]:
class LayerNormalization(nn.Module):
    def __init__(self, shape, eps=1e-5):
        super().__init__()
        
        self.param_shape = shape # 768
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(shape)) # 768
        self.beta = nn.Parameter(torch.zeros(shape)) # 768
        
    def forward(self, x): # 32 x 256 x 768

        dim = -1

        mean = x.mean(dim = dim, keepdim=True) # Out shape: 30 x 256 x 1. Keepdim keeps the last dim and does not collapse.

        var = ((x - mean)**2).mean(dim = dim,keepdim = True) # Out shape: 30 x 256 x 1. Keepdim keeps the last dim and does not collapse.

        std = (var + self.eps).sqrt() # 30 x 256 x 1

        y = (x - mean) / std # 30 x 256 x 768 Normalized values based on mean and std.

        out = self.gamma * y + self.beta # learnable param same as embedding size. To incorporate lost information presumably or outliers/bias.
        
        return out

#### <span style="font-family: 'Bebas Neue'; font-size:34px">Embedding implementataion
- <span style="font-family: 'Bebas Neue'; font-size:24px"> Takes in vocab, gives embeddings for each token.
- <span style="font-family: 'Bebas Neue'; font-size:24px"> Position embeds are trainable, same as bert

In [None]:
class Embeddings(nn.Module):
    def __init__(self, vocab_size, vector_dim, max_len, drop):
        super().__init__()
        self.token_embeddings = nn.Embedding(vocab_size, vector_dim) # 30552 x 768
        self.position_embed = nn.Embedding(max_len, vector_dim) # 256 x 768
        self.dropout = nn.Dropout(p = drop)
        

    def forward(self, x): # 32 x 256

        token_embeddings = self.token_embeddings(x) # 32 x 256 x 768

        seq_len = x.size(1) # Gets the seq len

        position_ids = torch.arange(seq_len,dtype = torch.long,device = x.device) # Gives a tensor with [0,...,255]

        position_ids = position_ids.unsqueeze(0).expand_as(x) # Unsqueeze adds a singleton dimension at pos 0. Then expanded with batch size same as x.

        position_embeddings = self.position_embed(position_ids) # out: 32 x 256 x 768

        embeddings = token_embeddings + position_embeddings # Add them

        embeddings = self.dropout(embeddings) # Drop out

        return embeddings # Out: 32 x 256 x 768

<span style="font-family: 'Bebas Neue'; font-size:24px"> Every word is broken down into three vectors q, k, v  

<span style="font-family: 'Bebas Neue'; font-size:24px"> Attention gives the probability of how much must we focus.  

<span style="font-family: 'Bebas Neue'; font-size:24px">The output is a vector with all the attention information incorporated...  

In [None]:
def sdpa(q: torch.tensor, k: torch.tensor, v:torch.tensor, mask = None): # 32 x 8 x 256 x 96

    d_k = q.shape[-1] # Takes the length of each head_dim i.e 64

    #Selectively transposes the last two dimension of k for matmul. Gives attention scores for each word and every other word Out: 32 x 8 x 256 x 256
    scaled = torch.matmul(q, k.transpose(-1,-2)) / math.sqrt(d_k) 
    
    if mask is not None:
        
        mask_expanded = mask.unsqueeze(1).unsqueeze(2) # Add two dimension in the middle
        
        mask_expanded = mask_expanded.expand_as(scaled) # Expand as the scaled dimensions

        scaled = scaled.masked_fill(mask_expanded == 0, float('-inf')) # Replace 0 values with - inf. Will not contribute to softmax -inf -> 0
    

    attention = F.softmax(scaled, dim = -1) # Squish the outputs, add up to one

    out = torch.matmul(attention, v) #The transformed vector with attention scores incorporated. Out: (32, 8, 256, 96) 

    return out

<span style="font-family: 'Bebas Neue'; font-size:34px">Multi-Head Attention
- <span style="font-family: 'Bebas Neue'; font-size:24px"> Split across many heads for parallel comuptation.
- <span style="font-family: 'Bebas Neue'; font-size:24px"> The intutuion is many heads can learn better attention scores than a single head

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self,vector_dim,num_heads,drop):
        super(MultiHeadAttention, self).__init__()
        self.vector_dim = vector_dim # 768
        self.num_heads = num_heads # 8
        self.head_dims = vector_dim // num_heads # 768 / 8 = 96
        self.qkv_layer = nn.Linear(vector_dim, vector_dim * 3) # shape: (768, 2304)
        self.linear = nn.Linear(vector_dim, vector_dim) #shape: (768, 768)
        self.dropout = nn.Dropout(p = drop)
        self.layer_norm = LayerNormalization(vector_dim)

    def forward(self, x, mask = None): # 32 x 256 x 768

        bs, max_len, _ = x.shape # Gets the batch size and max length

        qkv = self.qkv_layer(x) # Splits the inp embedding into three diff values but together for faster speeds. Out: 32 x 256 x 2304

        qkv = qkv.reshape(bs, max_len, self.num_heads, self.head_dims * 3) # Divide according to the number of heads. Out: 32 x 256 x 8 x 288

        qkv = qkv.permute(0,2,1,3) # Reshaping to get it ready for inp into attention head. Out: 32 x 8 x 256 x 288

        q, k, v = qkv.chunk(3,dim = -1) # Chunk it into 3 parts Q, K and V Out: (32, 8, 256, 96)

        out = sdpa(q, k, v, mask) # Scaled dot product attention Out: (32, 8, 256, 96)

        out = out.reshape(bs, max_len, self.num_heads * self.head_dims) # Back to vector_size with attention info incorporated Out: (32, 256, 768)

        out = self.dropout(out) # Drop out

        out = self.linear(out) # Final Linear transformation Out: Out: (32, 256, 768)

        out = self.layer_norm(out) # Normalization
        
        out = self.dropout(out) # Drop out

        return out

<span style="font-family: 'Bebas Neue'; font-size:34px"> Feed Forward network

In [None]:
class PosFFN(nn.Module):
    def __init__(self, vector_dim, hidden_dim, drop):
        super().__init__()
        self.lin1 = nn.Linear(vector_dim,hidden_dim) # 768 x 2048
        self.lin2 = nn.Linear(hidden_dim,vector_dim) # 2048 x 768
        self.gelu = nn.GELU()
        self.drop = nn.Dropout(p=drop)
        self.layer_norm = LayerNormalization(hidden_dim)

    def forward(self, x): # 32 x 256 x 768

        x = self.lin1(x) # Out: 32 x 256 x 2048

        x = self.gelu(x) # Act Out: 32 x 256 x 2048

        x = self.layer_norm(x) # Out: 32 x 256 x 2048

        x = self.drop(x) # Dropout Out: 32 x 256 x 2048

        x = self.lin2(x) #Linear projection to vector_dim. Out:32 x 256 x 768

        return x

<span style="font-family: 'Bebas Neue'; font-size:34px"> Encoder block
- <span style="font-family: 'Bebas Neue'; font-size:24px"> The Multi-head attention and FFN to get a single encoder block

In [None]:
class EncoderBlock(nn.Module):
    def __init__(self,vector_dim,num_heads,hidden_dim,drop):
        super(EncoderBlock, self).__init__()
        self.attention = MultiHeadAttention(vector_dim,num_heads,drop)
        self.ffn = PosFFN(vector_dim = vector_dim,hidden_dim = hidden_dim,drop = drop)
    
    def forward(self, x, attention_mask): # 32 x 256 x 768
        res_x = x # Residual Out: 32 x 256 x 768

        x = self.attention(x,mask = attention_mask) # Out: 32 x 256 x 768

        x = x + res_x # Add the residual

        res_x = x # Take the residual again Out: 32 x 256 x 768

        x = self.ffn(x) # Position wise feed forward Out: 32 x 256 x 768

        x = x + res_x # Add the residual

        return x

<span style="font-family: 'Bebas Neue'; font-size:34px"> Stacking the encoder blocks

In [None]:
class Encoder(nn.Module):
    def __init__(self,vector_dim,num_heads,drop,hidden_dim,num_layers):
        super().__init__()
        self.layers = nn.Sequential(*[EncoderBlock(vector_dim,num_heads,hidden_dim,drop) for _ in range(num_layers)])
    def forward(self, x, attention_mask):
        for layer in self.layers:
            x = layer(x, attention_mask)
        return x

<span style="font-family: 'Bebas Neue'; font-size:34px"> The final model.
- <span style="font-family: 'Bebas Neue'; font-size:24px">Ready for token prediction.

In [None]:
vector_dim = 768
num_heads = 8
drop = 0.1
max_len = 256
hidden_dim = 2048
num_layers = 2
vocab_size = len(tokenizer.get_vocab()) # 30552

class TransformerModel(nn.Module):
    def __init__(self, vocab_size, vector_dim, max_len, num_heads, drop, hidden_dim, num_layers):
        super().__init__()
        self.embeddings = Embeddings(vocab_size, vector_dim, max_len, drop)
        self.encoder = Encoder(vector_dim, num_heads, drop, hidden_dim, num_layers)
        self.linear = nn.Linear(vector_dim, vocab_size)  # Add a linear layer for output projection
    
    def forward(self, x, attention_mask):
        x = self.embeddings(x) # Out: 32 x 256 x 768
        
        x = self.encoder(x, attention_mask) # Out: 32 x 256 x 768

        x = self.linear(x)  # Project the output to the vocabulary size
        
        return x

device = 'mps'

bs = 32
input_tensor = torch.randint(0, vocab_size, (bs, max_len))
attention_mask = torch.ones(bs, max_len)  # Create an attention mask with all ones

model = TransformerModel(vocab_size, vector_dim, max_len, num_heads, drop, hidden_dim, num_layers).to(device)

# Generate a random input tensor and attention mask
bs = 32
input_tensor = torch.randint(0, vocab_size, (bs, max_len)).to(device)
attention_mask = torch.ones(bs, max_len).to(device) # Create an attention mask with all ones

# Pass the input tensor and attention mask through the model
output = model(input_tensor, attention_mask)

# Print the shapes of the input and output tensors
print("Input tensor shape:", input_tensor.shape)
print("Attention mask shape:", attention_mask.shape)
print("Output tensor shape:", output.shape)

### Training Loop

In [None]:
vector_dim = 768
num_heads = 8
drop = 0.1
max_len = 256
hidden_dim = 2048
num_layers = 2
vocab_size = len(tokenizer.get_vocab()) # 30552

from tqdm import tqdm
model = TransformerModel(vocab_size, vector_dim, max_len, num_heads, drop, hidden_dim, num_layers)

model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-3)

num_epochs = 100
for epoch in range(num_epochs):
    train_loss = 0.0
    val_loss = 0.0

    pbar = tqdm(train_dl, desc=f'Epoch {epoch+1}')
    model.train()
    for batch in pbar:
        
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        # Forward pass
        outputs = model(input_ids, attention_mask).to(device)
        loss = criterion(outputs.view(-1, vocab_size), labels.view(-1))
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        
    print(f'Train Loss: {train_epoch_loss}')
    model.eval()
    with torch.no_grad():
        pbar2 = tqdm(test_dl, desc=f'Epoch {epoch + 1}')
        
        for batch in pbar2:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            # Forward pass
            outputs = model(input_ids, attention_mask).to(device)
            loss = criterion(outputs.view(-1, vocab_size), labels.view(-1))

            val_loss += loss.item()
    train_epoch_loss = train_loss / len(train_dl)
    valid_loss = val_loss / len(test_dl)
    print(f'Valid Loss: {valid_loss}')

<span style="font-family: 'Bebas Neue'; font-size:24px"> WIP: Attention mechanism for the routing gate
- <span style="font-family: 'Bebas Neue'; font-size:24px"> Trying it with 1 attention head 

In [None]:
# class SingleHeadRouterAttention(nn.Module):
#     def __init__(self,vector_dim,num_experts,num_heads = 1,bs=32,max_len=256):
#         super().__init__()
#         self.vector_dim = vector_dim
#         self.num_heads = num_heads
#         self.head_dims = vector_dim // num_heads # 768
#         self.qkv_layer = nn.Linear(vector_dim, vector_dim * 3) # shape: (768, 2304)
#         self.linear = nn.Linear(vector_dim,num_experts) #shape: (768, num_experts)
#         self.bs = bs
#         self.max_len = max_len

#     def forward(self, x, mask = None): # 32 x 256 x 768
        
#         bs, max_len, _ = x.shape # Gets the batch size and max length

#         qkv = self.qkv_layer(x) # Splits the inp embedding into three diff values but together for faster speeds. Out: 32 x 256 x 2304

#         qkv = qkv.reshape(bs, max_len, self.num_heads, self.head_dims * 3) # Divide according to the number of heads. Out: 32 x 256 x 8 x 288

#         qkv = qkv.permute(0,2,1,3) # Reshaping to get it ready for inp into attention head. Out: 32 x 8 x 256 x 288

#         q, k, v = qkv.chunk(3,dim = -1) # Chunk it into 3 parts Q, K and V Out: (32, 8, 256, 96)

#         out = sdpa(q, k, v, mask) # Scaled dot product attention Out: (32, 8, 256, 96)

#         out = out.reshape(bs, max_len, self.num_heads * self.head_dims) # Back to vector_size with attention info incorporated Out: (32, 256, 768)

#         out = self.linear(out) # Final Linear transformation Out: Out: (32, 256, 768)

#         return out

<span style="font-family: 'Bebas Neue'; font-size:34px">Sparse Feed Forward Network
- <span style="font-family: 'Bebas Neue'; font-size:24px"> Takes in number of experts and top_k, the top k experts to route the tokens.
- <span style="font-family: 'Bebas Neue'; font-size:24px"> Implements a linear gating mechanism 
- <span style="font-family: 'Bebas Neue'; font-size:24px"> The output is multiplied by the routing weights, appropriately.

In [None]:
class SwitchFFN(nn.Module):
    def __init__(self, vector_dim, hidden_dim, num_experts, expert_drop = 0.4,top_k = 1,capacity_factor = 1.25):
        super().__init__()
        self.vector_dim = vector_dim
        self.hidden_dim = hidden_dim
        self.num_experts = num_experts
        self.top_k = int(top_k)
        self.capacity_factor = capacity_factor
        self.gate = nn.Linear(self.vector_dim, self.num_experts, bias=False) # 768 x 4
        self.experts = nn.ModuleList([PosFFN(vector_dim,hidden_dim,expert_drop) for _ in range(self.num_experts)]) #out shape: (32, 256, 768)

    def forward(self, x): # 32 x 256 x 768
        
        batch_size, sequence_length, vector_dim = x.shape 

        x = x.view(-1, vector_dim) #Flattens to num_tokens [8192] -> 32 X 256, vector_dim -> 768

        router_logits = self.gate(x) #Passes it through the routing gate. In shape: 8192 x 768 Out shape: 8192 x 4 -> Router logits for each token
        
        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) # Probabilities of routing
        
        routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) # Gives 2 tensors. 1 > routing weights 2 > The expert index

        routing_weights /= routing_weights.sum(dim=-1, keepdim=True) # Scales the routing scores according to the experts
        
        routing_weights = routing_weights.to(x.dtype) # Cast it back to fp16. Unfortunately i have mps, i have no fp16

        #initializes the intermediate output to all zeros. Need this to add out in specific indexes
        final_hidden_states = torch.zeros(
            (batch_size * sequence_length, vector_dim), dtype=x.dtype, device=x.device
        )

        # Calculate expert capacity
        tokens_per_batch = batch_size * sequence_length
        expert_capacity = int(tokens_per_batch / self.num_experts) * self.capacity_factor

        #Shape: 4,1,8192. A one hot encoded mask for each token num_experts x top_k x tokens
        expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) 
        
        # Looping through the all the experts and perform comp on each expert.
        for expert_idx in range(self.num_experts):

            expert_layer = self.experts[expert_idx] # Gets the expert layer

            idx, top_x = torch.where(expert_mask[expert_idx]) # 

            #If no experts are assigned, the loop contunues to the next expert
            if top_x.shape[0] == 0:
                continue
            if top_x.shape[0] > expert_capacity:
                leftover_tokens = top_x[expert_capacity:]
                top_x = top_x[:expert_capacity]
            
            current_state = x[None, top_x].reshape(-1, vector_dim) # In: 8192 x 768

            #Indexes into the routing weights and multiplies with out from layer
            current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]

            final_hidden_states[top_x] = current_hidden_states.to(x.dtype)

            if leftover_tokens.shape[0] > 0:
                final_hidden_states[leftover_tokens] = x[None, leftover_tokens].reshape(-1, vector_dim)
        

        final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, vector_dim) #reshapes
        
        return final_hidden_states, router_logits

<span style="font-family: 'Bebas Neue'; font-size:34px">Switch Encoder
- <span style="font-family: 'Bebas Neue'; font-size:24px"> Inserts a switch layer after three normal FFN layers
- <span style="font-family: 'Bebas Neue'; font-size:24px"> Takes the router logits if switch layer. (required to calculate the load balancing loss)

In [None]:
class SwitchEncoderBlock(nn.Module):
    def __init__(self, vector_dim, num_heads, hidden_dim, num_layers, num_experts,top_k = 1,expert_drop = 0.4,drop = 0.1):  # Added MoE parameters
        super().__init__()
        self.attention = MultiHeadAttention(vector_dim, num_heads, drop)
        self.layers = nn.Sequential(*[
            PosFFN(vector_dim, hidden_dim, drop) if i % 4 != 3  # Insert EncoderBlock at most positions
            else SwitchFFN(vector_dim, hidden_dim, num_experts, top_k, expert_drop)  # Insert SparseMoEBlock every 4th layer
            for i in range(num_layers) 
        ])

    def forward(self, x, attention_mask):
        total_router_logits = []

        for layer in self.layers:
            if isinstance(layer, SwitchFFN):
                #Attention
                residual_x = x
                x = self.attention(x, attention_mask)
                x = residual_x + x
                residual_x = x
                
                x, router_logits = layer(x)  # Collect from MoE
                x = residual_x + x
                total_router_logits.append(router_logits)
            else:
                residual_x = x

                x = self.attention(x, attention_mask)

                x = residual_x + x

                residual_x = x

                x = layer(x)

                x = residual_x + x
        total_router_logits = torch.cat(total_router_logits, dim=0)
        return x, total_router_logits

<span style="font-family: 'Bebas Neue'; font-size:34px">The Final Switch Encoder

In [None]:
vector_dim = 768
num_heads = 8
drop = 0.1
max_len = 256
hidden_dim = 2048
num_layers = 8
num_experts = 4
expert_drop = 0.4
top_k = 1
vocab_size = len(tokenizer.get_vocab()) # 30552

class SwitchTransformer(nn.Module):
    def __init__(self, vocab_size, vector_dim, max_len, num_heads, drop, hidden_dim, num_layers,top_k = 1, expert_drop = 0.4):
        super().__init__()
        self.embeddings = Embeddings(vocab_size, vector_dim, max_len, drop)
        self.encoder = SwitchEncoderBlock(vector_dim, num_heads, hidden_dim, num_layers, num_experts,top_k,expert_drop) #Initialize the switch encoder
        self.linear = nn.Linear(vector_dim, vocab_size)  # Add a linear layer for output projection
    
    def forward(self, x, attention_mask):
        x = self.embeddings(x) # Out: 32 x 256 x 768
        
        x, router_logits = self.encoder(x, attention_mask) # Out: 32 x 256 x 768

        x = self.linear(x)  # Project the output to the vocabulary size
        
        return x, router_logits

device = 'mps'

bs = 32
input_tensor = torch.randint(0, vocab_size, (bs, max_len))
attention_mask = torch.ones(bs, max_len)  # Create an attention mask with all ones

model = SwitchTransformer(vocab_size, vector_dim, max_len, num_heads, drop, hidden_dim, num_layers).to(device)

# Generate a random input tensor and attention mask
bs = 32
input_tensor = torch.randint(0, vocab_size, (bs, max_len)).to(device)
attention_mask = torch.ones(bs, max_len).to(device) # Create an attention mask with all ones

# Pass the input tensor and attention mask through the model
output,router_logits = model(input_tensor, attention_mask)

# Print the shapes of the input and output tensors
print("Input tensor shape:", input_tensor.shape)
print("Attention mask shape:", attention_mask.shape)
print("Output tensor shape:", output.shape)
print("Router Logits shape:", router_logits.shape)

<span style="font-family: 'Bebas Neue'; font-size:34px">Load balancing loss
- <span style="font-family: 'Bebas Neue'; font-size:24px">Calculates token processed per expert.
- <span style="font-family: 'Bebas Neue'; font-size:24px">Route percentages per expert
- <span style="font-family: 'Bebas Neue'; font-size:24px">Mean and Sum both of them
- <span style="font-family: 'Bebas Neue'; font-size:24px">Scales by a hyper param alpha

In [None]:
def load_loss(router_logits,num_experts = num_experts,top_k = top_k,attention_mask = None,alpha=10e-2):
    """
    inp: router logits from the sparse layers, top_k experts no., attention_mask, tunable hyper-param to scale the loss
    """
   
    # Apply softmax to router_logits to get the probabilities
    router_probs = F.softmax(router_logits, dim=1)

    _, selected_experts = torch.topk(router_probs, top_k, dim=-1) # Gives selected ecpdrt for each token: Shape: (8192,1)

    expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)# OHE mask for each token Out: (8192, 1, 4)

    tok_per_expert = torch.mean(expert_mask.float(), dim=0) # Gives token fraction per expert i Out: (1, 4)

    rout_per_expert = torch.mean(router_probs, dim=0) # Gives fraction of router prob allotted to expert i Out: (4)

    loss = torch.sum(tok_per_expert * rout_per_expert.unsqueeze(0))
    # Scale the loss by alpha
    loss = alpha * loss

    return loss

### Training SwitchTransformer

In [None]:
model = SwitchTransformer(vocab_size, vector_dim, max_len, num_heads, drop, hidden_dim, num_layers)

model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-3)

In [None]:
from tqdm import tqdm
num_epochs = 10
for epoch in range(num_epochs):
    train_loss = 0.0

    pbar = tqdm(train_dl, desc=f'Epoch {epoch+1}')
    model.train()
    for batch in pbar:
        
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        # Forward pass
        outputs,router_logits = model(input_ids, attention_mask)
        outputs.to(device),router_logits.to(device)

        #Aggregrate Loss
        load_balance_loss = load_loss(router_logits)
        loss = criterion(outputs.view(-1, vocab_size), labels.view(-1))
        loss += load_balance_loss
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        pbar.set_postfix(loss=train_loss / (pbar.n + 1))
        
    print(f'Train_loss:{train_loss/len(train_dl)}')