In [1]:
# ALL IMPORTS FOR A NEW NOTEBOOK

import os, sys, random, math
import numpy as np
import pandas as pd
import matplotlib.pylab as plt
import seaborn as sns
import itertools as it
import scipy
import glob
import matplotlib
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset,DataLoader
from torch.optim import Optimizer
import torchvision.transforms.transforms as txf
import torch.optim.lr_scheduler as lr_scheduler
from collections import OrderedDict

from sklearn import metrics
from sklearn import preprocessing as pp
from sklearn import model_selection as ms

import torch_utils
from tqdm.notebook import tqdm_notebook as tqdm
import time

import torchtext
from torchtext import data, datasets
import spacy


font = {'size'   : 20}

matplotlib.rc('font', **font)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
SEED = 947
torch_utils.seed_everything(SEED)

In [3]:
spacy_de = spacy.load("de")
spacy_en = spacy.load("en")

def tokenize_de(txt):
    return [tok.text for tok in spacy_de.tokenizer(txt)]
def tokenize_en(txt):
    return [tok.text for tok in spacy_en.tokenizer(txt)]

In [4]:
SRC = data.Field(tokenize=tokenize_de, init_token="<sos>", eos_token="<eos>", lower=True, batch_first=True)
TRG = data.Field(tokenize=tokenize_en, init_token="<sos>", eos_token="<eos>", lower=True, batch_first=True)

In [5]:
train_data, valid_data, test_data = datasets.Multi30k.splits(exts=(".de",".en"), fields=(SRC, TRG))

In [6]:
SRC.build_vocab(train_data, min_freq=2)
TRG.build_vocab(train_data, min_freq=2)

In [7]:
BATCH_SIZE = 128
train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(
    (train_data, valid_data, test_data),
    batch_size=BATCH_SIZE,
    device=device
)

In [8]:
class PositionwiseFeedForwardLayer(nn.Module):
    def __init__(self, hid_dim, pff_dim, dropout):
        super(PositionwiseFeedForwardLayer, self).__init__()
        
        self.fc1 = nn.Linear(hid_dim, pff_dim)
        self.fc2 = nn.Linear(pff_dim, hid_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        # x -- batch_size X seq_len X hid_dim
        x = self.dropout(torch.relu(self.fc1(x)))
        # x -- batch_size X seq_len X pff_dim
        return self.fc2(x)

In [9]:
class SelfAttentionLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, dropout, device):
        super(SelfAttentionLayer, self).__init__()
        
        assert hid_dim% n_heads == 0, "hidden dimension must be divisible into n_heads"
        
        self.hid_dim = hid_dim
        self.n_heads = n_heads
        self.head_dim = hid_dim//n_heads
        
        # for projection of the keys, values and queries
        self.fc_q = nn.Linear(hid_dim, hid_dim)
        self.fc_k = nn.Linear(hid_dim, hid_dim)
        self.fc_v = nn.Linear(hid_dim, hid_dim)
        
        self.fc_final = nn.Linear(hid_dim, hid_dim)
        
        self.dropout = nn.Dropout(dropout)
        
        self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)
    
    def forward(self, query, key, value, mask=None):
        # query -- batch_size X query_len X hid_dim
        # key -- batch_size X key_len X hid_dim
        # value -- batch_size X value_len X hid_dim
        batch_size = query.shape[0]
        
        # projections
        Q = self.fc_q(query)
        K = self.fc_k(key)
        V = self.fc_v(value)
        
        
        # split into heads, basically we're dividing hid_dim into n_heads*head_dim
        # we need all of these vectors to be of , batch_size X n_heads X seq_len X head_dim, so permute
        
        Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        
        # now calculate energy/overlap between K and Q, Q*K/dk
        energy = torch.matmul(Q, K.permute(0,1,3,2))/self.scale
        
        if mask is not None:
            energy = energy.masked_fill(mask==0, -1e10)
        
        attention = self.dropout(torch.softmax(energy, dim=-1))
        # attention -- batch_size X n_heads X query_len X key_len
        x = torch.matmul(attention, V)
        # x --  batch_size X n_heads X query_len X head_dim
        x = x.permute(0, 2, 1, 3).contiguous()
        x = x.view(batch_size, -1, self.hid_dim)
        # concatenated n_heads
        
        return self.fc_final(x), attention

In [10]:
class EncoderLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, pff_dim, self_attention_layer, position_wise_feedforward_layer, dropout, device):
        super(EncoderLayer, self).__init__()
        
        self.layer_norm = nn.LayerNorm(hid_dim)
        self.self_attention = self_attention_layer(hid_dim, n_heads, dropout, device)
        self.positionwise_feedforward = position_wise_feedforward_layer(hid_dim, pff_dim, dropout)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, src, src_mask):
        _src, _ = self.self_attention(src, src, src, src_mask)
        src = self.layer_norm(src+self.dropout(_src))
        src = self.layer_norm(src+self.dropout(self.positionwise_feedforward(src)))
        return src

In [11]:
class Encoder(nn.Module):
    def __init__(self, input_dim, hid_dim, n_layers, n_heads, pf_dim, encoder_layer, self_attention_layer, positionwise_ff_layer, dropout, device):
        super(Encoder, self).__init__()
        self.device = device
        self.tok_embedding = nn.Embedding(input_dim, hid_dim)
        self.pos_embedding = torch_utils.PositionalEncoding(hid_dim, dropout)
        
        self.layers = nn.ModuleList([
            encoder_layer(hid_dim, n_heads, pf_dim, self_attention_layer, positionwise_ff_layer, dropout, device)
            for _ in range(n_layers)
        ])
        self.dropout = nn.Dropout(dropout)
        self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
    
    def forward(self, src, src_mask):
        batch_size, src_len = src.shape
        src_tok = self.tok_embedding(src) * self.scale
        src = self.pos_embedding(src_tok)
        for layer in self.layers:
            src = layer(src, src_mask)
        return src

In [12]:
class DecoderLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, pf_dim, self_attention_layer, positionwise_feedforward_layer, dropout, device):
        super(DecoderLayer, self).__init__()
        
        self.layer_norm = nn.LayerNorm(hid_dim)
        self.self_attention = self_attention_layer(hid_dim, n_heads, dropout, device)
        self.encoder_attention = self_attention_layer(hid_dim, n_heads, dropout, device)
        self.positionwise_ffn = positionwise_feedforward_layer(hid_dim, pf_dim, dropout)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, trg, enc_src, trg_mask, src_mask):
        _trg, _ = self.self_attention(trg, trg, trg, trg_mask)
        trg = self.layer_norm(trg+self.dropout(_trg))
        _trg, attention = self.encoder_attention(trg, enc_src, enc_src, src_mask)
        trg = self.layer_norm(trg+self.dropout(_trg))
        trg = self.layer_norm(trg+self.dropout(self.positionwise_ffn(trg)))
        
        return trg, attention

In [13]:
class Decoder(nn.Module):
    def __init__(self, output_dim, hid_dim, n_layers, n_heads, pf_dim, decoder_layer, self_attention_layer, positionwise_feedforward_layer, dropout, device):
        super(Decoder, self).__init__()
        
        self.device = device
        self.tok_embedding = nn.Embedding(output_dim, hid_dim)
        
        self.pos_embedding = torch_utils.PositionalEncoding(hid_dim, dropout)
        
        self.layers = nn.ModuleList([decoder_layer(hid_dim, n_heads, pf_dim, self_attention_layer, positionwise_feedforward_layer, dropout, device)])
        self.fc_out = nn.Linear(hid_dim, output_dim)
        self.dropout = nn.Dropout(dropout)
        self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
    
    def forward(self, trg, enc_src, trg_mask, src_mask):
        batch_size, trg_len = trg.shape
        trg_embed = self.tok_embedding(trg)*self.scale
        trg = self.pos_embedding(trg_embed)
        for layer in self.layers:
            trg, attention = layer(trg, enc_src, trg_mask, src_mask)
        output = self.fc_out(trg)
        return output, attention

In [14]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, src_pad_idx, trg_pad_idx, trg_sos_idx, device):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.trg_sos_idx = trg_sos_idx
        self.device = device
    
    def make_src_mask(self, src):
        src_mask = (src!=self.src_pad_idx).unsqueeze(dim=1).unsqueeze(dim=2)
        # src_mask -- batch_size X 1 X 1 X src_len
        return src_mask
    def make_trg_mask(self, trg):
        trg_pad_mask = (trg!=self.trg_pad_idx).unsqueeze(dim=1).unsqueeze(dim=3)
        # trg_pad_mask --  batch_size X 1 X trg_len X 1
        trg_len = trg.shape[1]
        trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device = self.device)).bool()
        trg_mask = trg_pad_mask & trg_sub_mask
        return trg_mask
    
    def forward(self, src, trg):
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)
        
        enc_src = self.encoder(src, src_mask)
        
        output, attention = self.decoder(trg, enc_src, trg_mask, src_mask)
        
        return output, attention

In [15]:
def init_weights(m):
    if hasattr(m, "weight") and m.weight.dim()>1:
        nn.init.xavier_uniform_(m.weight.data)

In [16]:
def train(model, iterator, optimizer, criterion, clip):
    model.train()
    el = 0
    for batch in tqdm(iterator):
        src = batch.src
        trg = batch.trg
        
        optimizer.zero_grad()
        
        output, _ = model(src, trg[:,:-1])
        output_dim = output.shape[-1]
        
        output = output.contiguous().view(-1, output_dim)
        trg = trg[:, 1:].contiguous().view(-1)
        
        loss = criterion(output, trg)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        
        el+=loss.item()
    
    return el/len(iterator)

In [17]:
def evaluate(model, iterator, criterion):
    model.eval()
    el = 0
    with torch.no_grad():
        for batch in tqdm(iterator):
            src = batch.src
            trg = batch.trg

            output, _ = model(src, trg[:,:-1])
            output_dim = output.shape[-1]

            output = output.contiguous().view(-1, output_dim)
            trg = trg[:, 1:].contiguous().view(-1)

            loss = criterion(output, trg)

            el+=loss.item()
    
    return el/len(iterator)

In [18]:
INP_DIM = len(SRC.vocab)
OUT_DIM = len(TRG.vocab)
HID_DIM = 256
ENC_LAYERS = 3
DEC_LAYERS = 3
ENC_HEADS = 8
DEC_HEADS = 8
ENC_PF_DIM = 512
DEC_PF_DIM = 512
ENC_DROPOUT = 0.15
DEC_DROPOUT = 0.15

SRC_PAD_IDX = SRC.vocab.stoi[SRC.pad_token]
TRG_PAD_IDX = TRG.vocab.stoi[TRG.pad_token]
TRG_SOS_IDX = TRG.vocab.stoi[TRG.init_token]


enc = Encoder(INP_DIM, HID_DIM, ENC_LAYERS, ENC_HEADS, ENC_PF_DIM, EncoderLayer, SelfAttentionLayer, PositionwiseFeedForwardLayer, ENC_DROPOUT, device)
dec = Decoder(OUT_DIM, HID_DIM, DEC_LAYERS, DEC_HEADS, DEC_PF_DIM, DecoderLayer, SelfAttentionLayer, PositionwiseFeedForwardLayer, DEC_DROPOUT, device)
model = Seq2Seq(enc, dec, SRC_PAD_IDX,TRG_PAD_IDX, TRG_SOS_IDX, device).to(device).apply(init_weights)

In [19]:
print(f"TOTAL PARAMS = {torch_utils.count_model_params(model)}")

TOTAL PARAMS = 7403525


In [20]:
criterion = nn.CrossEntropyLoss(ignore_index=TRG_PAD_IDX)
optimizer = torch_utils.RAdam(model.parameters(), lr=0.0005)

In [None]:
N_EPOCHS = 30
CLIP = 1
history = pd.DataFrame()
best_model = "ATTN_IS_ALL_YOU_NEED.pt"
ea = torch_utils.EarlyStopping(patience=5, verbose=True, save_model_name=best_model)

for e in range(N_EPOCHS):
    st = time.time()
    tl = train(model, train_iterator, optimizer, criterion, CLIP)
    vl = evaluate(model, valid_iterator, criterion)
    
    TPL = np.exp(tl)
    VPL = np.exp(vl)
    ea(VPL, model)
    
    torch_utils.print_epoch_stat(e, time.time()-st, history, tl, valid_loss=vl)
    
    print(f"\t\tTPL : \t\t{TPL:0.5f}")
    print(f"\t\tVPL : \t\t{VPL:0.5f}")
    
    if ea.early_stop:
        print("STOPPING EARLY!")
        break

HBox(children=(IntProgress(value=0, max=227), HTML(value='')))




HBox(children=(IntProgress(value=0, max=8), HTML(value='')))


Found better solution (inf --> 75.665609).  Saving model ...


EPOCH 1 Completed, Time Taken: 0:00:12.269908
	Train Loss 	6.22600696
	Valid Loss 	4.32632375
		TPL : 		505.73204
		VPL : 		75.66561


HBox(children=(IntProgress(value=0, max=227), HTML(value='')))




HBox(children=(IntProgress(value=0, max=8), HTML(value='')))


Found better solution (75.665609 --> 28.065509).  Saving model ...


EPOCH 2 Completed, Time Taken: 0:00:12.110225
	Train Loss 	3.97044868
	Valid Loss 	3.33454138
		TPL : 		53.00831
		VPL : 		28.06551


HBox(children=(IntProgress(value=0, max=227), HTML(value='')))




HBox(children=(IntProgress(value=0, max=8), HTML(value='')))


Found better solution (28.065509 --> 14.944775).  Saving model ...


EPOCH 3 Completed, Time Taken: 0:00:12.509177
	Train Loss 	3.2455308
	Valid Loss 	2.70436174
		TPL : 		25.67533
		VPL : 		14.94477


HBox(children=(IntProgress(value=0, max=227), HTML(value='')))




HBox(children=(IntProgress(value=0, max=8), HTML(value='')))


Found better solution (14.944775 --> 10.191247).  Saving model ...


EPOCH 4 Completed, Time Taken: 0:00:11.838966
	Train Loss 	2.74037308
	Valid Loss 	2.32152918
		TPL : 		15.49276
		VPL : 		10.19125


HBox(children=(IntProgress(value=0, max=227), HTML(value='')))




HBox(children=(IntProgress(value=0, max=8), HTML(value='')))


Found better solution (10.191247 --> 7.982975).  Saving model ...


EPOCH 5 Completed, Time Taken: 0:00:11.829717
	Train Loss 	2.37884255
	Valid Loss 	2.07731114
		TPL : 		10.79240
		VPL : 		7.98297


HBox(children=(IntProgress(value=0, max=227), HTML(value='')))




HBox(children=(IntProgress(value=0, max=8), HTML(value='')))


Found better solution (7.982975 --> 6.774808).  Saving model ...


EPOCH 6 Completed, Time Taken: 0:00:11.729448
	Train Loss 	2.10597416
	Valid Loss 	1.91321108
		TPL : 		8.21510
		VPL : 		6.77481


HBox(children=(IntProgress(value=0, max=227), HTML(value='')))




HBox(children=(IntProgress(value=0, max=8), HTML(value='')))


Found better solution (6.774808 --> 6.038275).  Saving model ...


EPOCH 7 Completed, Time Taken: 0:00:12.088450
	Train Loss 	1.892443
	Valid Loss 	1.79811844
		TPL : 		6.63556
		VPL : 		6.03828


HBox(children=(IntProgress(value=0, max=227), HTML(value='')))




HBox(children=(IntProgress(value=0, max=8), HTML(value='')))


Found better solution (6.038275 --> 5.581215).  Saving model ...


EPOCH 8 Completed, Time Taken: 0:00:12.055277
	Train Loss 	1.72003912
	Valid Loss 	1.71940653
		TPL : 		5.58475
		VPL : 		5.58122


HBox(children=(IntProgress(value=0, max=227), HTML(value='')))




HBox(children=(IntProgress(value=0, max=8), HTML(value='')))


Found better solution (5.581215 --> 5.305896).  Saving model ...


EPOCH 9 Completed, Time Taken: 0:00:11.700305
	Train Loss 	1.57769993
	Valid Loss 	1.66881858
		TPL : 		4.84380
		VPL : 		5.30590


HBox(children=(IntProgress(value=0, max=227), HTML(value='')))




HBox(children=(IntProgress(value=0, max=8), HTML(value='')))


Found better solution (5.305896 --> 5.133511).  Saving model ...


EPOCH 10 Completed, Time Taken: 0:00:11.627280
	Train Loss 	1.45890912
	Valid Loss 	1.63578977
		TPL : 		4.30126
		VPL : 		5.13351


HBox(children=(IntProgress(value=0, max=227), HTML(value='')))




HBox(children=(IntProgress(value=0, max=8), HTML(value='')))


Found better solution (5.133511 --> 4.955053).  Saving model ...


EPOCH 11 Completed, Time Taken: 0:00:12.433437
	Train Loss 	1.35543368
	Valid Loss 	1.60040787
		TPL : 		3.87844
		VPL : 		4.95505


HBox(children=(IntProgress(value=0, max=227), HTML(value='')))




HBox(children=(IntProgress(value=0, max=8), HTML(value='')))


Found better solution (4.955053 --> 4.915455).  Saving model ...


EPOCH 12 Completed, Time Taken: 0:00:11.172405
	Train Loss 	1.26334918
	Valid Loss 	1.59238429
		TPL : 		3.53725
		VPL : 		4.91545


HBox(children=(IntProgress(value=0, max=227), HTML(value='')))




HBox(children=(IntProgress(value=0, max=8), HTML(value='')))


EarlyStopping counter: 1 out of 5


EPOCH 13 Completed, Time Taken: 0:00:11.754394
	Train Loss 	1.1882219
	Valid Loss 	1.59650153
		TPL : 		3.28124
		VPL : 		4.93573


HBox(children=(IntProgress(value=0, max=227), HTML(value='')))

In [None]:
ax = history["train_loss"].plot()
history["valid_loss"].plot(ax=ax)

In [None]:
test_loss = evaluate(model, test_iterator, criterion)
print(f"Possibly overfitted model: Loss = {test_loss:0.5f} | PPL = {np.exp(test_loss):0.5f}")

In [None]:
model.load_state_dict(torch.load(best_model, map_location=device))

In [None]:
test_loss = evaluate(model, test_iterator, criterion)
print(f"Best model: Loss = {test_loss:0.5f} | PPL = {np.exp(test_loss):0.5f}")

In [None]:
def translate_sentence(sentence, src_field, trg_field, model, device, max_len = 50):
    
    model.eval()
        
    if isinstance(sentence, str):
        nlp = spacy.load('de')
        tokens = [token.text.lower() for token in nlp(sentence)]
    else:
        tokens = [token.lower() for token in sentence]

    tokens = [src_field.init_token] + tokens + [src_field.eos_token]
        
    src_indexes = [src_field.vocab.stoi[token] for token in tokens]

    src_tensor = torch.LongTensor(src_indexes).unsqueeze(0).to(device)
    
    src_mask = model.make_src_mask(src_tensor)
    
    with torch.no_grad():
        enc_src = model.encoder(src_tensor, src_mask)

    trg_indexes = [trg_field.vocab.stoi[trg_field.init_token]]

    for i in range(max_len):

        trg_tensor = torch.LongTensor(trg_indexes).unsqueeze(0).to(device)

        trg_mask = model.make_trg_mask(trg_tensor)
        
        with torch.no_grad():
            output, attention = model.decoder(trg_tensor, enc_src, trg_mask, src_mask)
        
        pred_token = output.argmax(2)[:,-1].item()
        
        trg_indexes.append(pred_token)

        if pred_token == trg_field.vocab.stoi[trg_field.eos_token]:
            break
    
    trg_tokens = [trg_field.vocab.itos[i] for i in trg_indexes]
    
    return trg_tokens[1:], attention

In [None]:
from matplotlib import ticker
def display_attention(sentence, translation, attention, n_heads = 8, n_rows = 4, n_cols = 2):
    
    assert n_rows * n_cols == n_heads
    
    fig = plt.figure(figsize=(15,25))
    
    for i in range(n_heads):
        
        ax = fig.add_subplot(n_rows, n_cols, i+1)
        
        _attention = attention.squeeze(0)[i].cpu().detach().numpy()

        cax = ax.matshow(_attention, cmap='bone')

        ax.tick_params(labelsize=12)
        ax.set_xticklabels(['']+['<sos>']+[t.lower() for t in sentence]+['<eos>'], 
                           rotation=45)
        ax.set_yticklabels(['']+translation)

        ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
        ax.yaxis.set_major_locator(ticker.MultipleLocator(1))

    plt.show()
    plt.close()

In [None]:
example_idx = 5
src = vars(train_data.examples[example_idx])["src"]
trg = vars(train_data.examples[example_idx])["trg"]
print(f"SRC == {src}")
print(f"TRG == {trg}")
translation, attention = translate_sentence(src, SRC, TRG, model, device)
print(f"TRNSLATED == {translation}")
display_attention(src, translation, attention)

In [None]:
example_idx = 6
src = vars(valid_data.examples[example_idx])["src"]
trg = vars(valid_data.examples[example_idx])["trg"]
print(f"SRC == {src}")
print(f"TRG == {trg}")
translation, attention = translate_sentence(src, SRC, TRG, model, device)
print(f"TRNSLATED == {translation}")
display_attention(src, translation, attention)

In [None]:
example_idx = 10
src = vars(test_data.examples[example_idx])["src"]
trg = vars(test_data.examples[example_idx])["trg"]
print(f"SRC == {src}")
print(f"TRG == {trg}")
translation, attention = translate_sentence(src, SRC, TRG, model, device)
print(f"TRNSLATED == {translation}")
display_attention(src, translation, attention)

In [None]:
def calculate_bleu(data, src_field, trg_field, model, device, max_len = 50):
    
    trgs = []
    pred_trgs = []
    
    for datum in tqdm(data):
        
        src = vars(datum)['src']
        trg = vars(datum)['trg']
        
        pred_trg, _ = translate_sentence(src, src_field, trg_field, model, device, max_len)
        
        #cut off <eos> token
        pred_trg = pred_trg[:-1]
        
        pred_trgs.append(pred_trg)
        trgs.append([trg])
        
    return torch_utils.bleu_score(pred_trgs, trgs)

In [None]:
100.0*calculate_bleu(test_data, SRC, TRG, model, device)