#### TODO
[last updated: 23/03/2025]
1. Blog and line-by-line explanation of transformers architecture
2. Separate files for machine translation testing
3. ADD: unitesting of each class

In [None]:
import os
from os.path import exists
import math
import copy
import time
import warnings

import numpy as np
import pandas as pd
import altair as alt
import spacy
import GPUtil

# PyTorch imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.functional import log_softmax, pad
import torch.nn.functional as F
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP

# torchtext imports
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.datasets import Multi30k
from torchtext.data.functional import to_map_style_dataset
from torch.nn.utils.rnn import pad_sequence
from typing import List, Tuple, Iterator

# Device configuration
device = torch.device('mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(
            self, 
            d_model, num_heads
    ):
        super().__init__()
        assert d_model % num_heads == 0

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads     # Dimension of each head's key/query/value

        self.q_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        self.out_linear = nn.Linear(d_model, d_model)
    
    def forward(self, 
                q, k, v,
                mask = None,
    ):
        batch_size = q.size(0)
        
        # Reshape to separate attn_heads
        # [batch, seq_len (or token_seq), d_model\ --> [batch, heads, seq_len, d_k]
        q = self.q_linear(q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        k = self.k_linear(k).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        v = self.v_linear(v).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

        # Scaled dot-product attention
        # Mathematically: 
        # Dimension: [batch_size, heads, d_k, seq_len]
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)

        # apply mask for padding future tokens
        if mask is not None:
            scores = scores.masked_fill(
                mask==0, 1e-9,
            )
        
        # we calculate softmax along dim: -1 (seq_len)
        attn_weights = F.softmax(scores, dim=-1)

        # infomation aggregation: each pos becomes a weighted sum of values from all positions it attends to
        output = torch.matmul(attn_weights, v)

        # Combining all the heads
        # [batch, heads, seq_len, d_k] -> [batch, seq_len, heads, d_k] --> [batch, seq_len, d_model] 
        # ealier we had split `d_model` >> heads x d_k
        # Now converting back to `d_model`
        # [new] `contiguous`` is used to ensure the entire tensor in a single block of memory
        output = output.transpose(1, 2).contiguous().view(
            batch_size, -1, self.d_model
        )

        output = self.out_linear(output)
        
        return output, attn_weights

In [None]:
class PositionalwiseFeedForward(nn.Module):
    def __init__(self,
                 d_model, d_ff):
        super().__init__()
        # LL1 -> extands dimension to 4*d_model
        self.linear1 = nn.Linear(d_model, d_ff)
        # LL2 -> projects nacl tp og_dim
        self.linear2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.linear2(
            self.relu(
                self.linear1(x)
            )
        )

In [None]:
class PositionalEncoding(nn.Module):
    """
    Since the transformer architecture cantains no recurrence or convolution  it has non inheret way to understand the order of token in sequence

    PosEnc injects information about relative positioning through a linear function in the embedding (positional encoding)

    PE(pos, 2i) = sin(pos/10000^(2i/d_model))
    PE(pos, 2i+1) = cos(pos / 10000 ^ (2i / d_model))

    Research Advancements:
    - Rotary Positional Embedding (RoPE): integrate PosEnc into attention mech
    """

    def __init__(self, 
                 d_model, max_seq_len = 5000):
        super().__init__()
        pe = torch.zeros(max_seq_len, d_model)
        position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)

        # Division Term tensor create different frequencies with each each dimension of model gets a frequency. 
        # For lower_dim get slower freq
        # higher_dim get faster freq
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

        # Applying sin to even indicies in embedding dim
        pe[:, 0: :2] = torch.sin(position * div_term)
        # Applying cos to odd indicies in embedding dim
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0)

        # Register Buffer: PE is not trained but saved and restored in state_dict
        self.register_buffer('pe', pe)

    def forward(self, x):
        # og_paper: adds these rather than concatenating
        return x + self.pe[:, :x.size(1)]

In [None]:
class EncoderLayer(nn.Module):
    """
    Each encoder layer consists of 
    1. Mutl-ead self-attention layer
    2. Position-wise feed forward network

    Each layer followed by a residual connection and layer normalization
    """
    def __init__(self,
                 d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionalwiseFeedForward(d_model, d_ff)

        # Layer_norm is used over Batch_norn
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        attn_output, _ = self.self_attn(x, x, x, mask)
        x += self.dropout(attn_output)      # Residual Connection
        x = self.norm1(x)                   # normalizatio

        ff_output = self.feed_forward(x)
        x += self.dropout(x)                # Residual Connection
        x = self.norm2(x)                   # Normalization

        return x

In [None]:
class DecoderLayer(nn.Module):
    """
    Each Decoder Layer consists of
    1. Masked multi-head self-attention mech (to prevent attending to future position)
    2. Mulihead cross attention over encoder output
    3. Position wise feed forward network

    * Each sub-layer uses residual connection and normalization
    * self-attention mask ensures autotressive property
    """
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionalwiseFeedForward(d_model, d_ff)

        # Layer_norm is used over Batch_norn
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)

        self.dropout = nn.Dropout(dropout)
    def forward(self, x, enc_output, self_mask = None, cross_mask = None):
        self_attn_output, _ = self.self_attn(x, x, x, self_mask)
        x += self.dropout(self_attn_output)
        x = self.norm1(x)

        # Cross attention: allows decoder to attend to all encoder positions
        cross_attn_output, _ = self.cross_attn(x, enc_output, enc_output, cross_mask)
        x += self.dropout(cross_attn_output)
        x = self.norm2(x)

        ff_ouput = self.feed_forward(x)
        x += self.dropout(ff_ouput)
        x = self.norm3(x)

        return x

In [None]:
class Encoder(nn.Module):
    """
    Encoder consists of N-layers where ouput from previous layer is input to the next layer
    """

    def __init__(self, 
                 d_model, num_heads, d_ff, num_layers, dropout=0.1):
        super().__init__()

        # Stacking indentical encoder layers
        self.layers  = nn.ModuleList(
            EncoderLayer(
                d_model, num_heads, d_ff, dropout
            )
            for _ in range(num_layers)
        )

        self.norm = nn.LayerNorm(d_model)

    def forward(self, 
                x, mask= None):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

In [None]:
class Decoder(nn.Module):
    """
    Decoder consists of N identical layers with the output of each layer serves as input to the next. Then a final layer normalization is applied to last decoder layer
    """

    def __init__(self,
                 d_model, num_heads, d_ff, num_layer, dropout=0.1):
        super().__init__()

        self.layers = nn.ModuleList(
            DecoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layer)
        )
        self.norm = nn.LayerNorm(d_model)
    
    def forward(self, x, enc_output, self_mask = None, cross_mask = None):
        for layer in self.layers:
            x = layer(x, enc_output, self_mask, cross_mask)
        return self.norm(x)

In [None]:
class Transformer(nn.Module):
    """
    - Encoder only moderl: BERT > excel at task understanding
    - Decoder only model: GPT > excel at general tasks
    - Encoder Decoder: T5, BART > seq2seq tasks
    """

    def __init__(self, 
                 src_vocab_size, tgt_vocab_size, 
                 d_model=512, num_heads=8, d_ff=2048, num_layers=6, 
                 dropout=0.1, max_seq_len=5000):
        super().__init__()

        # Converting token indicies to embedding
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)

        # Adding Positional Encoding
        self.positional_encoding = PositionalEncoding(d_model, max_seq_len)

        # Encoder & Decoder stack
        self.encoder = Encoder(d_model, num_heads, d_ff, num_layers, dropout)
        self.decoder = Decoder(d_model, num_heads, d_ff, num_layers, dropout)

        # Linear Layer: project decoder output to vocabulary probabilities
        self.linear = nn.Linear(d_model, tgt_vocab_size)

        # Initialize parameters with Xavier
        self._init_parameters()

        self.d_model = d_model
        self.dropout = nn.Dropout(dropout)
    
    def _init_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
    
    def create_masks(self, src, tgt):
        """
        Padding and sequence mask for attention
        1. Padding mask: prevents attending to padding tokens (0s in input)
        2. Look-ahead mask: prevent decoder from attenting to future positions
        """
        
        # Padding mask for src
        # Creating boolean maks that's true for real token and False for paddin tokens
        src_mask = (src != 0).unsqueeze(1).unsqueeze(2)

        # Look-ahead mask for target
        tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(2)

        # Look-ahead mask ensures autoregressive property and creates an upper triangualr matrix of Flase values thereby preventing each position from attending to future position
        seq_len = tgt.size(1)
        look_ahead_mask = torch.triu(
            torch.ones(seq_len, seq_len), 
            diagonal=1
        ).bool()
        look_ahead_mask = look_ahead_mask.to(tgt.device)

        # Combine padding and look-ahead masks
        # Position is valid for attention if both non-padding token AND not in future
        tgt_mask = tgt_mask & ~look_ahead_mask

        return src_mask, tgt_mask

    def forward(self, src, tgt):
        src_mask, tgt_mask = self.create_masks(src, tgt)

        # in og_paper, embedding and pos_enc for src_seq are sclaed by sqrt(d_model)
        src_embedded = self.dropout(
            self.positional_encoding(
                self.src_embedding(src) * math.sqrt(self.d_model)
            )
        )

        tgt_embedded = self.dropout(
            self.positional_encoding(
                self.tgt_embedding(tgt) * math.sqrt(self.d_model)
            )
        )

        enc_output = self.encoder(src_embedded, src_mask)
        dec_output = self.decoder(tgt_embedded, enc_output, tgt_mask, src_mask)

        # Project output to vocabulary space to get logits
        output = self.linear(dec_output)

        return output

In [None]:
# Example usuage: random
def create_transformer_model(
        src_vocab_size = 10000, tgt_vocab_size = 10000
):
    """
    Standard architecture 
    - d_model = 512
    - num_heads = 8
    - d_ff = 2048
    - num_layers = 6
    - dropout = 0.1
    """

    model = Transformer(
        src_vocab_size=src_vocab_size,
        tgt_vocab_size=tgt_vocab_size,
        d_model=512,
        num_heads=8,
        d_ff=2048,
        num_layers=6,
        dropout=0.1
    )
    return model

model = create_transformer_model()
print(f"Model has {sum(p.numel() for p in model.parameters())} parameters")

# Test with random tensors to verify dimensions
src = torch.randint(1, 10000, (2, 20))  # [batch_size, seq_len]
tgt = torch.randint(1, 10000, (2, 15))  # [batch_size, seq_len]

# Forward pass
output = model(src, tgt)
print(f"Output shape: {output.shape}")  # Should be [batch_size, tgt_seq_len, tgt_vocab_size]

### Testing with Dataset

In [None]:
def load_dataset():
    # Download the dataset - this will use English to German by default
    train_iter, valid_iter, test_iter = Multi30k(split=('train', 'valid', 'test'))
    return train_iter, valid_iter, test_iter

In [None]:
def get_tokenizers():
    try:
        spacy_en = get_tokenizer('spacy', language='en_core_web_sm')
        spacy_de = get_tokenizer('spacy', language='de_core_news_sm')

        def tokenizer_en(text):
            return [token.lower() for token in spacy_en(text)]
        
        def tokenizer_de(text):
            return [token.lower() for token in spacy_de(text)]
        
        return tokenizer_en, tokenizer_de
    except ImportError:
        
        print("SpaCy model not found. Using basic tokenization")
        basic_en = get_tokenizer('basic_english')
        basic_de = lambda text : text.lower().split()

        return basic_en, basic_de

def build_vocabularies(train_iter, tokenize_src, tokenize_tgt):

    # Helper: yield token
    def yield_tokens(data_iter, tokenizer, index):
        for data_sample in data_iter:
            yield tokenizer(data_sample[index])
    
    special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']

    # Building src vocab
    train_iter_for_src, _ = Multi30k(split='train')
    src_vocab = build_vocab_from_iterator(
        yield_tokens(train_iter_for_src, tokenize_src, 0), 
        min_freq=2, specials=special_symbols
    )

    # Build target vocabulary
    train_iter_for_tgt, _ = Multi30k(split=('train'))
    tgt_vocab = build_vocab_from_iterator(
        yield_tokens(train_iter_for_tgt, tokenize_tgt, 1),
        min_freq=2,
        specials=special_symbols
    )

    # Set Default index to <unk>
    src_vocab.set_default_index(src_vocab['<unk>'])
    tgt_vocab.set_default_index(tgt_vocab['<unk>'])

    return src_vocab, tgt_vocab

def process_data(data_iterator, tokenize_src, tokenize_tgt, src_vocab, tgt_vocab):
    BOS_IDX = src_vocab['<bos>']
    EOS_IDX = src_vocab['<eos>']
    PAD_IDX = src_vocab['<pad>']

    data_pairs = []
    for sample in data_iterator:
        src_text, tgt_text = sample

        # Tokenize and convert to indices
        src_tokens = [BOS_IDX] + [src_vocab[token] for token in tokenize_src(src_text)] + [EOS_IDX]
        tgt_tokens = [BOS_IDX] + [tgt_vocab[token] for token in tokenize_tgt(tgt_text)] + [EOS_IDX]
        
        data_pairs.append((torch.tensor(src_tokens), torch.tensor(tgt_tokens)))
    
    return data_pairs

def create_batch(data_batch):
    src_batch, tgt_batch = [], []
    for src_item, tgt_item in data_batch:
        src_batch.append(src_item)
        tgt_batch.append(tgt_item)
    
    # Pad sequences in batch to the same length
    src_batch = pad_sequence(src_batch, padding_value=src_vocab['<pad>'], batch_first=True)
    tgt_batch = pad_sequence(tgt_batch, padding_value=tgt_vocab['<pad>'], batch_first=True)
    
    return src_batch, tgt_batch

def evaluate(model, val_loader, criterion, pad_idx):
    model.eval()
    losses = 0
    
    with torch.no_grad():
        for batch_idx, (src, tgt) in enumerate(val_loader):
            src, tgt = src.to(device), tgt.to(device)
            
            # target sequence without the last token
            tgt_input = tgt[:, :-1]
            
            # Targets for loss calculation (without the BOS token)
            tgt_output = tgt[:, 1:]
            
            # Forward pass
            outputs = model(src, tgt_input)
            
            # Reshape for loss calculation
            outputs = outputs.contiguous().view(-1, outputs.shape[-1])
            tgt_output = tgt_output.contiguous().view(-1)
            
            # Calculate loss (ignoring padding)
            loss = criterion(outputs, tgt_output)
            losses += loss.item()
    
    return losses / len(val_loader)

def train_model(model, train_loader, val_loader, optimizer, criterion, num_epochs, pad_idx):
    best_val_loss = float('inf')
    
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        start_time = time.time()
        
        for batch_idx, (src, tgt) in enumerate(train_loader):
            src, tgt = src.to(device), tgt.to(device)
            
            # Zero the gradients
            optimizer.zero_grad()
            
            # target sequence without the last token
            tgt_input = tgt[:, :-1]
            # Targets for loss calculation (without the BOS token)
            tgt_output = tgt[:, 1:]
            # Forward pass
            outputs = model(src, tgt_input)
            # Reshape for loss calculation
            outputs = outputs.contiguous().view(-1, outputs.shape[-1])
            tgt_output = tgt_output.contiguous().view(-1)
            # Calculate loss (ignoring padding)
            loss = criterion(outputs, tgt_output)
            
            # Backward pass and optimize
            loss.backward()
            # Gradient clipping to prevent exploding gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            epoch_loss += loss.item()
            
            # Print batch statistics
            if (batch_idx + 1) % 100 == 0:
                print(f'Epoch: {epoch+1}, Batch: {batch_idx+1}, Loss: {loss.item():.4f}')
        
        # Calculate epoch statistics
        epoch_loss = epoch_loss / len(train_loader)
        val_loss = evaluate(model, val_loader, criterion, pad_idx)
        end_time = time.time()
        epoch_mins, epoch_secs = divmod(end_time - start_time, 60)
        
        # Print epoch statistics
        print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs:.2f}s')
        print(f'Train Loss: {epoch_loss:.4f} | Val Loss: {val_loss:.4f}')
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_transformer_model.pt')
            print(f'Best model saved with validation loss: {best_val_loss:.4f}')

def translate_sentence(model, sentence, tokenize_src, src_vocab, tgt_vocab, max_len=50):
    model.eval()
    
    tokens = tokenize_src(sentence)
    tokens = ['<bos>'] + tokens + ['<eos>']                             # Add BOS and EOS tokens
    src_indices = [src_vocab[token] for token in tokens]                # Convert to indices
    src_tensor = torch.LongTensor(src_indices).unsqueeze(0).to(device)  # Convert to tensor and add batch dimension
    
    # Initialize target with BOS token
    tgt_indices = [tgt_vocab['<bos>']]
    
    # Get special token indices
    eos_idx = tgt_vocab['<eos>']
    
    with torch.no_grad():
        # Encode the source sentence
        encoder_outputs = model.encoder(
            model.dropout(
                model.positional_encoding(
                    model.src_embedding(src_tensor) * math.sqrt(model.d_model)
                )
            ),
            src_mask=None
        )
        
        # Initialize with BOS token
        output = torch.LongTensor([tgt_vocab['<bos>']]).to(device)
        
        # Generate tokens auto-regressively
        for i in range(max_len):
            # Prepare target tensor (so far)
            tgt_tensor = output.unsqueeze(0)
            
            # Create masks
            src_mask = (src_tensor != src_vocab['<pad>']).unsqueeze(1).unsqueeze(2).to(device)
            tgt_mask = model.generate_subsequent_mask(tgt_tensor.size(1)).to(device)
            
            # Pass through decoder
            tgt_embedded = model.dropout(
                model.positional_encoding(
                    model.tgt_embedding(tgt_tensor) * math.sqrt(model.d_model)
                )
            )
            
            decoder_output = model.decoder(
                tgt_embedded, 
                encoder_outputs, 
                tgt_mask=tgt_mask, 
                src_mask=src_mask
            )
            
            # Get prediction
            prediction = model.linear(decoder_output)
            pred_token = prediction[0, -1].argmax().item()
            
            # Add predicted token to output
            output = torch.cat([output, torch.LongTensor([pred_token]).to(device)])
            
            # Break if EOS token is predicted
            if pred_token == eos_idx:
                break
    
    # Convert indices back to tokens
    tgt_tokens = [tgt_vocab.get_itos()[i] for i in output]
    
    # Remove special tokens and return translation
    return ' '.join(tgt_tokens[1:-1])  # Remove BOS and EOS

# Helper: generate look-ahead mask for decoder
def generate_subsequent_mask(size):
    mask = torch.triu(torch.ones(size, size), diagonal=1).bool()
    return mask.to(device)

Transformer.generate_subsequent_mask = generate_subsequent_mask

In [None]:
def main():
    train_iter, valid_iter, test_iter = load_dataset()
    tokenize_src, tokenize_tgt = get_tokenizers()

    global src_vocab, tgt_vocab  
    src_vocab, tgt_vocab = build_vocabularies(train_iter, tokenize_src, tokenize_tgt)
    
    train_data = process_data(train_iter, tokenize_src, tokenize_tgt, src_vocab, tgt_vocab)
    valid_data = process_data(valid_iter, tokenize_src, tokenize_tgt, src_vocab, tgt_vocab)
    test_data = process_data(test_iter, tokenize_src, tokenize_tgt, src_vocab, tgt_vocab)
    
    BATCH_SIZE = 64
    train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=create_batch)
    valid_loader = DataLoader(valid_data, batch_size=BATCH_SIZE, shuffle=False, collate_fn=create_batch)
    test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False, collate_fn=create_batch)
    
    src_vocab_size = len(src_vocab)
    tgt_vocab_size = len(tgt_vocab)
    
    # Standard hyperparameters
    D_MODEL = 512
    NUM_HEADS = 8
    NUM_LAYERS = 6
    D_FF = 2048
    DROPOUT = 0.1

    NUM_EPOCHS = 10
    LEARNING_RATE = 0.0001
    # lr = D_MODEL**-0.5 * min(step_num**-0.5, step_num * WARMUP_STEPS**-1.5)
    
    model = Transformer(
        src_vocab_size=src_vocab_size,
        tgt_vocab_size=tgt_vocab_size,
        d_model=D_MODEL,
        num_heads=NUM_HEADS,
        d_ff=D_FF,
        num_layers=NUM_LAYERS,
        dropout=DROPOUT
    ).to(device)
    
    # Stepup: raining parameters
    PAD_IDX = src_vocab['<pad>']
    criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX) 
    
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.98), eps=1e-9)
    
    train_model(model, train_loader, valid_loader, optimizer, criterion, NUM_EPOCHS, PAD_IDX)
    
    # Evaluation: Test dataset
    test_loss = evaluate(model, test_loader, criterion, PAD_IDX)
    print(f'Test Loss: {test_loss:.4f}')
    
    # Inference: sample query
    sample_sentence = "The young boy is playing soccer in the park."
    translation = translate_sentence(model, sample_sentence, tokenize_src, src_vocab, tgt_vocab)
    print(f'English: {sample_sentence}')
    print(f'Translation: {translation}')

In [None]:
main()