# **Machine Translation with Transformers from Scratch**

## **Project Overview**
This project focuses on building a machine translation system from **English to Russian** using **transformers implemented from scratch**. The primary goal is to experiment with and compare different **attention mechanism variations** to evaluate their impact on translation quality.

## **Objectives**
- Implement a **transformer-based** translation model from scratch.
- Explore and compare **various attention mechanisms** (e.g., scaled dot-product attention, local attention, and adaptive attention).
- Train the model on an **English-Russian parallel dataset**.
- Evaluate translation quality using **Perplexity metric**.
- Optimize performance by fine-tuning architectural components.

## **Key Components**
- **Data Preprocessing:** Tokenization, text normalization, and preparation of parallel English-Russian datasets.
- **Model Architecture:** Implementation of the transformer model, including encoder-decoder structures and attention mechanisms.
- **Training & Optimization:** Training the model with effective hyperparameters and loss functions.
- **Evaluation:** Measuring translation accuracy using BLEU scores and analyzing model performance.

## **Expected Outcomes**
- A working machine translation model that translates **English to Russian** with high accuracy.
- Insights into how **different attention mechanisms** affect translation quality.
- Potential improvements in efficiency and translation fluency by fine-tuning model components.


In [1]:
# File imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

import datasets
import torchtext
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

import random, math, time, copy, gc

In [2]:
# Global vars
# randomness
SEED = 42

# data processing
SRC_LANGUAGE = 'en'
TRG_LANGUAGE = 'ru'
UNK_IDX, PAD_IDX, SOS_IDX, EOS_IDX = 0, 1, 2, 3

# model_related
BATCH_SIZE = 8
HID_DIM = 256
ENC_LAYERS = 3
DEC_LAYERS = 3
ENC_HEADS = 8
DEC_HEADS = 8
ENC_PF_DIM = 512
DEC_PF_DIM = 512
ENC_DROPOUT = 0.1
DEC_DROPOUT = 0.1

In [3]:
# Pre-configs
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
# device2 = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

# lib versions
print(torch.__version__, torchtext.__version__)

2.2.0+cu121 0.16.2+cpu


## ETL: Loading the dataset

The dataset chosen is random phrases in the following [link](https://huggingface.co/datasets/Helsinki-NLP/opus-100/viewer/en-ru)

OPUS-100 is an English-centric multilingual corpus covering 100 languages.

OPUS-100 is English-centric, meaning that all training pairs include English on either the source or target side. The corpus covers 100 languages (including English). The languages were selected based on the volume of parallel data available in OPUS.

In [None]:
# books = datasets.load_dataset("opus_books", "en-ru") # not good, want more data
dataset = datasets.load_dataset('opus100', 'en-ru')
dataset

In [None]:
subset_size = 50000 # selecting 50k of training samples since it is getting a lot of time to train
dataset['train'] = dataset['train'].select(range(subset_size))
dataset

## Preprocessing
For preprocessing part, I want to split translation into corresponding languages

In [None]:
preprocessed = dataset.map(lambda text: {lang: text['translation'][lang] for lang in ('en', 'ru')}, remove_columns=['translation'])
preprocessed

### Tokenizing

Note: the models must first be downloaded using the following on the command line:
```
python3 -m spacy download en_core_web_sm
python3 -m spacy download ru_core_news_sm
```
First, since we have two languages, let's create some constants to represent that. Also, let's create two dicts: one for holding our tokenizers and one for holding all the vocabs with assigned numbers for each unique word.

#### Text to integers (Numericalization)
Next we gonna create function (torchtext called vocabs) that turn these tokens into integers. Here we use built in factory function ```build_vocab_from_iterator``` which accepts iterator that yield list or iterator of tokens.

In [None]:
# Place-holders
token_transform = {}
vocab_transform = {}

token_transform[SRC_LANGUAGE] = get_tokenizer('spacy', language='en_core_web_sm')
token_transform[TRG_LANGUAGE] = get_tokenizer('spacy', language='ru_core_news_sm')

print(preprocessed['train'][23])
print(token_transform[SRC_LANGUAGE](preprocessed['train'][SRC_LANGUAGE][23]))
print(token_transform[TRG_LANGUAGE](preprocessed['train'][TRG_LANGUAGE][23]))

def yield_tokens(data, language):
    language_index = {SRC_LANGUAGE: 0, TRG_LANGUAGE: 1}
    
    for data_sample in data:
        yield token_transform[language](data_sample[language])


special_symbols = ['<unk>', '<pad>', '<sos>', '<eos>']

for ln in [SRC_LANGUAGE, TRG_LANGUAGE]:
    vocab_transform[ln] = build_vocab_from_iterator(yield_tokens(preprocessed['train'], ln), 
                                                    min_freq=2,   #if not, everything will be treated as UNK
                                                    specials=special_symbols,
                                                    special_first=True) #indicates whether to insert symbols at the beginning or at the end                                            
# Set UNK_IDX as the default index. This index is returned when the token is not found. 
# If not set, it throws RuntimeError when the queried token is not found in the Vocabulary. 
for ln in [SRC_LANGUAGE, TRG_LANGUAGE]:
    vocab_transform[ln].set_default_index(UNK_IDX)

print(vocab_transform[SRC_LANGUAGE](['hy', 'my', 'name', 'is', 'Ulugbek']))
print(vocab_transform[TRG_LANGUAGE](['Здравствуйте', 'меня', 'зовут', 'Улугбек']))

In [None]:
# saving vocabulary for future usage
torch.save(vocab_transform, 'mt_enru_vocab_opus100.pt')

## Preparing the dataloader
We defined special symbols <unk>, <pad>, <sos>, <eos> with indexes 0, 1, 2, 3 respectively. Where each symbol has meanings as such:
```
<unk>: unknown
<pad>: padding
<sos>: Start of Sentence
<eos>: End of Sentence
```

In [None]:
# train = [(data[SRC_LANGUAGE], data[TRG_LANGUAGE]) for data in preprocessed['train']]
# test  = [(data[SRC_LANGUAGE], data[TRG_LANGUAGE]) for data in preprocessed['test']]
# val   = [(data[SRC_LANGUAGE], data[TRG_LANGUAGE]) for data in preprocessed['validation']]

In [None]:
mapping = vocab_transform[SRC_LANGUAGE].get_itos()

# helper function to club together sequential operations
def sequential_transforms(*transforms):
    def func(txt_input):
        for transform in transforms:
            txt_input = transform(txt_input)
        return txt_input
    return func

# function to add BOS/EOS and create tensor for input sequence indices
def tensor_transform(token_ids):
    return torch.cat((torch.tensor([SOS_IDX]), 
                      torch.tensor(token_ids), 
                      torch.tensor([EOS_IDX])))


# src and trg language text transforms to convert raw strings into tensors indices
text_transform = {}
for ln in [SRC_LANGUAGE, TRG_LANGUAGE]:
    text_transform[ln] = sequential_transforms(token_transform[ln], #Tokenization
                                               vocab_transform[ln], #Numericalization
                                               tensor_transform) # Add BOS/EOS and create tensor


# function to collate data samples into batch tesors
def collate_batch(batch):
    src_batch, src_len_batch, trg_batch = [], [], []
    for src_sample, trg_sample in batch:
        processed_text = text_transform[SRC_LANGUAGE](src_sample.rstrip("\n"))
        src_batch.append(processed_text)
        # print(src_sample)
        # print(processed_text)
        trg_batch.append(text_transform[TRG_LANGUAGE](trg_sample.rstrip("\n")))
        src_len_batch.append(processed_text.size(0))
    
    # print(src_batch)
    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX, batch_first = True) #<----need this because we use linear layers mostly
    trg_batch = pad_sequence(trg_batch, padding_value=PAD_IDX, batch_first = True)
    return src_batch, torch.tensor(src_len_batch, dtype=torch.int64), trg_batch

In [None]:
# train_loader = DataLoader(train, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
# valid_loader = DataLoader(val,   batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_batch)
# test_loader  = DataLoader(test,  batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_batch)

In [None]:
# for en, _, ru in train_loader:
#     print(en.shape)
#     print(ru.shape)
#     # print(en)
#     # print(ru)
#     break

# torch.Size([8, 52])
# torch.Size([8, 47])

### Model architecture
### Encoder

In [None]:
class EncoderLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, pf_dim, dropout, atten_type, device):
        super().__init__()
        self.self_attn_layer_norm = nn.LayerNorm(hid_dim)
        self.ff_layer_norm        = nn.LayerNorm(hid_dim)
        self.self_attention       = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, atten_type, device)
        self.feedforward          = PositionwiseFeedforwardLayer(hid_dim, pf_dim, dropout)
        self.dropout              = nn.Dropout(dropout)

    def forward(self, src, src_mask):
        #src = [batch size, src len, hid dim]
        #src_mask = [batch size, 1, 1, src len]   #if the token is padding, it will be 1, otherwise 0
        _src, _ = self.self_attention(src, src, src, src_mask)
        src     = self.self_attn_layer_norm(src + self.dropout(_src))
        #src: [batch_size, src len, hid dim]

        _src    = self.feedforward(src)
        src     = self.ff_layer_norm(src + self.dropout(_src))
        #src: [batch_size, src len, hid dim]

        return src
    
class Encoder(nn.Module):
    def __init__(self, input_dim, hid_dim, n_layers, n_heads, pf_dim, dropout, atten_type, device, max_length = 552):
        super().__init__()
        self.device = device
        self.atten_type = atten_type
        self.tok_embedding = nn.Embedding(input_dim, hid_dim)
        self.pos_embedding = nn.Embedding(max_length, hid_dim)
        self.layers        = nn.ModuleList([EncoderLayer(hid_dim, n_heads, pf_dim, dropout, atten_type,device)
                                           for _ in range(n_layers)])
        self.dropout       = nn.Dropout(dropout)
        self.scale         = torch.sqrt(torch.FloatTensor([hid_dim])).to(self.device)
        
    def forward(self, src, src_mask):
        
        #src = [batch size, src len]
        #src_mask = [batch size, 1, 1, src len]
        
        batch_size = src.shape[0]
        src_len    = src.shape[1]
        
        pos        = torch.arange(0, src_len).unsqueeze(0).repeat(batch_size, 1).to(self.device)
        #pos: [batch_size, src_len]
        
        src        = self.dropout((self.tok_embedding(src) * self.scale) + self.pos_embedding(pos))
        #src: [batch_size, src_len, hid_dim]
        
        for layer in self.layers:
            src = layer(src, src_mask)
        #src: [batch_size, src_len, hid_dim]
        
        return src


### Attention layers

Multihead: $$ \text{Attention}(Q, K, V) = \text{Softmax} \big( \frac{QK^T}{\sqrt{d_k}} \big)V $$ 

In [None]:
class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, dropout, atten_type, device):
        super().__init__()
        
        self.hid_dim  = hid_dim
        self.n_heads  = n_heads
        self.head_dim = hid_dim // n_heads
        self.atten_type = atten_type

        assert hid_dim % n_heads == 0, "hid_dim must be divisible by n_heads"
        
        self.fc_q     = nn.Linear(hid_dim, hid_dim)
        self.fc_k     = nn.Linear(hid_dim, hid_dim)
        self.fc_v     = nn.Linear(hid_dim, hid_dim)
        
        self.fc_o     = nn.Linear(hid_dim, hid_dim)
        
        self.dropout  = nn.Dropout(dropout)
        
        self.scale    = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)

        if atten_type == 'additive':
            self.W_q = nn.Linear(self.head_dim, self.head_dim)
            self.W_k = nn.Linear(self.head_dim, self.head_dim)
            self.v = nn.Linear(self.head_dim, 1)
                
    def forward(self, query, key, value, mask = None):
        #src, src, src, src_mask
        #query = [batch size, query len, hid dim]
        #key = [batch size, key len, hid dim]
        #value = [batch size, value len, hid dim]
        
        batch_size = query.shape[0]
        
        Q = self.fc_q(query)
        K = self.fc_k(key)
        V = self.fc_v(value)
        #Q=K=V: [batch_size, src len, hid_dim]
        
        Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        #Q = [batch_size, n heads, query len, head_dim]

        # Calculate attention scores based on the selected attention variant
        if self.atten_type == 'general':
            energy = torch.matmul(Q, K.permute(0, 1, 3, 2))
        elif self.atten_type == "multiplicative":
            energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
        elif self.atten_type == "additive":
            Q_exp  = Q.unsqueeze(3)
            K_exp  = K.unsqueeze(2)
            energy = self.v(torch.tanh(self.W_q(Q_exp) + self.W_k(K_exp))).squeeze(-1)            
        else:
            raise Exception("Choose between 'multiplicative', 'general', or 'additive'")
        
        #Q = [batch_size, n heads, query len, head_dim] @ K = [batch_size, n heads, head_dim, key len]
        #energy = [batch_size, n heads, query len, key len]
        
        #for making attention to padding to 0
        if mask is not None:
            energy = energy.masked_fill(mask == 0, -1e10)
            
        attention = torch.softmax(energy, dim = -1)
        #attention = [batch_size, n heads, query len, key len]
        
        x = torch.matmul(self.dropout(attention), V)
        #[batch_size, n heads, query len, key len] @ [batch_size, n heads, value len, head_dim]
        #x = [batch_size, n heads, query len, head dim]
        
        x = x.permute(0, 2, 1, 3).contiguous()  #we can perform .view
        #x = [batch_size, query len, n heads, head dim]
        
        x = x.view(batch_size, -1, self.hid_dim)
        #x = [batch_size, query len, hid dim]
        
        x = self.fc_o(x)
        #x = [batch_size, query len, hid dim]
        
        return x, attention


class PositionwiseFeedforwardLayer(nn.Module):
    def __init__(self, hid_dim, pf_dim, dropout):
        super().__init__()
        self.fc1 = nn.Linear(hid_dim, pf_dim)
        self.fc2 = nn.Linear(pf_dim, hid_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        #x = [batch size, src len, hid dim]
        x = self.dropout(torch.relu(self.fc1(x)))
        x = self.fc2(x)
        
        return x


### Decoder

In [None]:
class DecoderLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, pf_dim, dropout, atten_type, device):
        super().__init__()
        self.self_attn_layer_norm = nn.LayerNorm(hid_dim)
        self.enc_attn_layer_norm  = nn.LayerNorm(hid_dim)
        self.ff_layer_norm        = nn.LayerNorm(hid_dim)
        self.self_attention       = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, atten_type, device)
        self.encoder_attention    = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, atten_type, device)
        self.feedforward          = PositionwiseFeedforwardLayer(hid_dim, pf_dim, dropout)
        self.dropout              = nn.Dropout(dropout)
        
    def forward(self, trg, enc_src, trg_mask, src_mask):
        
        #trg = [batch size, trg len, hid dim]
        #enc_src = [batch size, src len, hid dim]
        #trg_mask = [batch size, 1, trg len, trg len]
        #src_mask = [batch size, 1, 1, src len]
        
        _trg, _ = self.self_attention(trg, trg, trg, trg_mask)
        trg     = self.self_attn_layer_norm(trg + self.dropout(_trg))
        #trg = [batch_size, trg len, hid dim]
        
        _trg, attention = self.encoder_attention(trg, enc_src, enc_src, src_mask)
        trg             = self.enc_attn_layer_norm(trg + self.dropout(_trg))
        #trg = [batch_size, trg len, hid dim]
        #attention = [batch_size, n heads, trg len, src len]
        
        _trg = self.feedforward(trg)
        trg  = self.ff_layer_norm(trg + self.dropout(_trg))
        #trg = [batch_size, trg len, hid dim]
        
        return trg, attention

class Decoder(nn.Module):
    def __init__(self, output_dim, hid_dim, n_layers, n_heads, 
                 pf_dim, dropout, atten_type, device,max_length = 552):
        super().__init__()
        self.device = device
        self.tok_embedding = nn.Embedding(output_dim, hid_dim)
        self.pos_embedding = nn.Embedding(max_length, hid_dim)
        self.layers        = nn.ModuleList([DecoderLayer(hid_dim, n_heads, pf_dim, dropout, atten_type, device)
                                            for _ in range(n_layers)])
        self.fc_out        = nn.Linear(hid_dim, output_dim)
        self.dropout       = nn.Dropout(dropout)
        self.scale         = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
        
    def forward(self, trg, enc_src, trg_mask, src_mask):
        
        #trg = [batch size, trg len]
        #enc_src = [batch size, src len, hid dim]
        #trg_mask = [batch size, 1, trg len, trg len]
        #src_mask = [batch size, 1, 1, src len]
        
        batch_size = trg.shape[0]
        trg_len    = trg.shape[1]
        
        pos = torch.arange(0, trg_len).unsqueeze(0).repeat(batch_size, 1).to(self.device)
        #pos: [batch_size, trg len]
        
        trg = self.dropout((self.tok_embedding(trg) * self.scale) + self.pos_embedding(pos))
        #trg: [batch_size, trg len, hid dim]
        
        for layer in self.layers:
            trg, attention = layer(trg, enc_src, trg_mask, src_mask)
            
        #trg: [batch_size, trg len, hid dim]
        #attention: [batch_size, n heads, trg len, src len]
        
        output = self.fc_out(trg)
        #output = [batch_size, trg len, output_dim]
        
        return output, attention

### Putting them together (become Seq2Seq!)

Our `trg_sub_mask` will look something like this (for a target with 5 tokens):

$$\begin{matrix}
1 & 0 & 0 & 0 & 0\\
1 & 1 & 0 & 0 & 0\\
1 & 1 & 1 & 0 & 0\\
1 & 1 & 1 & 1 & 0\\
1 & 1 & 1 & 1 & 1\\
\end{matrix}$$

The "subsequent" mask is then logically anded with the padding mask, this combines the two masks ensuring both the subsequent tokens and the padding tokens cannot be attended to. For example if the last two tokens were `<pad>` tokens the mask would look like:

$$\begin{matrix}
1 & 0 & 0 & 0 & 0\\
1 & 1 & 0 & 0 & 0\\
1 & 1 & 1 & 0 & 0\\
1 & 1 & 1 & 0 & 0\\
1 & 1 & 1 & 0 & 0\\
\end{matrix}$$

In [None]:
class Seq2SeqTransformer(nn.Module):
    def __init__(self, encoder, decoder, src_pad_idx, trg_pad_idx, device):
        super().__init__()
        
        self.encoder = encoder
        self.decoder = decoder
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device

    def make_src_mask(self, src):
        
        #src = [batch size, src len]
        
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        #src_mask = [batch size, 1, 1, src len]

        return src_mask
    
    def make_trg_mask(self, trg):
        
        #trg = [batch size, trg len]
        
        trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(2)
        #trg_pad_mask = [batch size, 1, 1, trg len]
        
        trg_len = trg.shape[1]
        
        trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device = self.device)).bool()
        #trg_sub_mask = [trg len, trg len]
            
        trg_mask = trg_pad_mask & trg_sub_mask
        #trg_mask = [batch size, 1, trg len, trg len]
        
        return trg_mask

    def forward(self, src, trg):
        
        #src = [batch size, src len]
        #trg = [batch size, trg len]
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)
        
        #src_mask = [batch size, 1, 1, src len]
        #trg_mask = [batch size, 1, trg len, trg len]
        
        enc_src = self.encoder(src, src_mask)
        #enc_src = [batch size, src len, hid dim]
                
        output, attention = self.decoder(trg, enc_src, trg_mask, src_mask)
        
        #output = [batch size, trg len, output dim]
        #attention = [batch size, n heads, trg len, src len]
        
        return output, attention


### Training

In [None]:
# def initialize_weights(m):
#     if hasattr(m, 'weight') and m.weight.dim() > 1:
#         nn.init.xavier_uniform_(m.weight.data)

In [None]:
# INPUT_DIM = len(vocab_transform[SRC_LANGUAGE])
# OUTPUT_DIM = len(vocab_transform[TRG_LANGUAGE])

# ATTEN_TYPE = 'additive'

# enc = Encoder(INPUT_DIM, 
#               HID_DIM, 
#               ENC_LAYERS, 
#               ENC_HEADS, 
#               ENC_PF_DIM, 
#               ENC_DROPOUT, 
#               ATTEN_TYPE,
#               device)

# dec = Decoder(OUTPUT_DIM, 
#               HID_DIM, 
#               DEC_LAYERS, 
#               DEC_HEADS, 
#               DEC_PF_DIM, 
#               DEC_DROPOUT, 
#               ATTEN_TYPE,
#               device)

# SRC_PAD_IDX = PAD_IDX
# TRG_PAD_IDX = PAD_IDX

# model = Seq2SeqTransformer(enc, dec, SRC_PAD_IDX, TRG_PAD_IDX, device).to(device)
# model.apply(initialize_weights)

In [None]:
# def count_parameters(model):
#     params = [p.numel() for p in model.parameters() if p.requires_grad]
#     for item in params:
#         print(f'{item:>6}')
#     print(f'______\n{sum(params):>6}')

# count_parameters(model)

In [None]:
# lr = 0.0005

# #training hyperparameters
# optimizer = optim.Adam(model.parameters(), lr=lr)
# criterion = nn.CrossEntropyLoss(ignore_index = TRG_PAD_IDX) #combine softmax with cross entropy

Then, we'll define our training loop. This is the exact same as the one used in the previous tutorial.

As we want our model to predict the `<eos>` token but not have it be an input into our model we simply slice the `<eos>` token off the end of the sequence. Thus:

$$\begin{align*}
\text{trg} &= [sos, x_1, x_2, x_3, eos]\\
\text{trg[:-1]} &= [sos, x_1, x_2, x_3]
\end{align*}$$

$x_i$ denotes actual target sequence element. We then feed this into the model to get a predicted sequence that should hopefully predict the `<eos>` token:

$$\begin{align*}
\text{output} &= [y_1, y_2, y_3, eos]
\end{align*}$$

$y_i$ denotes predicted target sequence element. We then calculate our loss using the original `trg` tensor with the `<sos>` token sliced off the front, leaving the `<eos>` token:

$$\begin{align*}
\text{output} &= [y_1, y_2, y_3, eos]\\
\text{trg[1:]} &= [x_1, x_2, x_3, eos]
\end{align*}$$

We then calculate our losses and update our parameters as is standard.

In [None]:
def train(model, loader, optimizer, criterion, clip, loader_length):
    
    model.train()
    
    epoch_loss = 0
    
    for src, src_len, trg in loader:
        
        src = src.to(device)
        trg = trg.to(device)

        optimizer.zero_grad()
        
        #trg[:, :-1] remove the eos, e.g., "<sos> I love sushi" since teaching forcing, the input does not need to have eos
        output, _ = model(src, trg[:,:-1])
                
        #output = [batch size, trg len - 1, output dim]
        #trg    = [batch size, trg len]
            
        output_dim = output.shape[-1]
            
        output = output.reshape(-1, output_dim)
        trg = trg[:,1:].reshape(-1) #trg[:, 1:] remove the sos, e.g., "i love sushi <eos>" since in teaching forcing, the output does not have sos
                
        #output = [batch size * trg len - 1, output dim]
        #trg    = [batch size * trg len - 1]
            
        loss = criterion(output, trg)
        
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        
        optimizer.step()
        
        epoch_loss += loss.item()
        
    return epoch_loss / loader_length

def evaluate(model, loader, criterion, loader_length):
    
    model.eval()
    
    epoch_loss = 0
    
    with torch.no_grad():
    
        for src, src_len, trg in loader:
        
            src = src.to(device)
            trg = trg.to(device)

            output, _ = model(src, trg[:,:-1])
            
            #output = [batch size, trg len - 1, output dim]
            #trg = [batch size, trg len]
            
            output_dim = output.shape[-1]
            
            output = output.contiguous().view(-1, output_dim)
            trg = trg[:,1:].contiguous().view(-1)
            
            #output = [batch size * trg len - 1, output dim]
            #trg = [batch size * trg len - 1]
            
            loss = criterion(output, trg)

            epoch_loss += loss.item()
        
    return epoch_loss / loader_length

### Putting everything together

Finally, we train our actual model. This model is almost 3x faster than the convolutional sequence-to-sequence model and also achieves a lower validation perplexity!

**Note: similar to CNN, this model always has a teacher forcing ratio of 1, i.e. it will always use the ground truth next token from the target sequence (this is simply because CNN do everything in parallel so we cannot have the next token). This means we cannot compare perplexity values against the previous models when they are using a teacher forcing ratio that is not 1. To understand this, try run previous tutorials with teaching forcing ratio of 1, you will get very low perplexity.  **   

In [None]:
# train_loader_length = len(list(iter(train_loader)))
# val_loader_length   = len(list(iter(valid_loader)))
# test_loader_length  = len(list(iter(test_loader)))

# # train_loader_length, val_loader_length, test_loader_length

In [24]:
# train_loader_length = len(list(iter(train_loader)))
# val_loader_length   = len(list(iter(valid_loader)))
# test_loader_length  = len(list(iter(test_loader)))

# train_max_seq = max([max(batch[0].shape[1], batch[2].shape[1]) for batch in list(iter(train_loader))])
# val_max_seq = max([max(batch[0].shape[1], batch[2].shape[1]) for batch in list(iter(valid_loader))])
# test_max_seq = max([max(batch[0].shape[1], batch[2].shape[1]) for batch in list(iter(test_loader))])

# train_max_seq, val_max_seq, test_max_seq # required for model matching

In [25]:
# def epoch_time(start_time, end_time):
#     elapsed_time = end_time - start_time
#     elapsed_mins = int(elapsed_time / 60)
#     elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
#     return elapsed_mins, elapsed_secs

In [None]:
# best_valid_loss = float('inf')
# num_epochs = 15
# clip       = 1

# save_path = f'models/{time.asctime()}_{model.__class__.__name__}.pt'

# train_losses = []
# valid_losses = []

# for epoch in range(num_epochs):
    
#     start_time = time.time()

#     train_loss = train(model, train_loader, optimizer, criterion, clip, train_loader_length)
#     print("TRAIN ONE ITER TIME: ", str(time.time() - start_time))
#     valid_loss = evaluate(model, test_loader, criterion, test_loader_length)
    
#     #for plotting
#     train_losses.append(train_loss)
#     valid_losses.append(valid_loss)
    
#     end_time = time.time()
    
#     epoch_mins, epoch_secs = epoch_time(start_time, end_time)

    
#     if valid_loss < best_valid_loss:
#         best_valid_loss = valid_loss
#         torch.save(model.state_dict(), save_path)
    
#     print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
#     print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
#     print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')
    
#     #lower perplexity is better

In [None]:
train_data = [(data[SRC_LANGUAGE], data[TRG_LANGUAGE]) for data in preprocessed['train']]
test_data  = [(data[SRC_LANGUAGE], data[TRG_LANGUAGE]) for data in preprocessed['test']]
val_data   = [(data[SRC_LANGUAGE], data[TRG_LANGUAGE]) for data in preprocessed['validation']]

INPUT_DIM = len(vocab_transform[SRC_LANGUAGE])
OUTPUT_DIM = len(vocab_transform[TRG_LANGUAGE])
SRC_PAD_IDX = PAD_IDX
TRG_PAD_IDX = PAD_IDX
N_EPOCHS = 5

clip = 1
lr = 0.0005

def initialize_weights(m):
    if hasattr(m, 'weight') and m.weight.dim() > 1:
        nn.init.xavier_uniform_(m.weight.data)

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

def count_parameters(model):
    params = [p.numel() for p in model.parameters() if p.requires_grad]
    # for item in params:
        # print(f'{item:>6}')
    print(f'______\n{sum(params):>6}')

# count_parameters(model)

def plot_train(train_loss, valid_loss, attention):
    fig = plt.figure(figsize=(10, 5))
    ax = fig.add_subplot(1, 1, 1)
    
    ax.plot(train_loss, label='Training Loss', color='blue')
    ax.plot(valid_loss, label='Validation Loss', color='orange')
    
    plt.title(f'{attention}: Training and Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()

    plt.savefig(f'plots/traingraph_{attention}_{time.asctime()}.png')
    # plt.show()

def display_attention(sentence, translation, attention, fname='multihead_attention'):
    
    fig = plt.figure(figsize=(10,10))
    ax = fig.add_subplot(111)
    
    attention = attention.squeeze(1).cpu().detach().numpy()
    
    cax = ax.matshow(attention, cmap='bone')
   
    ax.tick_params(labelsize=10)
    
    y_ticks =  [''] + translation
    x_ticks =  [''] + sentence 
     
    ax.set_xticklabels(x_ticks, rotation=45)
    ax.set_yticklabels(y_ticks)

    ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))

    plt.savefig(f'plots/attention_{fname}_{time.asctime()}.png')

    # plt.show()
    # plt.close()


for attention in ['general', 'multiplicative']:# ['additive']:#:
    if attention == 'additive':
        BATCH_SIZE = 2
    else:
        BATCH_SIZE = 16

    train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
    valid_loader = DataLoader(val_data,   batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_batch)
    test_loader  = DataLoader(test_data,  batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_batch)

    train_loader_length = len(list(iter(train_loader)))
    val_loader_length   = len(list(iter(valid_loader)))
    test_loader_length  = len(list(iter(test_loader)))
    
    train_max_seq = max([max(batch[0].shape[1], batch[2].shape[1]) for batch in list(iter(train_loader))])
    val_max_seq = max([max(batch[0].shape[1], batch[2].shape[1]) for batch in list(iter(valid_loader))])
    test_max_seq = max([max(batch[0].shape[1], batch[2].shape[1]) for batch in list(iter(test_loader))])

    max_seq_len = max(train_max_seq, val_max_seq, test_max_seq)

    enc = Encoder(INPUT_DIM, 
              HID_DIM, 
              ENC_LAYERS, 
              ENC_HEADS, 
              ENC_PF_DIM, 
              ENC_DROPOUT, 
              attention,
              device,
              max_length=max_seq_len)

    dec = Decoder(OUTPUT_DIM, 
                  HID_DIM, 
                  DEC_LAYERS, 
                  DEC_HEADS, 
                  DEC_PF_DIM, 
                  DEC_DROPOUT, 
                  attention,
                  device,
                  max_length=max_seq_len)
    

    model = Seq2SeqTransformer(enc, dec, SRC_PAD_IDX, TRG_PAD_IDX, device).to(device)
    model.apply(initialize_weights)

    #training hyperparameters
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss(ignore_index = TRG_PAD_IDX) #combine softmax with cross entropy

    save_path = f'models/{attention}_{time.asctime()}_{model.__class__.__name__}.pt'

    best_train_loss = float('inf')
    best_valid_loss = float('inf')
    best_model = None
    train_losses = []
    valid_losses = []

    print(f'\n\t\t\t {attention}')

    start_training_time = time.time()
    avg_epoch_time = 0

    for epoch in range(N_EPOCHS):

        print("============= ", epoch)

        start_time = time.time()

        train_loss = train(model, train_loader, optimizer, criterion, clip, train_loader_length)
        valid_loss = evaluate(model, valid_loader, criterion, val_loader_length)

        #for plotting
        train_losses.append(train_loss)
        valid_losses.append(valid_loss)

        end_time = time.time()
        avg_epoch_time += end_time - start_time
        epoch_mins, epoch_secs = epoch_time(start_time, end_time)


        if valid_loss <= best_valid_loss:
            best_train_loss = train_loss
            best_valid_loss = valid_loss
            best_model = copy.deepcopy(model)
            torch.save(model.state_dict(), save_path)

        print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
        print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
        print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')

        #lower perplexity is better

    plot_train(train_losses, valid_losses, attention)

    best_train_ppl = math.exp(best_train_loss)
    best_valid_ppl = math.exp(best_valid_loss)
    
    # Calculate time taken for the traning
    avg_time = avg_epoch_time / N_EPOCHS
    overall_time = epoch_time(start_training_time, end_time)
    
    print(f"Best Training Loss: {best_train_loss:.3f}")
    print(f"Best Validation Loss: {best_valid_loss:.3f}")
    print(f"Best Training PPL: {best_train_ppl:.3f}")
    print(f"Best Validation PPL: {best_valid_ppl:.3f}")
    print(f"Average Time per epoch: {avg_time}")
    print(f"Overall time taken: {overall_time[0]}m {overall_time[1]}s")

    test_loss = evaluate(best_model, test_loader, criterion, test_loader_length)

    print(f'\n| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} |')

    src_txt = text_transform[SRC_LANGUAGE](preprocessed['test'][23]['en']).to(device)
    trg_txt = text_transform[TRG_LANGUAGE](preprocessed['test'][23]['ru']).to(device)
    src_txt = src_txt.reshape(1, -1)
    trg_txt = trg_txt.reshape(1, -1)
    text_length = torch.tensor([src_txt.size(0)]).to(dtype=torch.int64)

    start_time = time.time()
    with torch.no_grad():
        output, attention_plot = best_model(src_txt, trg_txt) #turn off teacher forcing
    
    end_time = time.time()
    
    inference_time = end_time - start_time
    
    print(f"Inference time: {inference_time}s")

    output = output.squeeze(0)
    output = output[1:]
    output_max = output.argmax(1) #returns max indices

    attention_plot = attention_plot[0, 0, :, :]
    src_tokens = ['<sos>'] + token_transform[SRC_LANGUAGE](preprocessed['test'][23]['en']) + ['<eos>']
    trg_tokens = ['<sos>'] + [mapping[token.item()] for token in output_max]

    display_attention(src_tokens, trg_tokens, attention_plot, fname=attention)
    
    # clearing cache
    torch.cuda.empty_cache()
    # freeing memory
    del enc
    del dec
    del model
    del best_model

# Additive
# Best Training Loss: 6.728
# Best Validation Loss: 6.441
# Best Training PPL: 835.796
# Best Validation PPL: 626.890
# Average Time per epoch: 1214.0751595973968
# Overall time taken: 101m 10s

# | Test Loss: 6.435 | Test PPL: 623.397 |
# Inference time: 0.010747194290161133s

# General

# 			 general
# =============  0
# Epoch: 01 | Time: 2m 34s
# 	Train Loss: 6.444 | Train PPL: 628.898
# 	 Val. Loss: 5.997 |  Val. PPL: 402.082
# =============  1
# Epoch: 02 | Time: 2m 34s
# 	Train Loss: 6.059 | Train PPL: 427.850
# 	 Val. Loss: 6.003 |  Val. PPL: 404.625
# =============  2
# Epoch: 03 | Time: 2m 33s
# 	Train Loss: 6.086 | Train PPL: 439.830
# 	 Val. Loss: 6.324 |  Val. PPL: 557.725
# =============  3
# Epoch: 04 | Time: 2m 32s
# 	Train Loss: 6.257 | Train PPL: 521.510
# 	 Val. Loss: 6.545 |  Val. PPL: 695.625
# =============  4
# Epoch: 05 | Time: 2m 29s
# 	Train Loss: 6.052 | Train PPL: 425.169
# 	 Val. Loss: 6.141 |  Val. PPL: 464.631
# Best Training Loss: 6.444
# Best Validation Loss: 5.997
# Best Training PPL: 628.898
# Best Validation PPL: 402.082
# Average Time per epoch: 152.9643747806549
# Overall time taken: 12m 44s

# | Test Loss: 5.977 | Test PPL: 394.320 |
# Inference time: 0.011830329895019531s

# 			 multiplicative
# =============  0
# Epoch: 01 | Time: 2m 54s
# 	Train Loss: 5.969 | Train PPL: 391.149
# 	 Val. Loss: 5.469 |  Val. PPL: 237.140
# =============  1
# Epoch: 02 | Time: 2m 35s
# 	Train Loss: 5.282 | Train PPL: 196.823
# 	 Val. Loss: 5.118 |  Val. PPL: 166.928
# =============  2
# Epoch: 03 | Time: 2m 53s
# 	Train Loss: 4.934 | Train PPL: 138.982
# 	 Val. Loss: 4.939 |  Val. PPL: 139.599
# =============  3
# Epoch: 04 | Time: 2m 54s
# 	Train Loss: 4.672 | Train PPL: 106.889
# 	 Val. Loss: 4.880 |  Val. PPL: 131.669
# =============  4
# Epoch: 05 | Time: 2m 45s
# 	Train Loss: 4.465 | Train PPL:  86.908
# 	 Val. Loss: 4.834 |  Val. PPL: 125.679
# Best Training Loss: 4.465
# Best Validation Loss: 4.834
# Best Training PPL: 86.908
# Best Validation PPL: 125.679
# Average Time per epoch: 168.6279025554657
# Overall time taken: 14m 8s

# | Test Loss: 4.808 | Test PPL: 122.545 |
# Inference time: 0.0119171142578125s

# Multiplicative is performing best for 5 epochs, lets train it longer up to 20 epochs

In [None]:

INPUT_DIM = len(vocab_transform[SRC_LANGUAGE])
OUTPUT_DIM = len(vocab_transform[TRG_LANGUAGE])
SRC_PAD_IDX = PAD_IDX
TRG_PAD_IDX = PAD_IDX
N_EPOCHS = 5

clip = 1
lr = 0.0005

def initialize_weights(m):
    if hasattr(m, 'weight') and m.weight.dim() > 1:
        nn.init.xavier_uniform_(m.weight.data)

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

def count_parameters(model):
    params = [p.numel() for p in model.parameters() if p.requires_grad]
    # for item in params:
        # print(f'{item:>6}')
    print(f'______\n{sum(params):>6}')

# count_parameters(model)

def plot_train(train_loss, valid_loss, attention):
    fig = plt.figure(figsize=(10, 5))
    ax = fig.add_subplot(1, 1, 1)
    
    ax.plot(train_loss, label='Training Loss', color='blue')
    ax.plot(valid_loss, label='Validation Loss', color='orange')
    
    plt.title(f'{attention}: Training and Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()

    plt.savefig(f'plots/traingraph_{attention}_{time.asctime()}.png')
    # plt.show()

def display_attention(sentence, translation, attention, fname='multihead_attention'):
    
    fig = plt.figure(figsize=(10,10))
    ax = fig.add_subplot(111)
    
    attention = attention.squeeze(1).cpu().detach().numpy()
    
    cax = ax.matshow(attention, cmap='bone')
   
    ax.tick_params(labelsize=10)
    
    y_ticks =  [''] + translation
    x_ticks =  [''] + sentence 
     
    ax.set_xticklabels(x_ticks, rotation=45)
    ax.set_yticklabels(y_ticks)

    ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))

    plt.savefig(f'plots/attention_{fname}_{time.asctime()}.png')

    # plt.show()
    # plt.close()
    
# lets train multiplicative further up to 30 epochs
for attention in ['multiplicative']:# ['additive']:#:'general', 
    if attention == 'additive':
        BATCH_SIZE = 2
    else:
        BATCH_SIZE = 16

    train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
    valid_loader = DataLoader(val_data,   batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_batch)
    test_loader  = DataLoader(test_data,  batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_batch)

    train_loader_length = len(list(iter(train_loader)))
    val_loader_length   = len(list(iter(valid_loader)))
    test_loader_length  = len(list(iter(test_loader)))
    
    train_max_seq = max([max(batch[0].shape[1], batch[2].shape[1]) for batch in list(iter(train_loader))])
    val_max_seq = max([max(batch[0].shape[1], batch[2].shape[1]) for batch in list(iter(valid_loader))])
    test_max_seq = max([max(batch[0].shape[1], batch[2].shape[1]) for batch in list(iter(test_loader))])

    max_seq_len = max(train_max_seq, val_max_seq, test_max_seq)

    enc = Encoder(INPUT_DIM, 
              HID_DIM, 
              ENC_LAYERS, 
              ENC_HEADS, 
              ENC_PF_DIM, 
              ENC_DROPOUT, 
              attention,
              device,
              max_length=max_seq_len)

    dec = Decoder(OUTPUT_DIM, 
                  HID_DIM, 
                  DEC_LAYERS, 
                  DEC_HEADS, 
                  DEC_PF_DIM, 
                  DEC_DROPOUT, 
                  attention,
                  device,
                  max_length=max_seq_len)
    

    model = Seq2SeqTransformer(enc, dec, SRC_PAD_IDX, TRG_PAD_IDX, device).to(device)
    model.apply(initialize_weights)

    #training hyperparameters
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss(ignore_index = TRG_PAD_IDX) #combine softmax with cross entropy

    save_path = f'models/{attention}_{time.asctime()}_{model.__class__.__name__}.pt'

    best_train_loss = float('inf')
    best_valid_loss = float('inf')
    best_model = None
    train_losses = []
    valid_losses = []

    print(f'\n\t\t\t {attention}')

    start_training_time = time.time()
    avg_epoch_time = 0

    for epoch in range(30):

        print("============= ", epoch)

        start_time = time.time()

        train_loss = train(model, train_loader, optimizer, criterion, clip, train_loader_length)
        valid_loss = evaluate(model, valid_loader, criterion, val_loader_length)

        #for plotting
        train_losses.append(train_loss)
        valid_losses.append(valid_loss)

        end_time = time.time()
        avg_epoch_time += end_time - start_time
        epoch_mins, epoch_secs = epoch_time(start_time, end_time)


        if valid_loss <= best_valid_loss:
            best_train_loss = train_loss
            best_valid_loss = valid_loss
            best_model = copy.deepcopy(model)
            torch.save(model.state_dict(), save_path)

        print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
        print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
        print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')

        #lower perplexity is better

    plot_train(train_losses, valid_losses, attention)

    best_train_ppl = math.exp(best_train_loss)
    best_valid_ppl = math.exp(best_valid_loss)
    
    # Calculate time taken for the traning
    avg_time = avg_epoch_time / 30
    overall_time = epoch_time(start_training_time, end_time)
    
    print(f"Best Training Loss: {best_train_loss:.3f}")
    print(f"Best Validation Loss: {best_valid_loss:.3f}")
    print(f"Best Training PPL: {best_train_ppl:.3f}")
    print(f"Best Validation PPL: {best_valid_ppl:.3f}")
    print(f"Average Time per epoch: {avg_time}")
    print(f"Overall time taken: {overall_time[0]}m {overall_time[1]}s")

    test_loss = evaluate(best_model, test_loader, criterion, test_loader_length)

    print(f'\n| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} |')

    src_txt = text_transform[SRC_LANGUAGE](preprocessed['test'][23]['en']).to(device)
    trg_txt = text_transform[TRG_LANGUAGE](preprocessed['test'][23]['ru']).to(device)
    src_txt = src_txt.reshape(1, -1)
    trg_txt = trg_txt.reshape(1, -1)
    text_length = torch.tensor([src_txt.size(0)]).to(dtype=torch.int64)

    start_time = time.time()
    with torch.no_grad():
        output, attention_plot = best_model(src_txt, trg_txt) #turn off teacher forcing
    
    end_time = time.time()
    
    inference_time = end_time - start_time
    
    print(f"Inference time: {inference_time}s")

    output = output.squeeze(0)
    output = output[1:]
    output_max = output.argmax(1) #returns max indices

    attention_plot = attention_plot[0, 0, :, :]
    src_tokens = ['<sos>'] + token_transform[SRC_LANGUAGE](preprocessed['test'][23]['en']) + ['<eos>']
    trg_tokens = ['<sos>'] + [mapping[token.item()] for token in output_max]

    display_attention(src_tokens, trg_tokens, attention_plot, fname=attention)
    
    # clearing cache
    torch.cuda.empty_cache()
    # freeing memory
    del enc
    del dec
    del model
    del best_model


### Analysis

It can be seen that the model is quite overfitting since validation set is not going down. Anyways, we are selecting best model based on validation set perplexity. Nonetheless, the training is taking a lot of time, will proceed for analysis based on current obtained results.

| Attentions | Training Loss | Traning PPL | Validation Loss | Validation PPL | Test Loss | Test PPL | AVG time per epoch | Overall time taken |
|----------|----------|----------|----------|----------|-|-|-|-|
| General Attention    | 6.444     | 628.898     | 5.997     | 402.082     | 5.977 | 394.320 | 152.9s | 12m 44s |
| Multiplicative Attention    | 4.465     | 86.908     | 4.834     | 125.679     | 4.808 | 122.545 | 168.6s | 14m 8s |
| Additive Attention    | 6.728     | 835.796     | 6.441     | 626.890     | 4.677 | 107.489 |1214s | 101m 10s |

```
The analysis of the results demonstrates that Multiplicative Attention outperforms the other mechanisms across all metrics, achieving the lowest training loss (4.465), training perplexity (86.908), validation loss (4.834), validation perplexity (125.679), test loss (4.808), and test perplexity (122.545). This suggests that Multiplicative Attention learns patterns effectively and generalizes well to unseen data, despite slightly higher computational cost with an average epoch time of 168.6 seconds and overall training time of 14 minutes 8 seconds. General Attention performs moderately well, with higher loss and perplexity across training, validation, and test datasets, but is faster, with an average epoch time of 152.9 seconds and total training time of 12 minutes 44 seconds, making it a viable option when computational efficiency is critical. Additive Attention, however, underperforms significantly, exhibiting the highest losses and perplexities across all datasets (e.g., training loss of 6.728 and perplexity of 835.796), indicating poor learning and generalization. The results suggest that Multiplicative Attention is the optimal choice for tasks requiring high accuracy, while General Attention may be considered for resource-constrained scenarios, and Additive Attention may require further tuning to improve its performance.
```

In [24]:
INPUT_DIM = len(vocab_transform[SRC_LANGUAGE])
OUTPUT_DIM = len(vocab_transform[TRG_LANGUAGE])
SRC_PAD_IDX = PAD_IDX
TRG_PAD_IDX = PAD_IDX
N_EPOCHS = 5

train_data = [(data[SRC_LANGUAGE], data[TRG_LANGUAGE]) for data in preprocessed['train']]
test_data  = [(data[SRC_LANGUAGE], data[TRG_LANGUAGE]) for data in preprocessed['test']]
val_data   = [(data[SRC_LANGUAGE], data[TRG_LANGUAGE]) for data in preprocessed['validation']]

BATCH_SIZE = 16

train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
valid_loader = DataLoader(val_data,   batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_batch)
test_loader  = DataLoader(test_data,  batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_batch)

train_loader_length = len(list(iter(train_loader)))
val_loader_length   = len(list(iter(valid_loader)))
test_loader_length  = len(list(iter(test_loader)))

train_max_seq = max([max(batch[0].shape[1], batch[2].shape[1]) for batch in list(iter(train_loader))])
val_max_seq = max([max(batch[0].shape[1], batch[2].shape[1]) for batch in list(iter(valid_loader))])
test_max_seq = max([max(batch[0].shape[1], batch[2].shape[1]) for batch in list(iter(test_loader))])

max_seq_len = max(train_max_seq, val_max_seq, test_max_seq)

enc = Encoder(INPUT_DIM, 
          HID_DIM, 
          ENC_LAYERS, 
          ENC_HEADS, 
          ENC_PF_DIM, 
          ENC_DROPOUT, 
          'multiplicative',
          device,
          max_length=max_seq_len)

dec = Decoder(OUTPUT_DIM, 
              HID_DIM, 
              DEC_LAYERS, 
              DEC_HEADS, 
              DEC_PF_DIM, 
              DEC_DROPOUT, 
              'multiplicative',
              device,
              max_length=max_seq_len)


model = Seq2SeqTransformer(enc, dec, SRC_PAD_IDX, TRG_PAD_IDX, device).to(device)

save_path = 'models/multiplicative_Sun Feb  2 09:26:59 2025_Seq2SeqTransformer.pt'

criterion = nn.CrossEntropyLoss(ignore_index = TRG_PAD_IDX) #combine softmax with cross entropy

model.load_state_dict(torch.load(save_path))
test_loss = evaluate(model, test_loader, criterion, test_loader_length)

print(f'| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} |')
#| Test Loss: 4.677 | Test PPL: 107.489 |

| Test Loss: 4.677 | Test PPL: 107.489 |


### Evaluation of Attention Mechanisms

The analysis of the results demonstrates that **Multiplicative Attention** outperforms the other mechanisms across all metrics, achieving the lowest training loss (4.465), training perplexity (86.908), validation loss (4.834), validation perplexity (125.679), test loss (4.808), and test perplexity (122.545). This suggests that Multiplicative Attention learns patterns effectively and generalizes well to unseen data, despite slightly higher computational cost with an average epoch time of 168.6 seconds and overall training time of 14 minutes 8 seconds.

**General Attention** performs moderately well, with higher loss and perplexity across training, validation, and test datasets. It is faster, with an average epoch time of 152.9 seconds and total training time of 12 minutes 44 seconds, making it a viable option when computational efficiency is critical. However, its generalization to unseen data is less effective compared to Multiplicative Attention.

**Additive Attention** underperforms significantly, exhibiting the highest losses and perplexities across all datasets (e.g., training loss of 6.728 and perplexity of 835.796). Despite having the best test perplexity (107.489) among the three mechanisms, its extremely high training and validation losses indicate overfitting and poor learning during training. Additionally, its computational cost is significantly higher, with an average epoch time of 1214 seconds and overall training time of 101 minutes 10 seconds.

### Conclusion
The results suggest that **Multiplicative Attention** is the optimal choice for tasks requiring high accuracy and generalization, albeit with a slightly higher computational cost. **General Attention** may be considered for resource-constrained scenarios where training time is a priority over accuracy. **Additive Attention** requires further tuning or architectural improvements to enhance its performance and computational efficiency.

Inference time for all models is good - 0.01-0.05 seconds on average

# Final result

Between three attentions, I am gonna choose multiplicative (attention used in transformers). I will deploy the dash application with this model.