In [26]:
## ALL IMPORTS FOR A NEW NOTEBOOK

import os, sys, random, math
import numpy as np
import pandas as pd
import matplotlib.pylab as plt
import seaborn as sns
import itertools as it
import scipy
import glob
import matplotlib
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset,DataLoader
from torch.optim import Optimizer
import torchvision.transforms.transforms as txf
import torch.optim.lr_scheduler as lr_scheduler
from collections import OrderedDict

from sklearn import metrics
from sklearn import preprocessing as pp
from sklearn import model_selection as ms

import torch_utils
from tqdm.notebook import tqdm_notebook as tqdm
import time

font = {'size'   : 20}

matplotlib.rc('font', **font)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [27]:
SEED = 947
torch_utils.seed_everything(SEED)

In [28]:
import spacy
spacy_de = spacy.load("de")
spacy_en = spacy.load("en")

In [29]:
def tokenize_de(txt):
    return [tok.text for tok in spacy_de.tokenizer(txt)]
def tokenize_en(txt):
    return [tok.text for tok in spacy_en.tokenizer(txt)]

In [30]:
from torchtext import data, datasets

In [31]:
SRC = data.Field(tokenize=tokenize_de,
                 init_token="<sos>",
                 eos_token="<eos>",
                 lower=True
                )
TRG = data.Field(tokenize=tokenize_en,
                 init_token="<sos>",
                 eos_token="<eos>",
                 lower=True
                )

In [32]:
train_data, valid_data, test_data = datasets.Multi30k.splits(exts=(".de", ".en"),fields=(SRC, TRG))

In [33]:
SRC.build_vocab(train_data, min_freq=2)
TRG.build_vocab(train_data, min_freq=2)

In [46]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 128
train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(
    (train_data, valid_data, test_data),
    batch_sizes=(BATCH_SIZE,BATCH_SIZE,BATCH_SIZE),
    device=device
)

In [47]:
class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hid_dim, dropout):
        super(Encoder, self).__init__()
        self.hid_dim = hid_dim
        
        self.embedding = nn.Embedding(input_dim, emb_dim)
        self.rnn = nn.GRU(emb_dim, hid_dim)
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, src):
        embedded = self.embedding(src)
        embedded = self.dropout(embedded)
        ot, ht = self.rnn(embedded)
        return ht

In [60]:
class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, hid_dim, dropout):
        super(Decoder, self).__init__()
        self.hid_dim = hid_dim
        self.output_dim = output_dim
        
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.rnn = nn.GRU(emb_dim+hid_dim, hid_dim)
        self.fc_out = nn.Linear(emb_dim+2*hid_dim, output_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, trg, ht, cntx):
        # trg = batch_size
        trg = trg.unsqueeze(dim=0)
        # trg = (1, batch_size)
        embedded = self.embedding(trg)
        # embedded = (1, batch_size, embed_size)
        embedded = self.dropout(embedded)
        # embedded = (1, batch_size, embed_size)
        emb_con = torch.cat((embedded, cntx), dim=2)
        # emb_con = (1, batch_size, embed_size+hid_dim)
        ot, ht = self.rnn(emb_con, ht)
        # ot = (1, batch_size, hidden_dim)
        # ht = (1, batch_size, hidden_dim)
#         print(ht==ot)
        ot = torch.cat((embedded.squeeze(dim=0), cntx.squeeze(dim=0), ht.squeeze(dim=0)), dim=1)
        
        # ot = (batch_size, emb_dim+2*hid_dim)
        preds = self.fc_out(ot)
        
        return preds, ht

In [61]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super(Seq2Seq, self).__init__()
        
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
    
    
    def forward(self, src, trg, teacher_forcing_ratio=0.5):
        batch_size = trg.shape[1]
        trg_len = trg.shape[0]
        trg_vocab_size = self.decoder.output_dim
        
        outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)
        
        context = self.encoder(src)
        
        hidden = context
        
        current_token = trg[0,:]
        
        for t in range(1, trg_len):
            ot, ht = self.decoder(current_token, hidden, context)
            outputs[t] = ot
            teacher_force = random.random()<teacher_forcing_ratio
            
            top1 = ot.argmax(1)
            
            current_token = trg[t] if teacher_force else top1
        
        return outputs

In [62]:
def train(model, iterator, optimizer, criterion, clip):
    model.train()
    l = 0
    for i, batch in enumerate(tqdm(iterator)):
        src = batch.src
        trg = batch.trg
        optimizer.zero_grad()
        outputs = model(src, trg)
        output_dim = outputs.shape[-1]
        outputs = outputs[1:].view(-1, output_dim)
        trg = trg[1:].view(-1)
        loss = criterion(outputs, trg)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        l+=loss.item()
    
    return l/len(iterator)

In [63]:
def evaluate(model, iterator, criterion):
    model.eval()
    l = 0
    with torch.no_grad():
        for i, batch in enumerate(tqdm(iterator)):
            src = batch.src
            trg = batch.trg
            outputs = model(src, trg)
            output_dim = outputs.shape[-1]
            outputs = outputs[1:].view(-1, output_dim)
            trg = trg[1:].view(-1)
            loss = criterion(outputs, trg)
            l+=loss.item()
    
    return l/len(iterator)

In [86]:
def init_weights(m):
    for name, param in m.named_parameters():
        nn.init.normal_(param.data, mean=0, std=0.01)

In [87]:
INPUT_DIM = len(SRC.vocab)
OUTPUT_DIM = len(TRG.vocab)
ENC_EMB_DIM = 256
DEC_EMB_DIM = 256
HID_DIM = 1024
ENC_DROPOUT = 0.3
DEC_DROPOUT = 0.3

enc = Encoder(INPUT_DIM, ENC_EMB_DIM, HID_DIM, ENC_DROPOUT)
dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, HID_DIM, DEC_DROPOUT)

torch_utils.clear_cuda()

model = Seq2Seq(enc, dec, device).to(device)
model = model.apply(init_weights)

In [88]:
torch_utils.count_model_params(model)

28125189

In [None]:
optimizer = torch_utils.RAdam(model.parameters())
criterion = nn.CrossEntropyLoss(ignore_index=TRG.vocab.stoi[TRG.pad_token])
N_EPOCHS = 30
CLIP =1
ea = torch_utils.EarlyStopping(patience=5, verbose=True)
history = pd.DataFrame()

for e in range(N_EPOCHS):
    torch_utils.clear_cuda()
    st = time.time()
    tl = train(model, train_iterator, optimizer, criterion, CLIP)
    vl = evaluate(model, valid_iterator, criterion)
    tpl = math.exp(tl)
    vpl = math.exp(vl)
    
    ea(vpl, model)
    
    torch_utils.print_epoch_stat(e, time.time()-st, history, tl, valid_loss=vl)
    print("\t\tTPL:\t{:0.5}".format(tpl))
    print("\t\tVPL:\t{:0.5}".format(vpl))
    
    history.loc[e, "TPL"] = tpl
    history.loc[e, "VPL"] = vpl
    
    if ea.early_stop:
        print("STOPPING EARLY")
        break

HBox(children=(IntProgress(value=0, max=227), HTML(value='')))




HBox(children=(IntProgress(value=0, max=8), HTML(value='')))


Found better solution (inf --> 190.549845).  Saving model ...


EPOCH 1 Completed, Time Taken: 0:00:42.501535
	Train Loss 	5.87053417
	Valid Loss 	5.24991381
		TPL:	354.44
		VPL:	190.55


HBox(children=(IntProgress(value=0, max=227), HTML(value='')))




HBox(children=(IntProgress(value=0, max=8), HTML(value='')))


Found better solution (190.549845 --> 108.096790).  Saving model ...


EPOCH 2 Completed, Time Taken: 0:00:42.268514
	Train Loss 	4.9172923
	Valid Loss 	4.68302703
		TPL:	136.63
		VPL:	108.1


HBox(children=(IntProgress(value=0, max=227), HTML(value='')))




HBox(children=(IntProgress(value=0, max=8), HTML(value='')))


Found better solution (108.096790 --> 87.039704).  Saving model ...


EPOCH 3 Completed, Time Taken: 0:00:42.456762
	Train Loss 	4.56937474
	Valid Loss 	4.46636438
		TPL:	96.484
		VPL:	87.04


HBox(children=(IntProgress(value=0, max=227), HTML(value='')))




HBox(children=(IntProgress(value=0, max=8), HTML(value='')))


Found better solution (87.039704 --> 64.644272).  Saving model ...


EPOCH 4 Completed, Time Taken: 0:00:41.784499
	Train Loss 	4.35311126
	Valid Loss 	4.16889951
		TPL:	77.72
		VPL:	64.644


HBox(children=(IntProgress(value=0, max=227), HTML(value='')))




HBox(children=(IntProgress(value=0, max=8), HTML(value='')))


Found better solution (64.644272 --> 55.267418).  Saving model ...


EPOCH 5 Completed, Time Taken: 0:00:42.438097
	Train Loss 	4.07594828
	Valid Loss 	4.01218355
		TPL:	58.906
		VPL:	55.267


HBox(children=(IntProgress(value=0, max=227), HTML(value='')))




HBox(children=(IntProgress(value=0, max=8), HTML(value='')))


Found better solution (55.267418 --> 37.129977).  Saving model ...


EPOCH 6 Completed, Time Taken: 0:00:42.373654
	Train Loss 	3.77059636
	Valid Loss 	3.61442465
		TPL:	43.406
		VPL:	37.13


HBox(children=(IntProgress(value=0, max=227), HTML(value='')))




HBox(children=(IntProgress(value=0, max=8), HTML(value='')))


Found better solution (37.129977 --> 30.008110).  Saving model ...


EPOCH 7 Completed, Time Taken: 0:00:41.869239
	Train Loss 	3.45387409
	Valid Loss 	3.40146768
		TPL:	31.623
		VPL:	30.008


HBox(children=(IntProgress(value=0, max=227), HTML(value='')))




HBox(children=(IntProgress(value=0, max=8), HTML(value='')))


Found better solution (30.008110 --> 28.399076).  Saving model ...


EPOCH 8 Completed, Time Taken: 0:00:41.904991
	Train Loss 	3.15946794
	Valid Loss 	3.3463566
		TPL:	23.558
		VPL:	28.399


HBox(children=(IntProgress(value=0, max=227), HTML(value='')))




HBox(children=(IntProgress(value=0, max=8), HTML(value='')))


Found better solution (28.399076 --> 22.060497).  Saving model ...


EPOCH 9 Completed, Time Taken: 0:00:42.501350
	Train Loss 	2.89212625
	Valid Loss 	3.09378853
		TPL:	18.032
		VPL:	22.06


HBox(children=(IntProgress(value=0, max=227), HTML(value='')))




HBox(children=(IntProgress(value=0, max=8), HTML(value='')))


EarlyStopping counter: 1 out of 5


EPOCH 10 Completed, Time Taken: 0:00:42.131136
	Train Loss 	2.6795998
	Valid Loss 	3.20038238
		TPL:	14.579
		VPL:	24.542


HBox(children=(IntProgress(value=0, max=227), HTML(value='')))




HBox(children=(IntProgress(value=0, max=8), HTML(value='')))


EarlyStopping counter: 2 out of 5


EPOCH 11 Completed, Time Taken: 0:00:42.397740
	Train Loss 	2.48063661
	Valid Loss 	3.13665703
		TPL:	11.949
		VPL:	23.027


HBox(children=(IntProgress(value=0, max=227), HTML(value='')))




HBox(children=(IntProgress(value=0, max=8), HTML(value='')))


EarlyStopping counter: 3 out of 5


EPOCH 12 Completed, Time Taken: 0:00:41.397969
	Train Loss 	2.35834154
	Valid Loss 	3.1004557
		TPL:	10.573
		VPL:	22.208


HBox(children=(IntProgress(value=0, max=227), HTML(value='')))

In [None]:
loss = evaluate(model, test_iterator, criterion)
print("LOSS: ", loss," PPL: ", math.exp(loss))

In [None]:
ax = history["train_loss"].plot()
history["valid_loss"].plot(ax=ax)

In [None]:
ax = history["TPL"].plot()
history["VPL"].plot(ax=ax)

In [None]:
history["VPL"].plot(ax=ax)

In [None]:
model.load_state_dict(torch.load("checkpoint.pt", map_location=device))

In [None]:
loss = evaluate(model, test_iterator, criterion)
print("LOSS: ", loss," PPL: ", math.exp(loss))