### Load Libraries

In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [2]:
import torch
import math
import pandas as pd
import numpy as np
import torch.nn as nn
from torchtext import data, vocab
from torch import optim
from torchtext.vocab import Vectors
import torch.nn.functional as F

dev = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [3]:
#declare the directory path to dataset  
DATA_PATH = 'data/'
SAMPLE_DATA_PATH = f'{DATA_PATH}sample_data/'
PROCESSED_DATA_PATH = f'{DATA_PATH}processed_data/'

In [4]:
# enter the processed data absolute path here
path = f'{SAMPLE_DATA_PATH}'

# enter the absolute path to the embeddings
emb_path = 'glove/glove.6B.300d.txt'

### Prep source and target Sequences

In [5]:
%%time
tokenizer = data.get_tokenizer('spacy')
TEXT = data.Field(tokenize=tokenizer, lower=True, eos_token='_eos_')


trn_data_fields = [("source", TEXT),
                   ("target", TEXT)]

trn, vld = data.TabularDataset.splits(path=path,
                                     train='train.csv', validation='valid.csv',
                                     format='csv', skip_header=True, fields=trn_data_fields)

Wall time: 1.21 s


In [6]:
# a sample of the preprocessed data
print(trn[1].source)
print(trn[1].target)

['chris', 'martin', 'claimed', 'three', 'wickets', 'in', 'the', 'first', '#', '#', 'minutes', 'of', 'the', 'first', 'test', 'cricket', 'between', 'new', 'zealand', 'and', 'bangladesh', 'on', 'friday', 'to', 'help', 'send', 'the', 'tourists', 'to', 'lunch', 'at', '#', '#', 'for', 'four', '.']
['bangladesh', '#', '#', '-', '#', 'at', 'lunch', 'on', '#', 'st', 'day', '#', 'st', 'test']


### Load Embeddings 

In [7]:
def get_embs(file_path):
    word_embeddings = {}
    f = open(file_path, encoding='utf-8')
    for line in f:
        values = line.split()
        word = values[0]
        coefs = np.asarray(values[1:], dtype='float32')
        word_embeddings[word] = coefs
    f.close()
    return word_embeddings

# get the embeddings
vectors = get_embs(emb_path)

vec_obj = Vectors(emb_path)

In [8]:
%%time
TEXT.build_vocab(trn, vectors=vec_obj)

Wall time: 11 ms


In [9]:
#10 most frequent words in the vocab
TEXT.vocab.freqs.most_common(10)

[('#', 113),
 ('the', 103),
 ('.', 84),
 ('of', 71),
 (',', 68),
 ('a', 59),
 ('in', 55),
 ('to', 50),
 ('and', 36),
 ('-', 34)]

### Create the DataLoader

In [10]:
# set batch size
batch_size = 45

train_iter, val_iter = data.BucketIterator.splits(
                        (trn, vld), batch_sizes=(batch_size, int(batch_size*1.6)),
                        device='cuda' if torch.cuda.is_available() else "cpu", 
                        sort_key=lambda x: len(x.source),
                        shuffle=True, sort_within_batch=False, repeat=False)

In [11]:
# Dataloader class
class BatchTuple():
    def __init__(self, dataset, x_var, y_var):
        self.dataset, self.x_var, self.y_var = dataset, x_var, y_var
        
    def __iter__(self):
        for batch in self.dataset:
            x = getattr(batch, self.x_var) 
            y = getattr(batch, self.y_var)                 
            yield (x, y)
            
    def __len__(self):
        return len(self.dataset)

In [12]:
#returns tuple of article-title pair tensors
trn_dl = BatchTuple(train_iter, "source", "target")
val_dl = BatchTuple(val_iter, "source", "target")

In [13]:
x, y = next(trn_dl.__iter__())
x.size(), y.size()

(torch.Size([48, 19]), torch.Size([14, 19]))

In [14]:
#lets look at an example pair
sample_source = x.transpose(1,0)[0].data.cpu().numpy()
sample_target = y.transpose(1,0)[0].data.cpu().numpy()

print("source:\n%s \n\ncorresponding tensor:\n%s \n" %(' '.join([TEXT.vocab.itos[o] for o in sample_source]), sample_source))
print("target:\n%s \n\ncorresponding tensor:\n%s \n" %(' '.join([TEXT.vocab.itos[o] for o in sample_target]), sample_target))

source:
the morning began with an embrace and negotiations involving pat conroy and his latest novel , `` beach music . _eos_ <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> 

corresponding tensor:
[  4 813 471  26  22 623  11 822 740 867 544  11  23  87 836   7  29 469
 817   5   2   1   1   1   1   1   1   1   1   1   1   1   1   1   1   1
   1   1   1   1   1   1   1   1   1   1   1   1] 

target:
a long schmooze with booksellers _eos_ <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> 

corresponding tensor:
[  8  89 960  26 482   2   1   1   1   1   1   1   1   1] 



In [15]:
#set maximum target summary size 
its = [next(trn_dl.__iter__())[1] for i in range(10)]
max_tgt_len = int(np.percentile([its[o].size()[0] for o in range(len(its))], 99))
max_tgt_len

16

### Define Model: Seq2Seq model with bi-GRU(RNN), added  teacher forking and Attention mechanism

In [16]:
def create_emb(vecs, itos, em_sz):
    emb = nn.Embedding(len(itos), em_sz, padding_idx=1)
    wgts = emb.weight.data
    miss = []
    for i,w in enumerate(itos):
        try: wgts[i] = torch.from_numpy(vecs[w])
        except: miss.append(w)
    print(len(miss),miss[5:10])
    return emb

def rand_t(*sz): return torch.randn(sz)/math.sqrt(sz[0])
def rand_p(*sz): return nn.Parameter(rand_t(*sz))

class Seq2SeqAttnBiRNN(nn.Module):
    def __init__(self, vecs, itos, em_sz, nh, out_sl, nl=2):
        super().__init__()
        self.emb_enc = create_emb(vecs, itos, em_sz)
        self.nl,self.nh,self.out_sl = nl,nh,out_sl
        self.gru_enc = nn.GRU(em_sz, nh, num_layers=nl, 
                              dropout=0.25, bidirectional=True)
        self.out_enc = nn.Linear(nh*2, em_sz, bias=False)
        
        self.drop_enc = nn.Dropout(0.25)
        self.emb_dec = nn.Embedding(len(itos), em_sz, padding_idx=1)
        self.gru_dec = nn.GRU(em_sz, em_sz, num_layers=nl, 
                              dropout=0.1)
        self.emb_enc_drop = nn.Dropout(0.15)
        self.out_drop = nn.Dropout(0.35)
        self.out = nn.Linear(em_sz, len(itos))
        
        self.emb_dec.weight  = self.emb_enc.weight
        self.out.weight.data = self.emb_dec.weight.data
        
        self.W1 = rand_p(nh*2, em_sz)
        self.l2 = nn.Linear(em_sz, em_sz)
        self.l3 = nn.Linear(em_sz+nh*2, em_sz)
        self.V = rand_p(em_sz)
        self.pr_force = 1.
        
    def forward(self, inp, y=None):
        sl,bs = inp.size()
        h = self.initHidden(bs).to(dev)
        emb = self.emb_enc_drop(self.emb_enc(inp))
        enc_out, h = self.gru_enc(emb, h)
        h = h.view(2,2,bs,-1).permute(0,2,1,3).contiguous().view(2,bs,-1)
        h = self.out_enc(self.drop_enc(h))
        
        dec_inp = torch.zeros(bs).long().to(dev)
        res,attns = [],[]
        w1e = enc_out @ self.W1
        for i in range(self.out_sl):
            w2h = self.l2(h[-1])
            u = torch.tanh(w1e + w2h)
            a = F.softmax(u @ self.V, 0)
            attns.append(a)
            Xa = (a.unsqueeze(2) * enc_out).sum(0)
            emb = self.emb_dec(dec_inp)
            wgt_enc = self.l3(torch.cat([emb, Xa], 1))
            
            outp, h = self.gru_dec(wgt_enc.unsqueeze(0), h)
            outp = self.out(self.out_drop(outp[0]))
            res.append(outp)
            dec_inp = outp.data.max(1)[1]
            if (dec_inp==1).all(): break
            if (y is not None) and (np.random.random()<self.pr_force):
                if i>=len(y): break
                dec_inp = y[i].to(dev)
        return torch.stack(res)
    
    def initHidden(self, bs): 
        return torch.zeros(self.nl*2, bs, self.nh)

### Define loss and fit function

In [17]:
def seq2seq_loss(input, target):
    sl,bs = target.size()
    sl_in,bs_in,nc = input.size()
    if sl>sl_in: input = F.pad(input, (0,0,0,0,0,sl-sl_in))
    input = input[:sl]
    return F.cross_entropy(input.view(-1,nc), target.view(-1))#, ignore_index=1)


def loss_batch(model, loss_func, xb, yb, opt=None, inf=False):
    if inf: loss = loss_func(model(xb, yb), yb)
    else: loss = loss_func(model(xb), yb)

    if opt is not None:
        loss.backward()
        opt.step()
        opt.zero_grad()

    return loss.item(), len(xb)

def fit(epochs, model, loss_func, opt, train_dl, valid_dl):
    for epoch in range(epochs):
        model.train() # imp to put for traning and eval time for batchnorm and dropout
        tr_losses, tr_bsz = [], []
        model.pr_force = (10-epoch)*0.1 if epoch<10 else 0
        inf = True if epoch<10 else False
        for xb, yb in train_dl:
            loss, bs = loss_batch(model, loss_func, xb, yb, opt, inf=inf)
            tr_losses.append(loss); tr_bsz.append(bs)
            
        model.eval()
        #inf = False
        with torch.no_grad():
            losses, bsz = zip(*(loss_batch(model, loss_func, xb, yb, inf=True) for xb, yb in valid_dl))
        valid_loss = np.sum(np.multiply(losses, bsz)) / np.sum(bsz)
        train_loss = np.sum(np.multiply(tr_losses, tr_bsz)) / np.sum(tr_bsz)
        print(f"epoch:{epoch}", f'train_loss:{train_loss}',f'valid_loss:{valid_loss}') 

### Training

In [18]:
model = Seq2SeqAttnBiRNN(vectors, TEXT.vocab.itos, 300, 200, max_tgt_len).to(dev)
opt = optim.Adam(model.parameters(), lr=1e-2, weight_decay=1e-1, betas=(0.9, 0.999))
fit(20, model, seq2seq_loss, opt, trn_dl, val_dl)

9 ['a\\/h#n', '-member', '-nation', '-year']
epoch:0 train_loss:19.909557630430978 valid_loss:22.819141387939453
epoch:1 train_loss:23.67210442165159 valid_loss:18.611751556396484
epoch:2 train_loss:22.196086019839882 valid_loss:22.349088668823242
epoch:3 train_loss:16.74916620074578 valid_loss:20.890520095825195
epoch:4 train_loss:14.641716663654034 valid_loss:14.591445922851562
epoch:5 train_loss:10.947546203181428 valid_loss:11.754252433776855
epoch:6 train_loss:7.573131341200608 valid_loss:9.427680015563965
epoch:7 train_loss:7.496777407328287 valid_loss:13.132740020751953
epoch:8 train_loss:5.938907722257218 valid_loss:14.67599868774414
epoch:9 train_loss:5.9101277864896336 valid_loss:11.508200645446777
epoch:10 train_loss:6.2073125659294845 valid_loss:8.656683921813965
epoch:11 train_loss:5.421783034006754 valid_loss:8.24914264678955
epoch:12 train_loss:6.65495096468458 valid_loss:11.582647323608398
epoch:13 train_loss:5.79215853030865 valid_loss:8.480308532714844
epoch:14 train_

### Inference

In [19]:
def test_pred(dl, k=1):
    x,y = next(iter(dl))
    probs = model(x)
    preds = probs.max(2)[1].cpu().numpy()
    x = np.array([i.cpu().numpy() for i in x])
    y = np.array([i.cpu().numpy() for i in y])

    # Inference
    for i in range(k):
        print('src:\n')
        print(' '.join([TEXT.vocab.itos[o] for o in x[:,i] if o != 1]))
        print('\ntarget:\n')
        print(' '.join([TEXT.vocab.itos[o] for o in y[:,i] if o != 1]))
        print('\npred:\n')
        print(' '.join([TEXT.vocab.itos[o] for o in preds[:,i] if o!=1]))
        print()

In [20]:
test_pred(val_dl, 2)

src:

a light <unk> <unk> <unk> <unk> on wednesday , killing three people <unk> , the russian <unk> <unk> ministry said . _eos_

target:

light <unk> <unk> <unk> <unk> killing three _eos_

pred:

eu

src:

us <unk> of <unk> <unk> <unk> on sunday <unk> <unk> states to <unk> <unk> <unk> towards <unk> a peace <unk> with israel . _eos_

target:

<unk> <unk> <unk> states to <unk> <unk> for <unk> peace _eos_

pred:

china eu

