In [1]:
import random
import math
import time
import gc
import copy

import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import Transformer
from torch.utils.data import DataLoader

import torchtext
from torchtext.data.metrics import bleu_score

import spacy
import numpy as np

from tqdm import tqdm

from dataset import nmtDataset
import helpers as utils

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
log = utils.Logger('logs/transformers.out')  

In [None]:
!python -m spacy download en_core_web_sm
!python -m spacy download de_core_news_sm

In [5]:
hyp_params = {
    "batch_size": 128,
    "lr": 0.0005,
    "num_epochs": 10,
    
    # Same as presented in paper
    "d_model": 512,
    
    # No. of multi-head attention block (aka paralle self-attention layers)
    # Same as presented in paper
    "n_head": 8,
    
    # N in the paper and they used 6 of each
    "num_encoder_layers": 3,
    "num_decoder_layers": 3,
    
    "feedforward_dim": 128,
    
    # Following paper
    "dropout": 0.1
}

In [6]:
nmtds_train = nmtDataset('datasets/Multi30k/', 'train')
nmtds_valid = nmtDataset('datasets/Multi30k/', 'val', nmtds_train)
nmtds_test = nmtDataset('datasets/Multi30k/', 'test', nmtds_train)

SRC_PAD_IDX = nmtds_train.src_vocab["<pad>"]
TRG_PAD_IDX = nmtds_train.trg_vocab["<pad>"]

train_dataloader = DataLoader(nmtds_train, batch_size=hyp_params["batch_size"], shuffle=True,
                              collate_fn=lambda batch_size: utils.collate_fn(batch_size, SRC_PAD_IDX, device))

valid_dataloader = DataLoader(nmtds_valid, batch_size=hyp_params["batch_size"], shuffle=True,
                              collate_fn=lambda batch_size: utils.collate_fn(batch_size, SRC_PAD_IDX, device))

hyp_params["src_vocab_size"] = len(nmtds_train.src_vocab)
hyp_params["trg_vocab_size"] = len(nmtds_train.trg_vocab)

In [7]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout, maxlen = 5000):
        super(PositionalEncoding, self).__init__()
        
        # A tensor consists of all the possible positions (index) e.g 0, 1, 2, ... max length of input
        # Shape (pos) --> [max len, 1]
        pos = torch.arange(0, maxlen).unsqueeze(1)
        
        pos_encoding = torch.zeros((maxlen, d_model))
        
        # In the paper, they had 2i in the positional encoding formula
        # where i is the dimension 
        sin_den = 10000 ** (torch.arange(0, d_model, 2)/d_model) # sin for even item of position's dimension
        cos_den = 10000 ** (torch.arange(1, d_model, 2)/d_model) # cos for odd 
        
        pos_encoding[:, 0::2] = torch.sin(pos / sin_den) 
        pos_encoding[:, 1::2] = torch.cos(pos / cos_den)
        
        # Shape (pos_embedding) --> [max len, d_model]
        pos_encoding = pos_encoding.unsqueeze(-2)
        # Shape (pos_embedding) --> [max len, 1, d_model]

        self.dropout = nn.Dropout(dropout)
        
        # We want pos_encoding be saved and restored in the `state_dict`, but not trained by the optimizer
        # hence registering it!
        # Source & credits: https://discuss.pytorch.org/t/what-is-the-difference-between-register-buffer-and-register-parameter-of-nn-module/32723/2
        self.register_buffer('pos_encoding', pos_encoding)

    def forward(self, token_embedding):
        
        # shape (token_embedding) --> [sentence len, batch size, d_model]
        
        # Concatenating embeddings with positional encodings
        # Note: As we made positional encoding with the size max length of sentence in our dataset 
        #       hence here we are picking till the sentence length in a batch
        #       Another thing to notice is in the paper they used FIXED positional encoding, there are
        #       methods where we can also learn them but we are doing as presented in the paper
        return self.dropout(token_embedding + self.pos_encoding[:token_embedding.size(0), :])
    
class InputEmbedding(nn.Module):
    def __init__(self, vocab_size, d_model):
        super(InputEmbedding, self).__init__()
        
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.d_model = d_model

    def forward(self, tokens):
        # shape (tokens) --> [sentence len, batch size]
        # shape (inp_emb) --> [sentence len, batch size, d_model]
        # Multiplying with square root of d_model as they mentioned in the paper
        inp_emb = self.embedding(tokens.long()) * math.sqrt(self.d_model)
        return inp_emb

In [29]:
class Seq2SeqTransformer(nn.Module):
    def __init__(self, 
                 src_vocab_size, 
                 trg_vocab_size, 
                 d_model, 
                 dropout,
                 nhead,
                 num_encoder_layers,
                 num_decoder_layers,
                 dim_feedforward,
                 src_pad_idx,
                 trg_pad_idx
                ):
        super(Seq2SeqTransformer, self).__init__()
        
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        
        self.src_inp_emb = InputEmbedding(src_vocab_size, d_model)
        self.trg_inp_emb = InputEmbedding(trg_vocab_size, d_model)
        
        self.positional_encoding = PositionalEncoding(d_model, dropout=dropout)
        
        self.transformer = Transformer(d_model=d_model,
                                       nhead=nhead,
                                       num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers,
                                       dim_feedforward=dim_feedforward,
                                       dropout=dropout)
        
        self.linear = nn.Linear(d_model, trg_vocab_size)
    
    def forward(self, src, trg):
        src_emb = self.positional_encoding(self.src_inp_emb(src))
        trg_emb = self.positional_encoding(self.trg_inp_emb(trg))
        
        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = self.create_mask(src, trg)
        
        outs = self.transformer(src = src_emb, 
                                tgt = trg_emb, 
                                src_mask = src_mask,
                                tgt_mask = tgt_mask, 
                                src_key_padding_mask = src_padding_mask, 
                                tgt_key_padding_mask = tgt_padding_mask,
                                memory_key_padding_mask = src_padding_mask
                               )
        return self.linear(outs)
        

    def create_mask(self, src, trg):
        src_seq_len = src.shape[0]
        trg_seq_len = trg.shape[0]

        # Subsequent mask aka "look ahead mask" is important as it wont let Decoder
        # to peek into future tokens.
        trg_mask = self.transformer.generate_square_subsequent_mask(trg_seq_len).to(device)
        
        src_mask = torch.zeros((src_seq_len, src_seq_len),device=device).type(torch.bool) # All False hence unchanged

        # Padding masking will allow attention to ignore padding <pad> tokens 
        src_padding_mask = (src == self.src_pad_idx).transpose(0, 1)
        trg_padding_mask = (trg == self.trg_pad_idx).transpose(0, 1)
        
        return src_mask, trg_mask, src_padding_mask, trg_padding_mask
    
    # These two functions will only used while inferring
    
    def encode(self, src):
        src_mask = torch.zeros((src.shape[0], src.shape[0]),device=device).type(torch.bool)
        src_padding_mask = (src == self.src_pad_idx).transpose(0, 1)
        
        return self.transformer.encoder(self.positional_encoding(self.src_inp_emb(src)), src_mask, src_padding_mask)

    def decode(self, trg, memory):
        # memory is the output from the encoder 
        trg_seq_len = trg.shape[0]
        trg_mask = self.transformer.generate_square_subsequent_mask(trg_seq_len).type(torch.bool).to(device)
        trg_padding_mask = (trg == self.trg_pad_idx).transpose(0, 1)
        
        return self.transformer.decoder(tgt = self.positional_encoding(self.trg_inp_emb(trg)), 
                                        memory = memory,
                                        tgt_mask = trg_mask,
                                        tgt_key_padding_mask = trg_padding_mask
                                       )

In [9]:
def train_model(model, train_dataloader, criterion, optimizer):
    model.train()
    epoch_loss = 0
    for batch_idx, batch in enumerate(tqdm(train_dataloader)):
        # shape (src, trg) --> [seq len, batch size]
        src = batch["src"]
        trg = batch["trg"]

        # Clear the accumulating gradients
        optimizer.zero_grad()

        # shape (trg_inp, trg_out) --> [seq len - 1, batch size]
        trg_inp = trg[:-1, :]
        trg_out = trg[1:, :]

        # shape --> (seq len - 1) * batch size 
        # Making all target seqeunces in 1d tensor
        trg_out = trg_out.reshape(-1)

        # shape (logits) --> [seq len - 1, batch size, trg vocab size]
        logits = model(src, trg_inp)

        # shape (logits) --> [(seq len - 1) * batch size, trg vocab size]
        logits = logits.reshape(-1, logits.shape[-1])

        loss = criterion(logits, trg_out)

        loss.backward()

        optimizer.step()
        epoch_loss += loss.detach().cpu()

    return epoch_loss/len(train_dataloader)

def evaluate_model(model, valid_dataloader, criterion):
    model.eval()
    epoch_loss = 0
    with torch.no_grad():
        for batch_idx, batch in enumerate(valid_dataloader):
            # shape (src, trg) --> [seq len, batch size]
            src = batch["src"]
            trg = batch["trg"]

            # shape (trg_inp, trg_out) --> [seq len - 1, batch size]
            trg_inp = trg[:-1, :]
            trg_out = trg[1:, :]

            # shape --> (seq len - 1) * batch size 
            # Making all target seqeunces in 1d tensor
            trg_out = trg_out.reshape(-1)

            # shape (logits) --> [seq len - 1, batch size, trg vocab size]
            logits = model(src, trg_inp)

            # shape (logits) --> [(seq len - 1) * batch size, trg vocab size]
            logits = logits.reshape(-1, logits.shape[-1])

            loss = criterion(logits, trg_out)

            epoch_loss += loss.detach().cpu()
    
    return epoch_loss/len(valid_dataloader)

In [10]:
model = Seq2SeqTransformer(hyp_params["src_vocab_size"],
                                hyp_params["trg_vocab_size"],
                                hyp_params["d_model"],
                                hyp_params["dropout"],
                                hyp_params["n_head"],
                                hyp_params["num_encoder_layers"],
                                hyp_params["num_decoder_layers"],
                                hyp_params["feedforward_dim"],
                                SRC_PAD_IDX,
                                TRG_PAD_IDX
                                ).to(device)

# They did not mention it in paper however tutorials from 
# bentrevett and others uses it so -_(-.-)_-
for p in model.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

criterion = nn.CrossEntropyLoss(ignore_index=TRG_PAD_IDX).to(device)

# They did not used fixed learning rate in the paper infact 
# their optimizer would look like 
#   optimizer = torch.optim.Adam(transformer.parameters(), lr=lr, betas=(0.9, 0.98), eps=1e-9)
# with variable learning rate as they presented in the paper
# but for the sake of simplicity and also fixed lr works fine
optimizer = optim.Adam(model.parameters(), lr=hyp_params["lr"])

In [11]:
min_el = math.inf
patience = 1
best_model = {}
best_epoch = 0

epoch_loss = 0
for epoch in range(hyp_params["num_epochs"]):
    start = time.time()
    
    gc.collect()
    torch.cuda.empty_cache()
    
    epoch_loss = train_model(model, train_dataloader, criterion, optimizer)
    eval_loss = evaluate_model(model, valid_dataloader, criterion)
    
    log.log(f"Epoch: {epoch+1}, Train loss: {epoch_loss}, Eval loss: {eval_loss}, patience: {patience}. Time {time.time() - start}")

    
    if eval_loss < min_el:
        best_epoch = epoch+1
        min_el = eval_loss
        best_model = copy.deepcopy(model)
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'eval_loss': min_el
        }, 'model-transformer.pt')
        patience = 1
    else:
        patience += 1
    
    if patience == 10:
        log.log("[STOPPING] Early stopping in action..")
        log.log(f"Best epoch was {best_epoch} with {min_el} eval loss")
        break
        
log.log(f"Best epoch was {best_epoch} with {min_el} eval loss")
log.close()

100%|██████████| 227/227 [01:14<00:00,  3.03it/s]


Epoch: 1, Train loss: 4.186707019805908, Eval loss: 2.9108266830444336, patience: 1. Time 76.19188642501831


100%|██████████| 227/227 [01:13<00:00,  3.07it/s]


Epoch: 2, Train loss: 2.474423885345459, Eval loss: 1.926767349243164, patience: 1. Time 74.95539593696594


100%|██████████| 227/227 [01:13<00:00,  3.08it/s]


Epoch: 3, Train loss: 1.7578703165054321, Eval loss: 1.619139313697815, patience: 1. Time 74.64624404907227


100%|██████████| 227/227 [01:14<00:00,  3.06it/s]


Epoch: 4, Train loss: 1.3715559244155884, Eval loss: 1.460307002067566, patience: 1. Time 75.28154706954956


100%|██████████| 227/227 [01:13<00:00,  3.07it/s]


Epoch: 5, Train loss: 1.117287278175354, Eval loss: 1.3753738403320312, patience: 1. Time 75.0431854724884


100%|██████████| 227/227 [01:13<00:00,  3.07it/s]


Epoch: 6, Train loss: 0.9391891360282898, Eval loss: 1.3951774835586548, patience: 1. Time 74.9404308795929


100%|██████████| 227/227 [01:14<00:00,  3.06it/s]


Epoch: 7, Train loss: 0.8048578500747681, Eval loss: 1.3871698379516602, patience: 2. Time 75.14989805221558


100%|██████████| 227/227 [01:13<00:00,  3.07it/s]


Epoch: 8, Train loss: 0.7016175389289856, Eval loss: 1.407729983329773, patience: 3. Time 75.03720164299011


100%|██████████| 227/227 [01:14<00:00,  3.07it/s]


Epoch: 9, Train loss: 0.618704080581665, Eval loss: 1.441323161125183, patience: 4. Time 75.14790487289429


100%|██████████| 227/227 [01:14<00:00,  3.06it/s]


Epoch: 10, Train loss: 0.5506867170333862, Eval loss: 1.4814356565475464, patience: 5. Time 75.14989900588989
Best epoch was 5 with 1.3753738403320312 eval loss


In [None]:
model_l = Seq2SeqTransformer(hyp_params["src_vocab_size"],
                                hyp_params["trg_vocab_size"],
                                hyp_params["d_model"],
                                hyp_params["dropout"],
                                hyp_params["n_head"],
                                hyp_params["num_encoder_layers"],
                                hyp_params["num_decoder_layers"],
                                hyp_params["feedforward_dim"],
                                SRC_PAD_IDX,
                                TRG_PAD_IDX
                                ).to(device)

model_l.load_state_dict(torch.load('model-transformer.pt', map_location=device)["model_state_dict"])
model_l.eval()

In [31]:
def translate(snt, dataset, model, device):
    snt = torch.tensor(snt).view(-1,1).to(device)
    
    num_tokens = snt.shape[0]
    max_len = 50
    
    with torch.no_grad():
        memory = model.encode(snt).to(device)
    
    ys = torch.LongTensor([dataset.trg_vocab['<sos>']]).unsqueeze(0).to(device)

    for i in range(max_len):
        with torch.no_grad():
            out = model.decode(ys, memory)
            
        out = out.transpose(0, 1)
        prob = model.linear(out[:, -1])
        next_word = prob.argmax().detach().item()

        ys = torch.cat([ys, torch.tensor([next_word]).unsqueeze(1).to(device)])
        
        
        if next_word == dataset.trg_vocab['<eos>']:
            break
    
    return dataset.trg_vocab.lookup_tokens(ys.squeeze().cpu().numpy())


def bleu(model, dataset, device):
    targets = []
    outputs = []

    for example in tqdm(dataset):
        src = example["src"]
        trg = example["trg"]
        
        trg = dataset.trg_vocab.lookup_tokens(trg)    
        prediction = translate(src, dataset, model, device)
        
        prediction = prediction[1:-1]  # removing <sos> and <eos> tokens
        trg = trg[1:-1]
        
        targets.append([trg])
        outputs.append(prediction)

    return bleu_score(outputs, targets)


In [32]:
bleu(model_l, nmtds_test, device)

100%|██████████| 1000/1000 [01:01<00:00, 16.38it/s]


0.36346383547707484