In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import pandas as pd
from icecream import ic
from tqdm import tqdm

#### <span style="font-family: 'Bebas Neue'; font-size:20px;">The wikitext dataset.

- <span style="font-family: 'Bebas Neue'; font-size:16px;">Has 1.8 million rows of text from wikipedia
- <span style="font-family: 'Bebas Neue'; font-size:16px;">Combined it to a very long string for processing

In [3]:
df1 = pd.read_parquet('train-00000-of-00002.parquet')
df2 = pd.read_parquet('train-00001-of-00002.parquet')
comb = pd.concat([df1,df2], ignore_index=True)

text_data = "\n".join(comb["text"].astype(str))

<span style="font-family: 'Bebas Neue'; font-size:16px;">Tried training a custom tokenizer, but due to loading issues dropped it and grabbed a pre-trained tokenizer  
<span style="font-family: 'Bebas Neue'; font-size:16px;">Will eventually figure it out

In [4]:
# from tokenizers import BertWordPieceTokenizer

# tokenizer = BertWordPieceTokenizer(
#     clean_text=True,
#     handle_chinese_chars=True,
#     strip_accents=True,
#     lowercase=True,
# )

# tokenizer.train_from_iterator(
#     text_data.split("\n"),
#     vocab_size=32000,
#     min_frequency=2,
#     special_tokens=["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"],
#     limit_alphabet=1000,
#     wordpieces_prefix="##"
# )

# def tokenize_text(text_data, tokenizer, max_length, chunk_size):
#     encodings = []
#     for i in range(0, len(text_data), chunk_size):
#         chunk = text_data[i:i+chunk_size]
#         chunk_texts = chunk.split("\n")
#         chunk_encodings = tokenizer.encode_batch(chunk_texts)
#         for enc in chunk_encodings:
#             encoding = {
#                 "ids": enc.ids[:max_length],
#                 "attention_mask": [1] * len(enc.ids[:max_length]),
#             }
#             overflow_ids = enc.ids[max_length:]
#             while overflow_ids:
#                 encoding["overflowing_tokens"] = overflow_ids[:max_length]
#                 encoding["overflow_attention_mask"] = [1] * len(overflow_ids[:max_length])
#                 overflow_ids = overflow_ids[max_length:]
#                 encodings.append(encoding)
#                 encoding = {
#                     "ids": [],
#                     "attention_mask": [],
#                 }
#     return encodings

# # Set the chunk size according to your memory constraints
# chunk_size = 10000  # Adjust this value based on your system's memory capacity

# # Tokenize the text data in chunks
# encodings = tokenize_text(text_data, tokenizer, max_length=256, chunk_size=chunk_size)

<span style="font-family: 'Bebas Neue'; font-size:16px;"> Will be working with 1% of the data for testing purposes  
<span style="font-family: 'Bebas Neue'; font-size:16px;"> Grabbed the next 0.5% of the data as valid_set

In [5]:
train_idx = int(0.01 * len(text_data))
train_text = text_data[:train_idx]
test_text = text_data[train_idx:int(train_idx*1.5)]

### <span style="font-family: 'Bebas Neue'; font-size:20px;">Gets the data ready for training.
- <span style="font-family: 'Bebas Neue'; font-size:16px;">Grabbed a pre-trained tokenizer from HF. Since i will be closely following BERT, hence BERTTokenizer
- <span style="font-family: 'Bebas Neue'; font-size:16px;">Makes a custom torch dataset 
- <span style="font-family: 'Bebas Neue'; font-size:16px;">Data collator function from HF, which masks the text with prob of 0.15%.
- <span style="font-family: 'Bebas Neue'; font-size:16px;">Operates on chunk of the text for faster processing
- <span style="font-family: 'Bebas Neue'; font-size:16px;">Resultant train_dl and text_dl is obtained.


In [6]:
from transformers import BertTokenizerFast, DataCollatorForLanguageModeling
from torch.utils.data import Dataset, DataLoader

bs = 32

tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

class MaskedLanguageModelDataset(Dataset):
    def __init__(self, encodings):
        self.encodings = encodings
    
    def __getitem__(self, idx):
        return {
            "input_ids": self.encodings[idx]["input_ids"],
            "attention_mask": self.encodings[idx]["attention_mask"],
        }
    
    def __len__(self):
        return len(self.encodings)

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=True,
    mlm_probability=0.15
)

def get_mlm(text_data):
    def tokenize_text(text_data, tokenizer, max_length, chunk_size):
        encodings = []
        for i in range(0, len(text_data), chunk_size):
            chunk = text_data[i:i+chunk_size]
            chunk_encodings = tokenizer(chunk, truncation=True, max_length=max_length, return_overflowing_tokens=True, padding=True)
            
            for encoding in chunk_encodings.encodings:
                encodings.append({
                    "input_ids": encoding.ids,
                    "attention_mask": encoding.attention_mask,
                })
        
        return encodings

    chunk_size = 10000  

    encodings = tokenize_text(text_data, tokenizer, max_length=256, chunk_size=chunk_size)

    dataset = MaskedLanguageModelDataset(encodings)

    dataloader = DataLoader(
        dataset,
        batch_size=bs,
        shuffle=True,
        collate_fn=data_collator
    )
    return dataloader

train_dl = get_mlm(train_text)
test_dl = get_mlm(test_text)

  from .autonotebook import tqdm as notebook_tqdm


- <span style="font-family: 'Bebas Neue'; font-size:16px;"> **input_ids**: Maps sub-words to inetgers according to the tokenizer
- <span style="font-family: 'Bebas Neue'; font-size:16px;"> **attention_mask**: Ignores padded tokens
- <span style="font-family: 'Bebas Neue'; font-size:16px;"> **labels**: Labesl for the task. -100 means ignore the label when calculating the loss in torch

In [7]:
for i,batch in enumerate(train_dl):
    print(batch.input_ids[4][:20])
    print(batch.attention_mask[4][:20])
    print(batch.labels[4][:20])
    break

tensor([  101,   103,  8109,  1010,  2403,  1025, 13591,  2726,  1010,   103,
         1025, 25869,   103, 24919,  1010,  2340,  1007,  1010,  2062,  2084])
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
tensor([-100, 3841, -100, -100, -100, -100, -100, -100, -100, 2410, -100, -100,
        2140, -100, -100, -100, -100, -100, -100, -100])


#### <span style="font-family: 'Bebas Neue'; font-size:1.2em;">Layer Norm
- <span style="font-family: 'Bebas Neue'; font-size:1.2em;"> Two trainable params gamma and beta

In [8]:
class LayerNormalization(nn.Module):
    def __init__(self, shape, eps=1e-5):
        super().__init__()
        
        self.param_shape = shape # 768
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(shape)) # 768
        self.beta = nn.Parameter(torch.zeros(shape)) # 768
        
    def forward(self, x): # 32 x 256 x 768

        dim = -1

        mean = x.mean(dim = dim, keepdim=True) # Out shape: 30 x 256 x 1. Keepdim keeps the last dim and does not collapse.

        var = ((x - mean)**2).mean(dim = dim,keepdim = True) # Out shape: 30 x 256 x 1. Keepdim keeps the last dim and does not collapse.

        std = (var + self.eps).sqrt() # 30 x 256 x 1

        y = (x - mean) / std # 30 x 256 x 768 Normalized values based on mean and std.

        out = self.gamma * y + self.beta # learnable param same as embedding size. To incorporate lost information presumably or outliers/bias.
        
        return out

#### <span style="font-family: 'Bebas Neue'; font-size:20px">Embedding implementataion
- <span style="font-family: 'Bebas Neue'; font-size:16px"> Takes in vocab, gives embeddings for each token.
- <span style="font-family: 'Bebas Neue'; font-size:16px"> Position embeds are trainable, same as bert

In [9]:
class Embeddings(nn.Module):
    def __init__(self, vocab_size, vector_dim, max_len, drop):
        super().__init__()
        self.token_embeddings = nn.Embedding(vocab_size, vector_dim) # 30552 x 768
        self.position_embed = nn.Embedding(max_len, vector_dim) # 256 x 768
        self.dropout = nn.Dropout(p = drop)
        

    def forward(self, x): # 32 x 256

        token_embeddings = self.token_embeddings(x) # 32 x 256 x 768

        seq_len = x.size(1) # Gets the seq len

        position_ids = torch.arange(seq_len,dtype = torch.long,device = x.device) # Gives a tensor with [0,...,255]

        position_ids = position_ids.unsqueeze(0).expand_as(x) # Unsqueeze adds a singleton dimension at pos 0. Then expanded with batch size same as x.

        position_embeddings = self.position_embed(position_ids) # out: 32 x 256 x 768

        embeddings = token_embeddings + position_embeddings # Add them

        embeddings = self.dropout(embeddings) # Drop out

        return embeddings # Out: 32 x 256 x 768

<span style="font-family: 'Bebas Neue'; font-size:16px"> Every word is broken down into three vectors q, k, v  

<span style="font-family: 'Bebas Neue'; font-size:16px"> Attention gives the probability of how much must we focus.  

<span style="font-family: 'Bebas Neue'; font-size:16px">The output is a vector with all the attention information incorporated...  

In [10]:
def sdpa(q: torch.tensor, k: torch.tensor, v:torch.tensor, mask = None): # 32 x 8 x 256 x 96

    d_k = q.shape[-1] # Takes the length of each head_dim i.e 64

    #Selectively transposes the last two dimension of k for matmul. Gives attention scores for each word and every other word Out: 32 x 8 x 256 x 256
    scaled = torch.matmul(q, k.transpose(-1,-2)) / math.sqrt(d_k) 
    
    if mask is not None:
        
        mask_expanded = mask.unsqueeze(1).unsqueeze(2) # Add two dimension in the middle
        
        mask_expanded = mask_expanded.expand_as(scaled) # Expand as the scaled dimensions

        scaled = scaled.masked_fill(mask_expanded == 0, float('-inf')) # Replace 0 values with - inf. Will not contribute to softmax -inf -> 0
    

    attention = F.softmax(scaled, dim = -1) # Squish the outputs, add up to one

    out = torch.matmul(attention, v) #The transformed vector with attention scores incorporated. Out: (32, 8, 256, 96) 

    return out

<span style="font-family: 'Bebas Neue'; font-size:20px">Multi-Head Attention
- <span style="font-family: 'Bebas Neue'; font-size:16px"> Split across many heads for parallel comuptation.
- <span style="font-family: 'Bebas Neue'; font-size:16px"> The intutuion is many heads can learn better attention scores than a single head

In [11]:
class MultiHeadAttention(nn.Module):
    def __init__(self,vector_dim,num_heads,drop):
        super(MultiHeadAttention, self).__init__()
        self.vector_dim = vector_dim # 768
        self.num_heads = num_heads # 8
        self.head_dims = vector_dim // num_heads # 768 / 8 = 96
        self.qkv_layer = nn.Linear(vector_dim, vector_dim * 3) # shape: (768, 2304)
        self.linear = nn.Linear(vector_dim, vector_dim) #shape: (768, 768)
        self.dropout = nn.Dropout(p = drop)
        self.layer_norm = LayerNormalization(vector_dim)

    def forward(self, x, mask = None): # 32 x 256 x 768

        bs, max_len, _ = x.shape # Gets the batch size and max length

        qkv = self.qkv_layer(x) # Splits the inp embedding into three diff values but together for faster speeds. Out: 32 x 256 x 2304

        qkv = qkv.reshape(bs, max_len, self.num_heads, self.head_dims * 3) # Divide according to the number of heads. Out: 32 x 256 x 8 x 288

        qkv = qkv.permute(0,2,1,3) # Reshaping to get it ready for inp into attention head. Out: 32 x 8 x 256 x 288

        q, k, v = qkv.chunk(3,dim = -1) # Chunk it into 3 parts Q, K and V Out: (32, 8, 256, 96)

        out = sdpa(q, k, v, mask) # Scaled dot product attention Out: (32, 8, 256, 96)

        out = out.reshape(bs, max_len, self.num_heads * self.head_dims) # Back to vector_size with attention info incorporated Out: (32, 256, 768)

        out = self.dropout(out) # Drop out

        out = self.linear(out) # Final Linear transformation Out: Out: (32, 256, 768)

        out = self.layer_norm(out) # Normalization
        
        out = self.dropout(out) # Drop out

        return out

<span style="font-family: 'Bebas Neue'; font-size:20px"> Feed Forward network

In [12]:
class PosFFN(nn.Module):
    def __init__(self, vector_dim, hidden_dim, drop):
        super().__init__()
        self.lin1 = nn.Linear(vector_dim,hidden_dim) # 768 x 2048
        self.lin2 = nn.Linear(hidden_dim,vector_dim) # 2048 x 768
        self.gelu = nn.GELU()
        self.drop = nn.Dropout(p=drop)
        self.layer_norm = LayerNormalization(hidden_dim)

    def forward(self, x): # 32 x 256 x 768

        x = self.lin1(x) # Out: 32 x 256 x 2048

        x = self.gelu(x) # Act Out: 32 x 256 x 2048

        x = self.layer_norm(x) # Out: 32 x 256 x 2048

        x = self.drop(x) # Dropout Out: 32 x 256 x 2048

        x = self.lin2(x) #Linear projection to vector_dim. Out:32 x 256 x 768

        return x

<span style="font-family: 'Bebas Neue'; font-size:20px"> Encoder block
- <span style="font-family: 'Bebas Neue'; font-size:16px"> The Multi-head attention and FFN to get a single encoder block

In [13]:
class EncoderBlock(nn.Module):
    def __init__(self,vector_dim,num_heads,hidden_dim,drop):
        super(EncoderBlock, self).__init__()
        self.attention = MultiHeadAttention(vector_dim,num_heads,drop)
        self.ffn = PosFFN(vector_dim = vector_dim,hidden_dim = hidden_dim,drop = drop)
    
    def forward(self, x, attention_mask): # 32 x 256 x 768
        res_x = x # Residual Out: 32 x 256 x 768

        x = self.attention(x,mask = attention_mask) # Out: 32 x 256 x 768

        x = x + res_x # Add the residual

        res_x = x # Take the residual again Out: 32 x 256 x 768

        x = self.ffn(x) # Position wise feed forward Out: 32 x 256 x 768

        x = x + res_x # Add the residual

        return x

<span style="font-family: 'Bebas Neue'; font-size:20px"> Stacking the encoder blocks

In [14]:
class Encoder(nn.Module):
    def __init__(self,vector_dim,num_heads,drop,hidden_dim,num_layers):
        super().__init__()
        self.layers = nn.Sequential(*[EncoderBlock(vector_dim,num_heads,hidden_dim,drop) for _ in range(num_layers)])
    def forward(self, x, attention_mask):
        for layer in self.layers:
            x = layer(x, attention_mask)
        return x

<span style="font-family: 'Bebas Neue'; font-size:20px"> The final model.
- <span style="font-family: 'Bebas Neue'; font-size:16px">Ready for token prediction.

In [15]:
vector_dim = 768
num_heads = 8
drop = 0.1
max_len = 256
hidden_dim = 2048
num_layers = 2
vocab_size = len(tokenizer.get_vocab()) # 30552

class TransformerModel(nn.Module):
    def __init__(self, vocab_size, vector_dim, max_len, num_heads, drop, hidden_dim, num_layers):
        super().__init__()
        self.embeddings = Embeddings(vocab_size, vector_dim, max_len, drop)
        self.encoder = Encoder(vector_dim, num_heads, drop, hidden_dim, num_layers)
        self.linear = nn.Linear(vector_dim, vocab_size)  # Add a linear layer for output projection
    
    def forward(self, x, attention_mask):
        x = self.embeddings(x) # Out: 32 x 256 x 768
        
        x = self.encoder(x, attention_mask) # Out: 32 x 256 x 768

        x = self.linear(x)  # Project the output to the vocabulary size
        
        return x

device = 'mps'

bs = 32
input_tensor = torch.randint(0, vocab_size, (bs, max_len))
attention_mask = torch.ones(bs, max_len)  # Create an attention mask with all ones

model = TransformerModel(vocab_size, vector_dim, max_len, num_heads, drop, hidden_dim, num_layers).to(device)

# Generate a random input tensor and attention mask
bs = 32
input_tensor = torch.randint(0, vocab_size, (bs, max_len)).to(device)
attention_mask = torch.ones(bs, max_len).to(device) # Create an attention mask with all ones

# Pass the input tensor and attention mask through the model
output = model(input_tensor, attention_mask)

# Print the shapes of the input and output tensors
print("Input tensor shape:", input_tensor.shape)
print("Attention mask shape:", attention_mask.shape)
print("Output tensor shape:", output.shape)

Input tensor shape: torch.Size([32, 256])
Attention mask shape: torch.Size([32, 256])
Output tensor shape: torch.Size([32, 256, 30522])


### Training Loop

In [16]:
vector_dim = 768
num_heads = 8
drop = 0.1
max_len = 256
hidden_dim = 2048
num_layers = 2
vocab_size = len(tokenizer.get_vocab()) # 30552

from tqdm import tqdm
model = TransformerModel(vocab_size, vector_dim, max_len, num_heads, drop, hidden_dim, num_layers)

model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-3)

num_epochs = 100
for epoch in range(num_epochs):
    train_loss = 0.0
    val_loss = 0.0

    pbar = tqdm(train_dl, desc=f'Epoch {epoch+1}')
    model.train()
    for batch in pbar:
        
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        # Forward pass
        outputs = model(input_ids, attention_mask).to(device)
        loss = criterion(outputs.view(-1, vocab_size), labels.view(-1))
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        
    print(f'Train Loss: {train_epoch_loss}')
    model.eval()
    with torch.no_grad():
        pbar2 = tqdm(test_dl, desc=f'Epoch {epoch + 1}')
        
        for batch in pbar2:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            # Forward pass
            outputs = model(input_ids, attention_mask).to(device)
            loss = criterion(outputs.view(-1, vocab_size), labels.view(-1))

            val_loss += loss.item()
    train_epoch_loss = train_loss / len(train_dl)
    valid_loss = val_loss / len(test_dl)
    print(f'Valid Loss: {valid_loss}')

<span style="font-family: 'Bebas Neue'; font-size:16px"> WIP: Routing with attention machanism
- <span style="font-family: 'Bebas Neue'; font-size:16px"> Trying it with 1 attention head 

In [17]:
# class SingleHeadRouterAttention(nn.Module):
#     def __init__(self,vector_dim,num_experts,num_heads = 1,bs=32,max_len=256):
#         super().__init__()
#         self.vector_dim = vector_dim
#         self.num_heads = num_heads
#         self.head_dims = vector_dim // num_heads # 768
#         self.qkv_layer = nn.Linear(vector_dim, vector_dim * 3) # shape: (768, 2304)
#         self.linear = nn.Linear(vector_dim,num_experts) #shape: (768, num_experts)
#         self.bs = bs
#         self.max_len = max_len

#     def forward(self, x, mask = None): # 32 x 256 x 768
        
#         bs, max_len, _ = x.shape # Gets the batch size and max length

#         qkv = self.qkv_layer(x) # Splits the inp embedding into three diff values but together for faster speeds. Out: 32 x 256 x 2304

#         qkv = qkv.reshape(bs, max_len, self.num_heads, self.head_dims * 3) # Divide according to the number of heads. Out: 32 x 256 x 8 x 288

#         qkv = qkv.permute(0,2,1,3) # Reshaping to get it ready for inp into attention head. Out: 32 x 8 x 256 x 288

#         q, k, v = qkv.chunk(3,dim = -1) # Chunk it into 3 parts Q, K and V Out: (32, 8, 256, 96)

#         out = sdpa(q, k, v, mask) # Scaled dot product attention Out: (32, 8, 256, 96)

#         out = out.reshape(bs, max_len, self.num_heads * self.head_dims) # Back to vector_size with attention info incorporated Out: (32, 256, 768)

#         out = self.linear(out) # Final Linear transformation Out: Out: (32, 256, 768)

#         return out

<span style="font-family: 'Bebas Neue'; font-size:20px">Sparse Feed Forward Network
- <span style="font-family: 'Bebas Neue'; font-size:16px"> Takes in number of experts and top_k, the top k experts to route the tokens.
- <span style="font-family: 'Bebas Neue'; font-size:16px"> Implements a linear gating mechanism 
- <span style="font-family: 'Bebas Neue'; font-size:16px"> The output is multiplied by the routing weights, appropriately.
- <span style="font-family: 'Bebas Neue'; font-size:16px"> Capacity factor decides how many token each expert can process. The rest are projected as it is without any transformation. Set to 1.25 recommended from the paper.

In [19]:
class SwitchFFN(nn.Module):
    def __init__(self, vector_dim, hidden_dim, num_experts, expert_drop = 0.4,top_k = 1,capacity_factor = 1.25):
        super().__init__()
        self.vector_dim = vector_dim
        self.hidden_dim = hidden_dim
        self.num_experts = num_experts
        self.top_k = int(top_k)
        self.capacity_factor = capacity_factor
        self.gate = nn.Linear(self.vector_dim, self.num_experts, bias=False) # 768 x 4
        self.experts = nn.ModuleList([PosFFN(vector_dim,hidden_dim,expert_drop) for _ in range(self.num_experts)]) #out shape: (32, 256, 768)

    def forward(self, x): # 32 x 256 x 768
        
        batch_size, sequence_length, vector_dim = x.shape 

        x = x.view(-1, vector_dim) #Flattens to num_tokens [8192] -> 32 X 256, vector_dim -> 768

        router_logits = self.gate(x) #Passes it through the routing gate. In shape: 8192 x 768 Out shape: 8192 x 4 -> Router logits for each token
        
        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) # Probabilities of routing
        
        routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) # Gives 2 tensors. 1 > routing weights 2 > The expert index

        routing_weights /= routing_weights.sum(dim=-1, keepdim=True) # Scales the routing scores according to the experts
        
        routing_weights = routing_weights.to(x.dtype) # Cast it back to fp16. Unfortunately i have mps, i have no fp16

        #initializes the intermediate output to all zeros. Need this to add out in specific indexes
        final_hidden_states = torch.zeros(
            (batch_size * sequence_length, vector_dim), dtype=x.dtype, device=x.device
        )

        # Calculate expert capacity
        tokens_per_batch = batch_size * sequence_length
        expert_capacity = int(tokens_per_batch / self.num_experts) * self.capacity_factor

        #Shape: 4,1,8192. A one hot encoded mask for each token num_experts x top_k x tokens
        expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) 
        
        # Looping through the all the experts and perform comp on each expert.
        for expert_idx in range(self.num_experts):

            expert_layer = self.experts[expert_idx] # Gets the expert layer

            idx, top_x = torch.where(expert_mask[expert_idx]) # Gets the indexes for the current expert

            #If no experts are assigned, the loop contunues to the next expert
            if top_x.shape[0] == 0:
                continue

            # Computes token to be processed based on expert capacity and the leftover tokens.
            if top_x.shape[0] > expert_capacity:
                leftover_tokens = top_x[expert_capacity:]
                top_x = top_x[:expert_capacity]
            
            current_state = x[None, top_x].reshape(-1, vector_dim) # In: 8192 x 768

            #Indexes into the routing weights and multiplies with out from the expert layer
            current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]

            final_hidden_states[top_x] = current_hidden_states.to(x.dtype) # Injects the computed hidden states to indexes of final hidden state

            if leftover_tokens.shape[0] > 0:
                final_hidden_states[leftover_tokens] = x[None, leftover_tokens].reshape(-1, vector_dim) # Injects left over tokens as it is
        

        final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, vector_dim) #reshapes 32 x 256 x 768
        print(final_hidden_states.shape)
        return final_hidden_states, router_logits

<span style="font-family: 'Bebas Neue'; font-size:20px">Switch Encoder
- <span style="font-family: 'Bebas Neue'; font-size:16px"> Inserts a switch layer after three normal FFN layers
- <span style="font-family: 'Bebas Neue'; font-size:16px"> Takes the router logits if switch layer. (required to calculate the load balancing loss)

In [20]:
class SwitchEncoderBlock(nn.Module):
    def __init__(self, vector_dim, num_heads, hidden_dim, num_layers, num_experts,top_k = 1,expert_drop = 0.4,drop = 0.1):  # Added MoE parameters
        super().__init__()
        self.attention = MultiHeadAttention(vector_dim, num_heads, drop)
        self.layers = nn.Sequential(*[
            PosFFN(vector_dim, hidden_dim, drop) if i % 4 != 3  # Insert EncoderBlock at most positions
            else SwitchFFN(vector_dim, hidden_dim, num_experts, top_k, expert_drop)  # Insert SparseMoEBlock every 4th layer
            for i in range(num_layers) 
        ])

    def forward(self, x, attention_mask):
        total_router_logits = []

        for layer in self.layers:
            if isinstance(layer, SwitchFFN):
                #Attention
                residual_x = x
                x = self.attention(x, attention_mask)
                x = residual_x + x
                residual_x = x
                
                x, router_logits = layer(x)  # Collect from MoE
                x = residual_x + x
                total_router_logits.append(router_logits)
            else:
                residual_x = x

                x = self.attention(x, attention_mask)

                x = residual_x + x

                residual_x = x

                x = layer(x)

                x = residual_x + x
        total_router_logits = torch.cat(total_router_logits, dim=0)
        return x, total_router_logits

<span style="font-family: 'Bebas Neue'; font-size:20px">The Final Switch Encoder

In [21]:
vector_dim = 768
num_heads = 8
drop = 0.1
max_len = 256
hidden_dim = 2048
num_layers = 8
num_experts = 4
expert_drop = 0.4
top_k = 1
vocab_size = len(tokenizer.get_vocab()) # 30552

class SwitchTransformer(nn.Module):
    def __init__(self, vocab_size, vector_dim, max_len, num_heads, drop, hidden_dim, num_layers,top_k = 1, expert_drop = 0.4):
        super().__init__()
        self.embeddings = Embeddings(vocab_size, vector_dim, max_len, drop)
        self.encoder = SwitchEncoderBlock(vector_dim, num_heads, hidden_dim, num_layers, num_experts,top_k,expert_drop) #Initialize the switch encoder
        self.linear = nn.Linear(vector_dim, vocab_size)  # Add a linear layer for output projection
    
    def forward(self, x, attention_mask):
        x = self.embeddings(x) # Out: 32 x 256 x 768
        
        x, router_logits = self.encoder(x, attention_mask) # Out: 32 x 256 x 768

        x = self.linear(x)  # Project the output to the vocabulary size
        
        return x, router_logits

device = 'mps'

bs = 32
input_tensor = torch.randint(0, vocab_size, (bs, max_len))
attention_mask = torch.ones(bs, max_len)  # Create an attention mask with all ones

model = SwitchTransformer(vocab_size, vector_dim, max_len, num_heads, drop, hidden_dim, num_layers).to(device)

# Generate a random input tensor and attention mask
bs = 32
input_tensor = torch.randint(0, vocab_size, (bs, max_len)).to(device)
attention_mask = torch.ones(bs, max_len).to(device) # Create an attention mask with all ones

# Pass the input tensor and attention mask through the model
output,router_logits = model(input_tensor, attention_mask)

# Print the shapes of the input and output tensors
print("Input tensor shape:", input_tensor.shape)
print("Attention mask shape:", attention_mask.shape)
print("Output tensor shape:", output.shape)
print("Router Logits shape:", router_logits.shape)

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])
Input tensor shape: torch.Size([32, 256])
Attention mask shape: torch.Size([32, 256])
Output tensor shape: torch.Size([32, 256, 30522])
Router Logits shape: torch.Size([16384, 4])


<span style="font-family: 'Bebas Neue'; font-size:20px">Load balancing loss
- <span style="font-family: 'Bebas Neue'; font-size:16px">Calculates token processed per expert.
- <span style="font-family: 'Bebas Neue'; font-size:16px">Route percentages per expert
- <span style="font-family: 'Bebas Neue'; font-size:16px">Mean and Sum both of them
- <span style="font-family: 'Bebas Neue'; font-size:16px">Scales by a hyper param alpha set as 10e-2 recommended by the paper

In [22]:
def load_loss(router_logits,num_experts = num_experts,top_k = top_k,attention_mask = None,alpha=10e-2):
    """
    inp: router logits from the sparse layers, top_k experts no., attention_mask, tunable hyper-param to scale the loss
    """
   
    # Apply softmax to router_logits to get the probabilities
    router_probs = F.softmax(router_logits, dim=1)

    _, selected_experts = torch.topk(router_probs, top_k, dim=-1) # Gives selected ecpdrt for each token: Shape: (8192,1)

    expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)# OHE mask for each token Out: (8192, 1, 4)

    tok_per_expert = torch.mean(expert_mask.float(), dim=0) # Gives token fraction per expert i Out: (1, 4)

    rout_per_expert = torch.mean(router_probs, dim=0) # Gives fraction of router prob allotted to expert i Out: (4)

    loss = torch.sum(tok_per_expert * rout_per_expert.unsqueeze(0))
    
    # Scale the loss by alpha
    loss = alpha * loss

    return loss

### Training SwitchTransformer

In [23]:
model = SwitchTransformer(vocab_size, vector_dim, max_len, num_heads, drop, hidden_dim, num_layers)

model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-3)

In [24]:
from tqdm import tqdm
num_epochs = 10
for epoch in range(num_epochs):
    train_loss = 0.0

    pbar = tqdm(train_dl, desc=f'Epoch {epoch+1}')
    model.train()
    for batch in pbar:
        
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        # Forward pass
        outputs,router_logits = model(input_ids, attention_mask)
        outputs.to(device),router_logits.to(device)

        #Aggregrate Loss
        load_balance_loss = load_loss(router_logits)
        loss = criterion(outputs.view(-1, vocab_size), labels.view(-1))
        loss += load_balance_loss
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        pbar.set_postfix(loss=train_loss / (pbar.n + 1))
        
    print(f'Train_loss:{train_loss/len(train_dl)}')

Epoch 1:   0%|          | 0/149 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:   1%|          | 1/149 [00:05<13:27,  5.46s/it, loss=15.8]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:   1%|▏         | 2/149 [00:06<07:37,  3.11s/it, loss=15.6]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:   2%|▏         | 3/149 [00:07<05:17,  2.17s/it, loss=16.7]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:   3%|▎         | 4/149 [00:08<03:58,  1.65s/it, loss=17]  

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:   3%|▎         | 5/149 [00:09<03:16,  1.37s/it, loss=17.3]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:   4%|▍         | 6/149 [00:10<02:49,  1.18s/it, loss=17.4]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:   5%|▍         | 7/149 [00:11<02:31,  1.07s/it, loss=17.6]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:   5%|▌         | 8/149 [00:12<02:19,  1.01it/s, loss=17.7]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:   6%|▌         | 9/149 [00:12<02:10,  1.07it/s, loss=17.6]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:   7%|▋         | 10/149 [00:13<02:05,  1.10it/s, loss=17.9]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:   7%|▋         | 11/149 [00:14<02:01,  1.13it/s, loss=17.8]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:   8%|▊         | 12/149 [00:15<01:58,  1.16it/s, loss=17.9]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:   9%|▊         | 13/149 [00:16<01:56,  1.17it/s, loss=17.8]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:   9%|▉         | 14/149 [00:17<01:54,  1.18it/s, loss=17.8]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  10%|█         | 15/149 [00:17<01:52,  1.19it/s, loss=17.8]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  11%|█         | 16/149 [00:18<01:51,  1.20it/s, loss=17.9]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  11%|█▏        | 17/149 [00:19<01:49,  1.20it/s, loss=17.9]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  12%|█▏        | 18/149 [00:20<01:49,  1.20it/s, loss=17.9]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  13%|█▎        | 19/149 [00:21<01:48,  1.20it/s, loss=17.8]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  13%|█▎        | 20/149 [00:22<01:47,  1.20it/s, loss=17.7]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  14%|█▍        | 21/149 [00:22<01:46,  1.20it/s, loss=17.6]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  15%|█▍        | 22/149 [00:23<01:46,  1.20it/s, loss=17.6]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  15%|█▌        | 23/149 [00:24<01:44,  1.20it/s, loss=17.6]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  16%|█▌        | 24/149 [00:25<01:44,  1.20it/s, loss=17.6]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  17%|█▋        | 25/149 [00:26<01:42,  1.21it/s, loss=17.6]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  17%|█▋        | 26/149 [00:27<01:41,  1.21it/s, loss=17.5]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  18%|█▊        | 27/149 [00:27<01:41,  1.20it/s, loss=17.4]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  19%|█▉        | 28/149 [00:28<01:41,  1.20it/s, loss=17.2]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  19%|█▉        | 29/149 [00:29<01:40,  1.19it/s, loss=17]  

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  20%|██        | 30/149 [00:30<01:39,  1.20it/s, loss=17.1]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  21%|██        | 31/149 [00:31<01:38,  1.20it/s, loss=17.3]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  21%|██▏       | 32/149 [00:32<01:37,  1.20it/s, loss=17.5]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  22%|██▏       | 33/149 [00:32<01:35,  1.21it/s, loss=17.6]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  23%|██▎       | 34/149 [00:33<01:35,  1.20it/s, loss=18.1]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  23%|██▎       | 35/149 [00:34<01:34,  1.20it/s, loss=18.6]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  24%|██▍       | 36/149 [00:35<01:33,  1.20it/s, loss=19]  

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  25%|██▍       | 37/149 [00:36<01:32,  1.21it/s, loss=19.4]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  26%|██▌       | 38/149 [00:37<01:32,  1.20it/s, loss=19.7]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  26%|██▌       | 39/149 [00:37<01:31,  1.21it/s, loss=20]  

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  27%|██▋       | 40/149 [00:38<01:31,  1.20it/s, loss=20.2]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  28%|██▊       | 41/149 [00:39<01:30,  1.20it/s, loss=20.2]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  28%|██▊       | 42/149 [00:40<01:29,  1.20it/s, loss=20.4]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  29%|██▉       | 43/149 [00:41<01:28,  1.20it/s, loss=20.5]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  30%|██▉       | 44/149 [00:42<01:27,  1.20it/s, loss=20.5]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  30%|███       | 45/149 [00:42<01:26,  1.21it/s, loss=20.5]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  31%|███       | 46/149 [00:43<01:27,  1.18it/s, loss=20.5]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  32%|███▏      | 47/149 [00:44<01:26,  1.19it/s, loss=20.5]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  32%|███▏      | 48/149 [00:45<01:24,  1.19it/s, loss=20.9]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  33%|███▎      | 49/149 [00:46<01:23,  1.20it/s, loss=21.1]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  34%|███▎      | 50/149 [00:47<01:22,  1.20it/s, loss=21.3]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  34%|███▍      | 51/149 [00:47<01:21,  1.20it/s, loss=21.6]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  35%|███▍      | 52/149 [00:48<01:20,  1.21it/s, loss=21.9]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  36%|███▌      | 53/149 [00:49<01:19,  1.21it/s, loss=22.2]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  36%|███▌      | 54/149 [00:50<01:18,  1.21it/s, loss=22.5]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  37%|███▋      | 55/149 [00:51<01:17,  1.21it/s, loss=22.8]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  38%|███▊      | 56/149 [00:52<01:17,  1.21it/s, loss=23.2]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  38%|███▊      | 57/149 [00:52<01:16,  1.20it/s, loss=23.5]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  39%|███▉      | 58/149 [00:53<01:16,  1.18it/s, loss=23.7]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  40%|███▉      | 59/149 [00:54<01:15,  1.18it/s, loss=23.9]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  40%|████      | 60/149 [00:55<01:15,  1.18it/s, loss=24.1]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  41%|████      | 61/149 [00:56<01:13,  1.19it/s, loss=24.3]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  42%|████▏     | 62/149 [00:57<01:12,  1.20it/s, loss=24.5]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  42%|████▏     | 63/149 [00:57<01:11,  1.20it/s, loss=24.7]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  43%|████▎     | 64/149 [00:58<01:10,  1.20it/s, loss=24.8]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  44%|████▎     | 65/149 [00:59<01:10,  1.19it/s, loss=24.9]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  44%|████▍     | 66/149 [01:00<01:10,  1.19it/s, loss=24.9]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  45%|████▍     | 67/149 [01:01<01:09,  1.19it/s, loss=24.9]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  46%|████▌     | 68/149 [01:02<01:07,  1.20it/s, loss=24.8]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  46%|████▋     | 69/149 [01:02<01:06,  1.21it/s, loss=24.7]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  47%|████▋     | 70/149 [01:03<01:05,  1.20it/s, loss=24.7]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  48%|████▊     | 71/149 [01:04<01:04,  1.20it/s, loss=24.6]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  48%|████▊     | 72/149 [01:05<01:04,  1.20it/s, loss=24.5]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  49%|████▉     | 73/149 [01:06<01:03,  1.19it/s, loss=24.5]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  50%|████▉     | 74/149 [01:07<01:03,  1.18it/s, loss=24.4]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  50%|█████     | 75/149 [01:08<01:02,  1.18it/s, loss=24.4]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  51%|█████     | 76/149 [01:08<01:01,  1.18it/s, loss=24.4]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  52%|█████▏    | 77/149 [01:09<01:00,  1.19it/s, loss=24.3]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  52%|█████▏    | 78/149 [01:10<00:59,  1.19it/s, loss=24.3]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  53%|█████▎    | 79/149 [01:11<00:58,  1.19it/s, loss=24.2]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  54%|█████▎    | 80/149 [01:12<00:57,  1.19it/s, loss=24.2]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  54%|█████▍    | 81/149 [01:13<00:56,  1.19it/s, loss=24.2]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  55%|█████▌    | 82/149 [01:13<00:55,  1.20it/s, loss=24.3]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  56%|█████▌    | 83/149 [01:14<00:55,  1.20it/s, loss=24.3]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  56%|█████▋    | 84/149 [01:15<00:54,  1.19it/s, loss=24.4]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  57%|█████▋    | 85/149 [01:16<00:53,  1.19it/s, loss=24.4]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  58%|█████▊    | 86/149 [01:17<00:52,  1.19it/s, loss=24.5]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  58%|█████▊    | 87/149 [01:18<00:52,  1.19it/s, loss=24.5]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  59%|█████▉    | 88/149 [01:18<00:51,  1.20it/s, loss=24.6]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  60%|█████▉    | 89/149 [01:19<00:49,  1.20it/s, loss=24.6]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  60%|██████    | 90/149 [01:20<00:49,  1.19it/s, loss=24.7]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  61%|██████    | 91/149 [01:21<00:48,  1.20it/s, loss=24.7]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  62%|██████▏   | 92/149 [01:22<00:47,  1.20it/s, loss=24.8]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  62%|██████▏   | 93/149 [01:23<00:47,  1.19it/s, loss=24.8]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  63%|██████▎   | 94/149 [01:23<00:46,  1.19it/s, loss=24.8]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  64%|██████▍   | 95/149 [01:24<00:45,  1.19it/s, loss=24.8]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  64%|██████▍   | 96/149 [01:25<00:44,  1.19it/s, loss=24.8]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  65%|██████▌   | 97/149 [01:26<00:43,  1.20it/s, loss=24.8]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  66%|██████▌   | 98/149 [01:27<00:42,  1.19it/s, loss=24.7]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  66%|██████▋   | 99/149 [01:28<00:41,  1.19it/s, loss=24.7]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  67%|██████▋   | 100/149 [01:29<00:41,  1.18it/s, loss=24.7]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  68%|██████▊   | 101/149 [01:29<00:40,  1.19it/s, loss=24.6]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  68%|██████▊   | 102/149 [01:30<00:39,  1.19it/s, loss=24.6]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  69%|██████▉   | 103/149 [01:31<00:38,  1.19it/s, loss=24.5]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  70%|██████▉   | 104/149 [01:32<00:37,  1.19it/s, loss=24.5]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  70%|███████   | 105/149 [01:33<00:37,  1.19it/s, loss=24.4]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  71%|███████   | 106/149 [01:34<00:36,  1.19it/s, loss=24.4]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  72%|███████▏  | 107/149 [01:34<00:35,  1.19it/s, loss=24.3]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  72%|███████▏  | 108/149 [01:35<00:34,  1.19it/s, loss=24.3]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  73%|███████▎  | 109/149 [01:36<00:33,  1.19it/s, loss=24.3]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  74%|███████▍  | 110/149 [01:37<00:32,  1.19it/s, loss=24.2]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  74%|███████▍  | 111/149 [01:38<00:32,  1.19it/s, loss=24.2]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  75%|███████▌  | 112/149 [01:39<00:31,  1.19it/s, loss=24.1]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  76%|███████▌  | 113/149 [01:39<00:30,  1.19it/s, loss=24.1]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  77%|███████▋  | 114/149 [01:40<00:29,  1.18it/s, loss=24]  

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  77%|███████▋  | 115/149 [01:41<00:28,  1.19it/s, loss=24]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  78%|███████▊  | 116/149 [01:42<00:27,  1.19it/s, loss=24]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  79%|███████▊  | 117/149 [01:43<00:26,  1.19it/s, loss=23.9]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  79%|███████▉  | 118/149 [01:44<00:26,  1.18it/s, loss=23.9]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  80%|███████▉  | 119/149 [01:44<00:25,  1.19it/s, loss=23.9]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  81%|████████  | 120/149 [01:45<00:24,  1.18it/s, loss=23.9]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  81%|████████  | 121/149 [01:46<00:23,  1.18it/s, loss=23.8]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  82%|████████▏ | 122/149 [01:47<00:22,  1.19it/s, loss=23.8]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  83%|████████▎ | 123/149 [01:48<00:21,  1.19it/s, loss=23.8]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  83%|████████▎ | 124/149 [01:49<00:20,  1.19it/s, loss=23.7]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  84%|████████▍ | 125/149 [01:50<00:20,  1.19it/s, loss=23.7]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  85%|████████▍ | 126/149 [01:50<00:19,  1.19it/s, loss=23.7]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  85%|████████▌ | 127/149 [01:51<00:18,  1.19it/s, loss=23.6]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  86%|████████▌ | 128/149 [01:52<00:17,  1.19it/s, loss=23.6]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  87%|████████▋ | 129/149 [01:53<00:16,  1.19it/s, loss=23.6]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  87%|████████▋ | 130/149 [01:54<00:16,  1.19it/s, loss=23.6]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  88%|████████▊ | 131/149 [01:55<00:15,  1.19it/s, loss=23.5]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  89%|████████▊ | 132/149 [01:55<00:14,  1.20it/s, loss=23.5]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  89%|████████▉ | 133/149 [01:56<00:13,  1.20it/s, loss=23.4]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  90%|████████▉ | 134/149 [01:57<00:12,  1.20it/s, loss=23.4]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  91%|█████████ | 135/149 [01:58<00:11,  1.19it/s, loss=23.4]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  91%|█████████▏| 136/149 [01:59<00:10,  1.19it/s, loss=23.3]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  92%|█████████▏| 137/149 [02:00<00:10,  1.19it/s, loss=23.2]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  93%|█████████▎| 138/149 [02:00<00:09,  1.18it/s, loss=23.2]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  93%|█████████▎| 139/149 [02:01<00:08,  1.19it/s, loss=23.1]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  94%|█████████▍| 140/149 [02:02<00:07,  1.19it/s, loss=23.1]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  95%|█████████▍| 141/149 [02:03<00:06,  1.19it/s, loss=23]  

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  95%|█████████▌| 142/149 [02:04<00:05,  1.19it/s, loss=22.9]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  96%|█████████▌| 143/149 [02:05<00:05,  1.19it/s, loss=22.9]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  97%|█████████▋| 144/149 [02:05<00:04,  1.19it/s, loss=22.8]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  97%|█████████▋| 145/149 [02:06<00:03,  1.19it/s, loss=22.7]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  98%|█████████▊| 146/149 [02:07<00:02,  1.19it/s, loss=22.7]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  99%|█████████▊| 147/149 [02:08<00:01,  1.19it/s, loss=22.6]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 1:  99%|█████████▉| 148/149 [02:09<00:00,  1.18it/s, loss=22.5]

torch.Size([13, 256, 768])
torch.Size([13, 256, 768])


Epoch 1: 100%|██████████| 149/149 [02:09<00:00,  1.15it/s, loss=22.5]


Train_loss:22.472335015367342


Epoch 2:   0%|          | 0/149 [00:00<?, ?it/s]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 2:   1%|          | 1/149 [00:01<02:42,  1.10s/it, loss=12]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 2:   1%|▏         | 2/149 [00:01<02:17,  1.07it/s, loss=11.8]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 2:   2%|▏         | 3/149 [00:02<02:10,  1.12it/s, loss=11.9]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 2:   3%|▎         | 4/149 [00:03<02:06,  1.14it/s, loss=12]  

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 2:   3%|▎         | 5/149 [00:04<02:04,  1.16it/s, loss=12]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 2:   4%|▍         | 6/149 [00:05<02:01,  1.17it/s, loss=12]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 2:   5%|▍         | 7/149 [00:06<02:00,  1.17it/s, loss=12.1]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 2:   5%|▌         | 8/149 [00:06<01:59,  1.18it/s, loss=12]  

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 2:   6%|▌         | 9/149 [00:07<01:57,  1.19it/s, loss=11.9]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 2:   7%|▋         | 10/149 [00:08<01:56,  1.20it/s, loss=11.8]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 2:   7%|▋         | 11/149 [00:09<01:54,  1.20it/s, loss=11.7]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 2:   8%|▊         | 12/149 [00:10<01:54,  1.19it/s, loss=11.6]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 2:   9%|▊         | 13/149 [00:11<01:54,  1.19it/s, loss=11.5]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 2:   9%|▉         | 14/149 [00:11<01:53,  1.19it/s, loss=11.5]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 2:  10%|█         | 15/149 [00:12<01:51,  1.20it/s, loss=11.4]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 2:  11%|█         | 16/149 [00:13<01:51,  1.20it/s, loss=11.3]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 2:  11%|█▏        | 17/149 [00:14<01:50,  1.19it/s, loss=11.2]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 2:  12%|█▏        | 18/149 [00:15<01:50,  1.19it/s, loss=11.1]

torch.Size([32, 256, 768])
torch.Size([32, 256, 768])


Epoch 2:  12%|█▏        | 18/149 [00:15<01:56,  1.13it/s, loss=11.1]


KeyboardInterrupt: 