In [1]:
import random
import time
import math
import copy

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torch.nn.functional as F

from dataset import nmtDataset
import helpers as utils

In [None]:
!python -m spacy download en_core_web_sm
!python -m spacy download de_core_news_sm

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
hyp_params = {
    "batch_size": 1,
    "num_epochs": 1,

    # Encoder parameters
    "encoder_embedding_size": 512,
    "encoder_dropout": 0.5,

    # Decoder parameters
    "decoder_dropout": 0.5,
    "decoder_embedding_size": 512,

    # Common parameters
    "hidden_size": 512,
    "num_layers": 2
}

In [5]:
log = utils.Logger('logs/emd512-enc2-dec2-bilstm.out')  

In [5]:
class Attention(nn.Module):
    def __init__(self, enc_hid_dim, dec_hid_dim):
        super().__init__()
        
        self.fc = nn.Linear(enc_hid_dim * 2, dec_hid_dim)
        self.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim)
        self.v = nn.Linear(dec_hid_dim, 1, bias = False)
        
    def forward(self, hidden, encoder_outputs):
        
        # Making hidden layer -> [batch size, hidden]
        hidden = torch.tanh(self.fc(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1)))
        
        #hidden = [batch size, dec hid dim]
        #encoder_outputs = [src len, batch size, enc hid dim * 2]
        
        batch_size = encoder_outputs.shape[1]
        src_len = encoder_outputs.shape[0]
        
        # Repeating hidden state for every sentences to the maximum sentence length
        # hidden = [batch size, src len, dec hid dim]
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
        
        # Sentences are in column (see enc_out shape), so instead we want each sentence
        # in a row. Hence, we are permuting it
        #encoder_outputs = [batch size, src len, enc hid dim * 2]
        encoder_outputs = encoder_outputs.permute(1, 0, 2)
        
        # Concatenation will put hidden states with relavant sentence and with the each
        # word of the sentence
        #energy = [batch size, src len, dec hid dim]
        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim = 2))) 
        
        #attention= [batch size, src len]
        attention = self.v(energy).squeeze(2)
        
        return F.softmax(attention, dim=1)

In [6]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size, num_layers, dropout):
        super(Encoder, self).__init__()
        
        self.dropout = nn.Dropout(dropout)
        
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        
        self.LSTM = nn.LSTM(embedding_dim, hidden_size, num_layers, dropout=dropout, bidirectional = True)
        
    def forward(self, x):
        # Shape (embedding) --> [Sequence_length , batch_size , embedding dims]
        embedding = self.dropout(self.embedding(x))
        
        # ************** Multiplied by 2 because of bi-directional LSTM
        # Shape --> (output) [Sequence_length , batch_size , hidden_size * 2]
        # Shape --> (hs, cs) [num_layers * 2, batch_size, hidden_size]
        outputs, (hidden_state, cell_state) = self.LSTM(embedding)
        
        return outputs, hidden_state, cell_state
    
class Decoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size, num_layers, dropout, output_size):
        super(Decoder, self).__init__()
        
        self.dropout = nn.Dropout(dropout)
        self.attention = Attention(hidden_size, hidden_size)
        
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        
         # ************** Multiplying 2 because of bi-directional LSTM
        self.LSTM = nn.LSTM(embedding_dim, hidden_size, num_layers * 2, dropout=dropout)
        
        self.fc = nn.Linear((hidden_size* 2) + hidden_size + embedding_dim, output_size)
    
    def forward(self, x, enc_outputs, hidden_state, cell_state):
        # As we are not feeding whole sentence we will each token a time
        # hence our sequence length would be just 1 however shape of x is batch_size
        # to add sequence length we will unsequeeze it
        # Shape (x) --> [batch_size] (see seq2seq model) so making it [1, batch_size]
        x = x.unsqueeze(0)
        
        # Shape (embedded) --> (1, batch_size, embedding dims)
        embedded = self.dropout(self.embedding(x))
        
        a = self.attention(hidden_state, enc_outputs)
        
        a = a.unsqueeze(1)
        
        #a = [batch size, 1, src len]
        
        enc_outputs = enc_outputs.permute(1, 0, 2)
        
        #encoder_outputs = [batch size, src len, enc hid dim * 2]
        
        weighted = torch.bmm(a, enc_outputs)
        
        #weighted = [batch size, 1, enc hid dim * 2]
        
        weighted = weighted.permute(1, 0, 2)
        
        #weighted = [1, batch size, enc hid dim * 2]
        
        rnn_input = torch.cat((embedded, weighted), dim = 2)
        
        
        outputs, (hidden_state, cell_state) = self.LSTM(embedded, (hidden_state, cell_state))
        
        '''
        # Shape (outputs) --> (1, 32, 1024) [1, batch_size , hidden_size]
        # Shape (hs, cl) --> (2, 32, 1024)  [num_layers * 2, batch_size , hidden_size]
        outputs, (hidden_state, cell_state) = self.LSTM(embedded, (hidden_state, cell_state))
        
        # Shape (outputs) -->  (1, batch_size, hidden_size)
        # Shape (outputs.squeeze(0)) -->  (batch_size, hidden_size)
        # Shape (predictions) --> (batch_size, target_vocab_size)
        '''
        
        embedded = embedded.squeeze(0)
        outputs = outputs.squeeze(0)
        weighted = weighted.squeeze(0)
        predictions = self.fc(torch.cat((outputs, weighted, embedded), dim = 1))
        
        return predictions, hidden_state, cell_state

class SeqtoSeq(nn.Module):
    def __init__(self, gen_params, target_vocab, device):
        super(SeqtoSeq, self).__init__()

        self.Encoder = Encoder(gen_params["input_size_encoder"],
                          gen_params["encoder_embedding_size"],
                          gen_params["hidden_size"],
                          gen_params["num_layers"],
                          gen_params["encoder_dropout"]).to(device)

        self.Decoder = Decoder(gen_params["input_size_decoder"],
                          gen_params["decoder_embedding_size"],
                          gen_params["hidden_size"],
                          gen_params["num_layers"],
                          gen_params["decoder_dropout"],
                          gen_params["output_size"]).to(device)

        self.target_vocab = target_vocab
        self.device = device
    
    def forward(self, source, target, tfr=0.5):
        # Shape -> (Sentence length, Batch_size)
        batch_size = source.shape[1]

        target_len = target.shape[0]  # Length of target sentences
        target_vocab_size = len(self.target_vocab)
        
        # here we will store all the outputs
        # so outputs is arrange in a way that sentences are in column and batch size is row and every element
        # will consist of probability of each word from the vocab
        outputs = torch.zeros(target_len, batch_size, target_vocab_size).to(self.device)

        # Shape --> (hs, cs) (num_layers * 2, batch_size size, hidden_size) (contains encoder's hs, cs - context vectors)
        enc_outputs, hidden_state, cell_state = self.Encoder(source)

        # Shape (target) -> (Sentence length, Batch_size)
        # Shape (x) --> (batch_size)
        x = target[0]  # First token (Trigger)
        
        for i in range(1, target_len):
            # Shape (output) --> (batch_size, target_vocab_size)
            # Shape (hs, cl) --> (num_layers * 2, batch_size , hidden_size)
            output, hidden_state, cell_state = self.Decoder(x, enc_outputs, hidden_state, cell_state)
            outputs[i] = output
            best_guess = output.argmax(1)  # 0th dimension is batch size, 1st dimension is word embedding
            # Schedule sampling
            x = target[
                i] if random.random() < tfr else best_guess  # Either pass the next word correctly from the dataset
            # or use the earlier predicted word

        # Shape --> (sentence length, batch size, vocab size)
        return outputs

In [7]:
nmtds_train = nmtDataset('datasets/Multi30k/', 'train')
nmtds_valid = nmtDataset('datasets/Multi30k/', 'val', nmtds_train)
nmtds_test = nmtDataset('datasets/Multi30k/', 'test', nmtds_train)

In [8]:
train_dataloader = DataLoader(nmtds_train, batch_size=hyp_params['batch_size'], shuffle=True,
                              collate_fn=lambda batch_size: utils.collate_fn(batch_size, device))

valid_dataloader = DataLoader(nmtds_valid, batch_size=hyp_params['batch_size'], shuffle=True,
                              collate_fn=lambda batch_size: utils.collate_fn(batch_size, device))

In [9]:
hyp_params["input_size_encoder"] = len(nmtds_train.src_vocab)
hyp_params["input_size_decoder"] = len(nmtds_train.trg_vocab)
hyp_params["output_size"] = len(nmtds_train.trg_vocab)

model = SeqtoSeq(hyp_params, target_vocab=nmtds_train.trg_vocab, device=device)
optimizer = optim.Adam(model.parameters())

pad_idx = nmtds_train.trg_vocab["<pad>"]
criterion = nn.CrossEntropyLoss(ignore_index=pad_idx).to(device)

In [None]:
min_el = math.inf
patience = 1
best_model = {}
best_epoch = 0

epoch_loss = 0
for epoch in range(hyp_params["num_epochs"]):
    start = time.time()
    
    epoch_loss = utils.train_model(model, train_dataloader, criterion, optimizer)
    eval_loss = utils.evaluate_model(model, valid_dataloader, criterion)
    
    log.log(f"Epoch: {epoch+1}, Train loss: {epoch_loss}, Eval loss: {eval_loss}, patience: {patience}. Time {time.time() - start}")

    
    if eval_loss < min_el:
        best_epoch = epoch
        min_el = eval_loss
        best_model = copy.deepcopy(model)
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'eval_loss': min_el
        }, 'model-vanilla.pt')
        patience = 1
    else:
        patience += 1
    
    if patience == 10:
        log.log("[STOPPING] Early stopping in action..")
        log.log(f"Best epoch was {best_epoch} with {min_el} eval loss")
        break

log.close()

In [None]:
model_l = SeqtoSeq(hyp_params, target_vocab=nmtds_train.trg_vocab, device=device)
model_l.load_state_dict(torch.load('model-bilstm.pt', map_location=device)["model_state_dict"])

In [11]:
utils.bleu(model_l, nmtds_test, device)

100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:24<00:00, 40.60it/s]


0.22469250857830048