In [80]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import Transformer
from tqdm import tqdm

import torchtext

import gc

from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker


from dataset import nmtDataset
import helpers as utils

import spacy
import numpy as np

import random
import math
import time

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
log = utils.Logger('logs/transformers.out')  

In [63]:
hyp_params = {
    "batch_size": 128,
    "num_epochs": 10,
    "d_model": 512,
    "n_head": 8,
    "num_encoder_layers": 3,
    "num_decoder_layers": 3,
    "feedforward_dim": 128,
    "dropout": 0.1
}

In [64]:
nmtds_train = nmtDataset('datasets/Multi30k/', 'train')
nmtds_valid = nmtDataset('datasets/Multi30k/', 'val', nmtds_train)
nmtds_test = nmtDataset('datasets/Multi30k/', 'test', nmtds_train)

SRC_PAD_IDX = nmtds_train.src_vocab["<pad>"]
TRG_PAD_IDX = nmtds_train.trg_vocab["<pad>"]

train_dataloader = DataLoader(nmtds_train, batch_size=hyp_params["batch_size"], shuffle=True,
                              collate_fn=lambda batch_size: utils.collate_fn(batch_size, SRC_PAD_IDX, device))

valid_dataloader = DataLoader(nmtds_valid, batch_size=hyp_params["batch_size"], shuffle=True,
                              collate_fn=lambda batch_size: utils.collate_fn(batch_size, SRC_PAD_IDX, device))

hyp_params["src_vocab_size"] = len(nmtds_train.src_vocab)
hyp_params["trg_vocab_size"] = len(nmtds_train.trg_vocab)

In [29]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout, maxlen = 5000):
        super(PositionalEncoding, self).__init__()
        
        # A tensor consists of all the possible positions (index) e.g 0, 1, 2, ... max length of input
        # Shape (pos) --> [max len, 1]
        pos = torch.arange(0, maxlen).unsqueeze(1)
        
        pos_encoding = torch.zeros((maxlen, d_model))
        
        # In the paper, they had 2i in the positional encoding formula
        # where i is the dimension 
        sin_den = 10000 ** (torch.arange(0, d_model, 2)/d_model) # sin for even item of position's dimension
        cos_den = 10000 ** (torch.arange(1, d_model, 2)/d_model) # cos for odd 
        
        pos_encoding[:, 0::2] = torch.sin(pos / sin_den) 
        pos_encoding[:, 1::2] = torch.cos(pos / cos_den)
        
        # Shape (pos_embedding) --> [max len, d_model]
        pos_encoding = pos_encoding.unsqueeze(-2)
        # Shape (pos_embedding) --> [max len, 1, d_model]

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_encoding', pos_encoding)

    def forward(self, token_embedding):
        
        # shape (token_embedding) --> [sentence len, batch size, d_model]
        
        # Combining embeddings with positional encodings
        # Note: As we made positional encoding with the size max length of sentence in our dataset 
        #       hence here we are picking till the sentence length in a batch
        #       Another thing to notice is in the paper they used FIXED positional encoding, there are
        #       methods where we can also learn them but we are doing as presented in the paper
        return self.dropout(token_embedding + self.pos_encoding[:token_embedding.size(0), :])
    
class InputEmbedding(nn.Module):
    def __init__(self, vocab_size, d_model):
        super(InputEmbedding, self).__init__()
        
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.d_model = d_model

    def forward(self, tokens):
        # shape (tokens) --> [sentence len, batch size]
        # shape (inp_emb) --> [sentence len, batch size, d_model]
        inp_emb = self.embedding(tokens.long()) * math.sqrt(self.d_model)
        return inp_emb

In [42]:
class Seq2SeqTransformer(nn.Module):
    def __init__(self, 
                 src_vocab_size, 
                 trg_vocab_size, 
                 d_model, 
                 dropout,
                 nhead,
                 num_encoder_layers,
                 num_decoder_layers,
                 dim_feedforward,
                 src_pad_idx,
                 trg_pad_idx
                ):
        super(Seq2SeqTransformer, self).__init__()
        
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        
        self.src_inp_emb = InputEmbedding(src_vocab_size, d_model)
        self.trg_inp_emb = InputEmbedding(trg_vocab_size, d_model)
        
        self.positional_encoding = PositionalEncoding(d_model, dropout=dropout)
        
        self.transformer = Transformer(d_model=d_model,
                                       nhead=nhead,
                                       num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers,
                                       dim_feedforward=dim_feedforward,
                                       dropout=dropout)
        
        self.generator = nn.Linear(d_model, trg_vocab_size)
    
    def forward(self, src, trg):
        src_emb = self.positional_encoding(self.src_inp_emb(src))
        trg_emb = self.positional_encoding(self.trg_inp_emb(trg))
        
        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = self.create_mask(src, trg)
        
        outs = self.transformer(src = src_emb, 
                                tgt = trg_emb, 
                                src_mask = src_mask,
                                tgt_mask = tgt_mask, 
                                src_key_padding_mask = src_padding_mask, 
                                tgt_key_padding_mask = tgt_padding_mask)
        return self.generator(outs)
        

    def create_mask(self, src, tgt):
        src_seq_len = src.shape[0]
        tgt_seq_len = tgt.shape[0]

        tgt_mask = self.transformer.generate_square_subsequent_mask(tgt_seq_len)
        src_mask = torch.zeros((src_seq_len, src_seq_len),device=device).type(torch.bool) # All False hence unchanged

        # Padding masking will allow attention to ignore padding <pad> tokens 
        src_padding_mask = (src == self.src_pad_idx).transpose(0, 1)
        tgt_padding_mask = (tgt == self.trg_pad_idx).transpose(0, 1)
        
        return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

In [77]:
def train_model(model, train_dataloader, criterion, optimizer):
    model.train()
    epoch_loss = 0
    for batch_idx, batch in enumerate(tqdm(train_dataloader)):
        # shape (src, trg) --> [seq len, batch size]
        src = batch["src"]
        trg = batch["trg"]

        # Clear the accumulating gradients
        optimizer.zero_grad()

        # shape (trg_inp, trg_out) --> [seq len - 1, batch size]
        trg_inp = trg[:-1, :]
        trg_out = trg[1:, :]

        # shape --> (seq len - 1) * batch size 
        # Making all target seqeunces in 1d tensor
        trg_out = trg_out.reshape(-1)

        # shape (logits) --> [seq len - 1, batch size, trg vocab size]
        logits = transformer(src, trg_inp)

        # shape (logits) --> [(seq len - 1) * batch size, trg vocab size]
        logits = logits.reshape(-1, logits.shape[-1])

        loss = criterion(logits, trg_out)

        loss.backward()

        optimizer.step()
        epoch_loss += loss.detach().cpu()

    return epoch_loss/len(train_dataloader)

def evaluate_model(transformer, valid_dataloader, criterion):
    model.eval()
    epoch_loss = 0
    with torch.no_grad():
        for batch_idx, batch in enumerate(valid_dataloader):
            # shape (src, trg) --> [seq len, batch size]
            src = batch["src"]
            trg = batch["trg"]

            # shape (trg_inp, trg_out) --> [seq len - 1, batch size]
            trg_inp = trg[:-1, :]
            trg_out = trg[1:, :]

            # shape --> (seq len - 1) * batch size 
            # Making all target seqeunces in 1d tensor
            trg_out = trg_out.reshape(-1)

            # shape (logits) --> [seq len - 1, batch size, trg vocab size]
            logits = transformer(src, trg_inp)

            # shape (logits) --> [(seq len - 1) * batch size, trg vocab size]
            logits = logits.reshape(-1, logits.shape[-1])

            loss = criterion(logits, trg_out)

            epoch_loss += loss.detach().cpu()
    
    return epoch_loss/len(valid_dataloader)

In [78]:
transformer = Seq2SeqTransformer(hyp_params["src_vocab_size"],
                                hyp_params["trg_vocab_size"],
                                hyp_params["d_model"],
                                hyp_params["dropout"],
                                hyp_params["n_head"],
                                hyp_params["num_encoder_layers"],
                                hyp_params["num_decoder_layers"],
                                hyp_params["feedforward_dim"],
                                SRC_PAD_IDX,
                                TRG_PAD_IDX
                                ).to(device)

for p in transformer.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

criterion = nn.CrossEntropyLoss(ignore_index=TRG_PAD_IDX)

optimizer = optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

In [None]:
min_el = math.inf
patience = 1
best_model = {}
best_epoch = 0

epoch_loss = 0
for epoch in range(hyp_params["num_epochs"]):
    start = time.time()
    
    gc.collect()
    torch.cuda.empty_cache()
    
    epoch_loss = train_model(transformer, train_dataloader, criterion, optimizer)
    eval_loss = evaluate_model(transformer, valid_dataloader, criterion)
    
    log.log(f"Epoch: {epoch+1}, Train loss: {epoch_loss}, Eval loss: {eval_loss}, patience: {patience}. Time {time.time() - start}")

    
    if eval_loss < min_el:
        best_epoch = epoch
        min_el = eval_loss
        best_model = copy.deepcopy(model)
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'eval_loss': min_el
        }, 'model-attention-mask.pt')
        patience = 1
    else:
        patience += 1
    
    if patience == 10:
        log.log("[STOPPING] Early stopping in action..")
        log.log(f"Best epoch was {best_epoch} with {min_el} eval loss")
        break
        
log.log(f"Best epoch was {best_epoch} with {min_el} eval loss")
log.close()