In [11]:
import torch
from torchtext.data.metrics import bleu_score

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F

from math import log
from copy import deepcopy

from dataset import nmtDataset

from tqdm import tqdm

In [12]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [13]:
hyp_params = {
    "batch_size": 64,
    "num_epochs": 10,

    # Encoder parameters
    "encoder_embedding_size": 512,
    "encoder_dropout": 0, # Disabled dropout because now we are only using single layer LSTM

    # Decoder parameters
    "decoder_dropout": 0,
    "decoder_embedding_size": 512,

    # Common parameters
    "hidden_size": 512,
    "num_layers": 1
}

In [14]:
class Attention(nn.Module):
    def __init__(self, enc_hid_dim, dec_hid_dim):
        super(Attention, self).__init__()
        
        self.fc = nn.Linear(enc_hid_dim * 2, dec_hid_dim)
        self.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim)
        self.v = nn.Linear(dec_hid_dim, 1, bias = False)
        
    def forward(self, hidden, encoder_outputs, mask):
        # Shape (hidden): --> [num_layers * 2, batch_size, hidden_size]
        # encoder_outputs: Shape --> [Sequence_length , batch_size , hidden_size * 2]
        
        # Making hidden layer -> [batch size, hidden_size]
        hidden = torch.tanh(self.fc(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1)))
        
        batch_size = encoder_outputs.shape[1]
        src_len = encoder_outputs.shape[0]
        
        # Repeating hidden state for every sentences to the maximum sentence length
        # Shape (hidden): --> [batch size, src len, hidden_size]
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
        
        # Sentences are in column (see enc_out shape), so instead we want each sentence
        # in a row. Hence, we are permuting it
        # Shape (encoder_outputs): --> [batch size, src len, hidden_size * 2]
        encoder_outputs = encoder_outputs.permute(1, 0, 2)
        
        # Concatenation will put hidden states with relavant sentence and with the each
        # word of the sentence
        # Shape (energy): --> [batch size, src len, hidden_size]
        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim = 2))) 
        
        # Shape (attention): --> [batch size, src len]
        attention = self.v(energy).squeeze(2)
        
        attention = attention.masked_fill(mask == 0, -1e10)
        
        return F.softmax(attention, dim=1)

In [15]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size, num_layers, dropout):
        super(Encoder, self).__init__()
        
        self.dropout = nn.Dropout(dropout)
        
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        
        self.LSTM = nn.LSTM(embedding_dim, hidden_size, num_layers, dropout=dropout, bidirectional = True)
        
    def forward(self, x):
        # Shape (embedding) --> [Sequence_length , batch_size , embedding dims]
        embedding = self.dropout(self.embedding(x))
        
        # ************** Multiplied by 2 because of bi-directional LSTM
        # Shape --> (output) [Sequence_length , batch_size , hidden_size * 2]
        # Shape --> (hs, cs) [num_layers * 2, batch_size, hidden_size]
        outputs, (hidden_state, cell_state) = self.LSTM(embedding)
        
        return outputs, hidden_state, cell_state
    
class Decoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size, num_layers, dropout, output_size):
        super(Decoder, self).__init__()
        
        self.dropout = nn.Dropout(dropout)
        self.attention = Attention(hidden_size, hidden_size)
        
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        
         # ************** Multiplying 2 because of bi-directional LSTM
        self.LSTM = nn.LSTM((hidden_size * 2) + embedding_dim, hidden_size, num_layers * 2, dropout=dropout)
        
        self.fc = nn.Linear((hidden_size* 2) + hidden_size + embedding_dim, output_size)
    
    def forward(self, x, enc_outputs, hidden_state, cell_state, mask):
        # As we are not feeding whole sentence we will each token a time
        # hence our sequence length would be just 1 however shape of x is batch_size
        # to add sequence length we will unsequeeze it
        # Shape (x) --> [batch_size] (see seq2seq model) so making it [1, batch_size]
        x = x.unsqueeze(0)
        
        # Shape (embedded) --> (1, batch_size, embedding dims)
        embedded = self.dropout(self.embedding(x))
        
        # Shape (a): --> [batch_size, src len]
        a = self.attention(hidden_state, enc_outputs, mask)
        
        # Shape (a): --> [batch_size, 1, src len]
        a = a.unsqueeze(1)
        
        # Shape (encoder_outputs): --> [batch_size, src len, hidden_size * 2]
        enc_outputs = enc_outputs.permute(1, 0, 2)
        
        # Shape (weighted): --> [batch_size, 1, hidden_size * 2]
        weighted = torch.bmm(a, enc_outputs)
        
        # Shape (weighted): --> [1, batch_size, hidden_size * 2]
        weighted = weighted.permute(1, 0, 2)
        
        # Shape (rnn_input): --> [1, batch_size, (hidden_size * 2) + embedding dims]
        rnn_input = torch.cat((embedded, weighted), dim = 2)
        
        
        # Shape (output): --> [1, batch_size, hidden_size]
        # Shape (hidden_state, cell_state): --> [num_layers * num_directions, batch_size, hidden_size]
        outputs, (hidden_state, cell_state) = self.LSTM(rnn_input, (hidden_state, cell_state))
        
        embedded = embedded.squeeze(0) # Shape (output): --> [batch_size, hidden_size]
        outputs = outputs.squeeze(0) # Shape (output): --> [batch_size, hidden_size]
        weighted = weighted.squeeze(0) # Shape (output): --> [batch_size, hidden_size * 2]
        
        # Shape (predictions): --> [batch_size, output_size]
        predictions = self.fc(torch.cat((outputs, weighted, embedded), dim = 1))
        
        return predictions, hidden_state, cell_state, a.squeeze(1)

class SeqtoSeq(nn.Module):
    def __init__(self, gen_params, target_vocab, src_pad_idx, device):
        super(SeqtoSeq, self).__init__()

        self.Encoder = Encoder(gen_params["input_size_encoder"],
                          gen_params["encoder_embedding_size"],
                          gen_params["hidden_size"],
                          gen_params["num_layers"],
                          gen_params["encoder_dropout"]).to(device)

        self.Decoder = Decoder(gen_params["input_size_decoder"],
                          gen_params["decoder_embedding_size"],
                          gen_params["hidden_size"],
                          gen_params["num_layers"],
                          gen_params["decoder_dropout"],
                          gen_params["output_size"]).to(device)

        self.target_vocab = target_vocab
        self.src_pad_idx = src_pad_idx
        self.device = device
    
    def create_mask(self, src):
        mask = (src != self.src_pad_idx).permute(1, 0)
        return mask
    
    def forward(self, source, target, tfr=0.5):
        # (source) Shape -> (Sentence length, Batch_size)
        batch_size = source.shape[1]

        target_len = target.shape[0]  # Length of target sentences
        target_vocab_size = len(self.target_vocab)
        
        # here we will store all the outputs
        # so outputs is arrange in a way that sentences are in column and batch size is row and every element
        # will consist of probability of each word from the vocab
        outputs = torch.zeros(target_len, batch_size, target_vocab_size).to(self.device)

        # Shape --> (hs, cs) (num_layers * 2, batch_size size, hidden_size) (contains encoder's hs, cs - context vectors)
        enc_outputs, hidden_state, cell_state = self.Encoder(source)

        # Shape (target) -> (Sentence length, Batch_size)
        # Shape (x) --> (batch_size)
        x = target[0]  # First token (Trigger)
        
        mask = self.create_mask(source)
        
        for i in range(1, target_len):
            # Shape (output) --> (batch_size, target_vocab_size)
            # Shape (hs, cl) --> (num_layers * 2, batch_size , hidden_size)
            # _ is attention which we dont need here!
            output, hidden_state, cell_state, _ = self.Decoder(x, enc_outputs, hidden_state, cell_state, mask)
            outputs[i] = output
            best_guess = output.argmax(1)  # 0th dimension is batch size, 1st dimension is word embedding
            # Schedule sampling
            x = target[
                i] if random.random() < tfr else best_guess  # Either pass the next word correctly from the dataset
            # or use the earlier predicted word

        # Shape --> (sentence length, batch size, vocab size)
        return outputs

In [16]:
nmtds_train = nmtDataset('datasets/Multi30k/', 'train')
nmtds_test = nmtDataset('datasets/Multi30k/', 'test', nmtds_train)

In [17]:
hyp_params["input_size_encoder"] = len(nmtds_train.src_vocab)
hyp_params["input_size_decoder"] = len(nmtds_train.trg_vocab)
hyp_params["output_size"] = len(nmtds_train.trg_vocab)
pad_idx = nmtds_train.trg_vocab["<pad>"]

model_l = SeqtoSeq(hyp_params, nmtds_train.trg_vocab, pad_idx, device=device)
model_l.load_state_dict(torch.load('model-attention.pt', map_location=device)["model_state_dict"])

<All keys matched successfully>

In [18]:
def translate(snt, dataset, model, attention, device):
    tokens = dataset.tokenizers['en'](snt.lower().strip())
    indices = [dataset.src_vocab['<sos>']] + dataset.src_vocab.lookup_indices(tokens) + [dataset.src_vocab['<eos>']]
    inp_tensor = torch.tensor(indices).unsqueeze(1).to(device)

    # Build encoder hidden, cell state
    with torch.no_grad():
        eouts, hidden, cell = model.Encoder(inp_tensor)

    start = dataset.trg_vocab["<sos>"]
    beam_width = 2
    beam = []
    
    for _ in range(50):
        with torch.no_grad():
            mask = model.create_mask(inp_tensor)

            if len(beam) == 0:
                previous_word = torch.LongTensor([start]).to(device)
                output, hidden, cell, _ = model.Decoder(previous_word, eouts, hidden, cell, mask)
                
                probs = output.topk(2).values.squeeze().numpy()
                vals = output.topk(2).indices.squeeze().numpy()

                for p, v in zip(probs, vals):
                    beam.append([{"indices": v, "prob": p, "h": hidden, "c": cell}, log(p)])
            elif len(beam) > 0:
                temp_beam = []
                both_eos = 0
                for w in beam:
                    previous_word = torch.LongTensor([w[-2]["indices"]]).to(device)
                    output, hidden, cell, _ = model.Decoder(previous_word, eouts, w[-2]["h"], w[-2]["c"], mask)
                    
                    print(output.shape)

                    probs = output.topk(2).values.squeeze().numpy()
                    vals = output.topk(2).indices.squeeze().numpy()
                    for p, v in zip(probs, vals):
                        temp_beam.append(w[:-1] + [{"indices": v, "prob": p, "h": hidden, "c": cell}, log(p) + w[-1]])
                    
                    
                
                temp_beam.sort(key=lambda x: x[-1])
                beam = temp_beam[-2:]
                
                

    
    beam[0] = [{ "p": d["prob"], "i": d["indices"] } for d in beam[0][:-1] if d["indices"] != dataset.trg_vocab['<eos>']]
    beam[1] = [{ "p": d["prob"], "i": d["indices"] } for d in beam[1][:-1] if d["indices"] != dataset.trg_vocab['<eos>']]
    
    if sum([p["p"] for p in beam[0]]) > sum([p["p"] for p in beam[1]]):
        res = [i["i"] for i in beam[0]]
    else:
        res = [i["i"] for i in beam[1]]
    
    return dataset.trg_vocab.lookup_tokens(res)


In [19]:
def bleu(model, dataset, attention, device):
    targets = []
    outputs = []

    for example in tqdm(dataset):
        src = example["src"][1:-1]
        trg = example["trg"][1:-1]
        
        src = ' '.join(dataset.src_vocab.lookup_tokens(src))
        trg = dataset.trg_vocab.lookup_tokens(trg)

        prediction = translate(src, dataset, model, attention, device)
        prediction = prediction[1:-1]  # remove <eos> token
        
        targets.append([trg])
        outputs.append(prediction)

    return bleu_score(outputs, targets)

In [20]:
bleu(model_l, nmtds_test, True, device)

  0%|                                                                                                                                                                                                                                                                                                                                              | 0/1000 [00:00<?, ?it/s]

torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size

  0%|▎                                                                                                                                                                                                                                                                                                                                     | 1/1000 [00:00<11:36,  1.43it/s]

torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size

  0%|▋                                                                                                                                                                                                                                                                                                                                     | 2/1000 [00:01<10:58,  1.52it/s]

torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size

  0%|▉                                                                                                                                                                                                                                                                                                                                     | 3/1000 [00:01<10:46,  1.54it/s]

torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])


  0%|▉                                                                                                                                                                                                                                                                                                                                     | 3/1000 [00:02<11:38,  1.43it/s]


torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])
torch.Size([1, 7853])


KeyboardInterrupt: 

In [7]:

import torch
a = torch.tensor([  [0.1, 0.7, 0.1, 0.1],
                    [0.7, 0.1, 0.1, 0.1],
                    [0.1, 0.1, 0.6, 0.2],
                    [0.1, 0.1, 0.1, 0.7],
                    [0.4, 0.3, 0.2, 0.1]])
a.shape

torch.Size([5, 4])

In [3]:
a.shape

torch.Size([100, 32])

In [5]:
a.transpose(0, 1).shape

torch.Size([32, 100])