In [1]:
import json
import spacy
import en_core_web_sm
import numpy as np
import random
import pickle
from collections import defaultdict
import pandas as pd
from nltk.tokenize import RegexpTokenizer
import time


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as Func
import torch.optim as optim
from torch.utils.data.dataset import Dataset
from torch.nn.utils.rnn import pad_sequence
from torch.autograd import Variable

In [3]:
train = pd.read_json('./data/train.jsonl', lines= True)
valid = pd.read_json('./data/valid.jsonl', lines= True)
test = pd.read_json('./data/test.jsonl', lines= True)

In [4]:
train.head(1)

Unnamed: 0,id,summary,text,sent_bounds,extractive_summary
0,1000000,A seven-hundred-year old oak gate at Salisbury...,The Grade I listed Harnham Gate was hit by a w...,"[[0, 107], [107, 255], [255, 362]]",1


In [5]:
embeddings_dict = {}
with open("./glove.6B/glove.6B.300d.txt", 'r') as f:
    for line in f:
        values = line.split()
        word = values[0]
        vector = np.asarray(values[1:], "float32")
        embeddings_dict[word] = vector

In [6]:
# add SOS and EOS
embeddings_dict['_sos_'] =  np.random.rand(300, )
embeddings_dict['_eos_'] =  np.random.rand(300, )
embeddings_dict['_unk_'] =  np.random.rand(300, )

In [7]:
class words_dict():
    def __init__(self, glove):
        self.word_count = {}
        self.id_to_word = {0: '_sos_', 1: '_eos_', 2: '_unk_'}
        self.word_to_id = {'_sos_': 0, '_eos_': 1, '_unk_': 2}
        self.n_words = 3
        self.tokenizer =  RegexpTokenizer(r'\w+')
        self.remain_id = []
        self.glove = glove
        
    def add_word(self, sentence):
        tokens = self.tokenizer.tokenize(sentence)
        for token in tokens:
            token = token.lower()
            if token in self.glove.keys():
                if not self.word_to_id.get(token) :
                    self.word_to_id[token] = self.n_words
                    self.id_to_word[self.n_words] = token
                    self.n_words += 1
                    self.word_count[token] = 1
                else:
                    self.word_count[token] += 1
                    
    def sort_dict(self):
        self.remain_id.append(0)
        self.remain_id.append(1)
        self.remain_id.append(2)
        sort_d = sorted(self.word_count.items(), key = lambda x: x[1], reverse = True)[:int(self.n_words *0.3)]
        for (i, j) in sort_d:
            self.remain_id.append(self.word_to_id[i])

        self.reconstruct()
    
    
    def reconstruct(self):
                # reconstruct dict
        n_words =  3
        id_to_word = {0: '_sos_', 1: '_eos_', 2: '_unk_'}
        word_to_id = {'_sos_': 0, '_eos_': 1, '_unk_': 2}
        for i in self.remain_id:
            if not word_to_id.get(i):
                word_to_id[self.id_to_word[i]] = n_words
                id_to_word[n_words] = self.id_to_word[i]
                n_words += 1
        self.n_words = n_words
        self.id_to_word = id_to_word
        self.word_to_id = word_to_id
        self.remain_id = [i for i in range(n_words)]
        
    def get_emb(self, data):
        if self.word_to_id.get(data, -1) != -1:
            if self.word_to_id[data] in self.remain_id:
                return self.glove[data]
        return self.glove['_unk_']
        
    def get_word_id(self, data):
        if data == []:
            return -1
        if self.word_to_id.get(data, -1) != -1:
            if self.word_to_id[data] in self.remain_id:
                return self.word_to_id[data]
        return 2
        
        
        

In [8]:
merge_df = train.append(valid, ignore_index= True)
merge_df.shape

(91604, 5)

In [9]:
dictionary = words_dict(embeddings_dict)
for i in range(len(merge_df)):
    text = merge_df.loc[i, 'text']
    summary = merge_df.loc[i, 'summary']
    data = text + summary
    dictionary.add_word(data)
dictionary.sort_dict()

In [10]:
dictionary.n_words

29664

In [11]:
class SummaryDataset(Dataset):
    def __init__(self, data, dic, test = False):
        self.data = data
        self.dic = dic
        self.test = test
        self.tokenizer =  RegexpTokenizer(r'\w+')
        self.test = test
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        text = '_sos_ ' + self.data.loc[idx, 'text'] + ' _eos_'
        text_emb, text_word = self.get_emb(text)
        if not self.test:
            summary = self.data.loc[idx, 'summary'] + ' _eos_'
            summary_emb, summary_word = self.get_emb(summary)
            summary_word_id = self.get_summary_id_list(summary_word)
            length = len(summary_word_id)
            return  torch.tensor(text_emb), length, torch.tensor(summary_word_id)
        id = self.data.loc[idx, 'id']
        return torch.tensor(text_emb), text_word, id

    def get_emb(self, data):
        tokens = self.tokenizer.tokenize(data)
        embeddings = []
        words = []
        for idx, token in enumerate(tokens):
            token = token.lower()
            emb = self.dic.get_emb(token)
            if len(emb) > 0:
                embeddings.append(emb)
                words.append(token)
        if len(embeddings) ==0:
            return [[0.0 for i in range(300)]], [[]]
        return embeddings, words
    
    def get_summary_id_list(self, words):
        ans = []
        for i in words:
            word_id = self.dic.get_word_id(i)
            if word_id != -1:
                ans.append(word_id)
        return ans

In [110]:
def create_mini_batch(samples):
    text_emb, length, summary_word_id = zip(*samples)
    text_emb = pad_sequence(text_emb, batch_first=True)
    summary_word_id = pad_sequence(summary_word_id, batch_first=True, padding_value=1)
    return text_emb, length, summary_word_id

def create_mini_batch_test(samples):
    text_emb, text_word, id = zip(*samples)
    text_emb = pad_sequence(text_emb, batch_first=True)
    return text_emb, text_word, id

In [118]:
class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers =1, bidirectional=False, dropout = 0):
        super(EncoderRNN, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers,
                            dropout= dropout, bidirectional=bidirectional, batch_first = True)
        if bidirectional :
            self.l1 = nn.Linear(2*hidden_size, hidden_size)
        else:
            self.l1 = nn.Linear(hidden_size, hidden_size)
        self.relu = nn.ReLU()
        self.tan = nn.Tanh()
        self.init_weights()
        self.bidirectional = bidirectional
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        self.lstm.flatten_parameters()
        out, (hn, cn) = self.lstm(x)  # out: tensor of shape (batch_size, seq_length, hidden_size)
        if  self.bidirectional:
            hn = torch.cat((hn[0], hn[1]),1)
            hn = hn.unsqueeze(0)
        hn = self.tan(self.drop(self.l1(hn)))
        return hn

    def init_weights(self):
        for name, p in self.lstm.named_parameters():
            if 'weight' in name:
                nn.init.orthogonal_(p)
            elif 'bias' in name:
                nn.init.constant_(p, 0)
                
class DecoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers =1, dropout = 0):
        super(DecoderRNN, self).__init__()
        
        self.l1 = nn.Linear(input_size + hidden_size, dictionary.n_words)
        self.lstm = nn.LSTM(input_size + hidden_size, hidden_size, num_layers,
                            dropout= dropout, batch_first = True)
        
        self.relu = nn.ReLU()
        self.init_weights()

    def forward(self, x, h = None, c= None):
        self.lstm.flatten_parameters()
        out, (hn, cn) = self.lstm(x, (h, c))  # out: tensor of shape (batch_size, seq_length, hidden_size)
        return hn, cn
    
    def predict(self, x):
        out = self.l1(x)
        val, idx = out.max(-1)
        return out, idx
    
    def test(self, x):
        out = self.l1(x)
        return out
    
    def init_weights(self):
        for name, p in self.lstm.named_parameters():
            if 'weight' in name:
                nn.init.orthogonal_(p)
            elif 'bias' in name:
                nn.init.constant_(p, 0)


In [119]:
class AutoEncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers=1, bidirectional=False, dropout = 0):
        super(AutoEncoderRNN, self).__init__()
        self.encoder = EncoderRNN(input_size, hidden_size, num_layers, bidirectional, dropout=dropout)
        self.decoder = DecoderRNN(input_size, hidden_size, num_layers, dropout=dropout)  
        
        

In [120]:
input_size = 300
hidden_size = 300
batch_size = 16
lr = 1e-3
epoch = 3
teacher_forcing = True
num_layers = 1
bidirectional = True
dropout = 0

In [121]:
model = AutoEncoderRNN(input_size, hidden_size, bidirectional=bidirectional).cuda()

In [122]:
batch_size = 16
train_dataset = SummaryDataset(train, dictionary)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = batch_size, collate_fn = create_mini_batch ,drop_last = True, shuffle= True)

valid_dataset = SummaryDataset(valid, dictionary)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size = batch_size, collate_fn = create_mini_batch, drop_last = True, )

In [125]:
min_loss = 10000000000
for i in range(epoch):
    iteration = 0
    total_loss= 0
    total_words = 0
    opt = torch.optim.Adam(model.parameters(), lr=lr)
#     opt_e = torch.optim.Adam(model.encoder.parameters(), lr=lr)
#     opt_d = torch.optim.Adam(model.decoder.parameters(), lr=lr)
    loss_f = nn.CrossEntropyLoss()
    for text_emb, length, summary_word_id in train_loader:
        text_emb = text_emb.float().cuda()
        summary_word_id = summary_word_id.cuda()
        batch_loss = 0
        context = model.encoder(text_emb)  #torch.Size([1, 5, 150])
        hn = context
#       hn = Variable(torch.zeros(1, batch_size, hidden_size)).cuda()
        cn = Variable(torch.zeros(1, batch_size, hidden_size)).cuda()
        
        # first input with SOS token
        SOS =  torch.tensor([[dictionary.get_emb('_sos_') for i in range(batch_size)]]).float().cuda()
        inputs = torch.cat((context, SOS), 2) #torch.Size([1, B, 450])
        inputs = inputs.permute(1,0,2)  # torch.Size([B, 1, 450])
#         inputs = SOS.permute(1,0,2)
        words = SOS
        index = 0
        thres = int(summary_word_id.shape[1])

        while True:
            hn, cn = model.decoder(inputs, hn, cn)  # torch.Size([1, 5, 150])
            combined = torch.cat((context, hn), -1)
            values, predict = model.decoder.predict(combined)  #torch.Size([1, 20, 98862]) torch.Size([1, 20])

           
            for j in range(batch_size):
                if length[j] >= index:
                    labels = summary_word_id[:,index].long().cuda()
                    loss = loss_f(values[0], labels)
                    batch_loss += loss


            # reconstruct input
            if teacher_forcing :
                words = [dictionary.id_to_word[labels.tolist()[j]] for j in range(batch_size)]
                words = torch.tensor([dictionary.get_emb(words[j]) for j in range(len(words))]).float().cuda()
            else:
                words = [dictionary.id_to_word[predict.view(-1).tolist()[j]] for j in range(batch_size)]
                words = torch.tensor([dictionary.get_emb(words[j]) for j in range(len(words))]).float().cuda()
   
            words = words.unsqueeze(0)
            inputs = torch.cat((hn, words), -1).permute(1,0,2) # h[0] torch.Size([B, 98862])
            index += 1
            
            #if predict summary exceed thres
            if index >= thres:
                break

        batch_loss.backward()
        total_words += sum(length)
        total_loss += batch_loss
        opt.step()
        opt.zero_grad()
#         opt_e.step()
#         opt_e.zero_grad()
#         opt_d.step()
#         opt_d.zero_grad()
        iteration += 1
        print(f' Epoch : {i}, Iteration: {iteration}, batch_loss : {batch_loss/sum(length)}, avg_loss: {total_loss/ total_words}', end = '\r')
    valid_loss = validate()
    if valid_loss < min_loss:
        print(f'Validation loss improve from {min_loss} to {valid_loss} ')
        min_loss = valid_loss
        best_model = model
        with open(f'./model/model_abtractive_0404_1500.pkl', 'wb') as output:
            pickle.dump(best_model, output)
    else:
        print(f'Validation loss did not improve from original {min_loss} to {valid_loss} ')
        break
    
    
    
    
    

torch.Size([16, 1, 600])
torch.Size([16, 1, 600]), batch_loss : 8.522150039672852, avg_loss: 8.522150039672852
torch.Size([16, 1, 600]), batch_loss : 8.033958435058594, avg_loss: 8.264249801635742
torch.Size([16, 1, 600]), batch_loss : 7.623191833496094, avg_loss: 8.064986228942871
torch.Size([16, 1, 600]), batch_loss : 7.6953582763671875, avg_loss: 7.972710609436035
torch.Size([16, 1, 600]), batch_loss : 6.99948263168335, avg_loss: 7.7895612716674805
torch.Size([16, 1, 600]), batch_loss : 6.935104846954346, avg_loss: 7.649357318878174
torch.Size([16, 1, 600]), batch_loss : 7.29809045791626, avg_loss: 7.600976943969727
torch.Size([16, 1, 600]), batch_loss : 6.892031669616699, avg_loss: 7.5090436935424805


KeyboardInterrupt: 

In [110]:
# with open(f'./model/model_abtractive_0404_1500.pkl', 'wb') as output:
#     pickle.dump(best_model, output)

## Validation

In [111]:
def validate():
    with torch.no_grad():
        iteration = 0
        total_loss= 0
        total_words = 0
        loss_f = nn.CrossEntropyLoss()
        for text_emb, length, summary_word_id in valid_loader:
            text_emb = text_emb.float().cuda()
            summary_word_id = summary_word_id.cuda()
            batch_loss = 0
            context = model.encoder(text_emb)  #torch.Size([1, 5, 150])
            if bidirectional:
        #             hn = torch.cat((context[:,:,:300], context[:,:,300:]), 0)
                hn = Variable(torch.zeros(2, batch_size, hidden_size)).cuda()
                cn = Variable(torch.zeros(2, batch_size, hidden_size)).cuda()
            else:
                hn = context
#                 hn = Variable(torch.zeros(1, batch_size, hidden_size)).cuda()
                cn = Variable(torch.zeros(1, batch_size, hidden_size)).cuda()

            # first input with SOS token
            SOS =  torch.tensor([[dictionary.get_emb('_sos_') for i in range(batch_size)]]).float().cuda()
            inputs = torch.cat((context, SOS), 2) #torch.Size([1, B, 450])
            inputs = inputs.permute(1,0,2)  # torch.Size([B, 1, 450])
        #         inputs = SOS.permute(1,0,2)

            words = SOS
            index = 0
            thres = int(summary_word_id.shape[1])

            while True:
                hn, cn = model.decoder(inputs, hn, cn)  # torch.Size([1, 5, 150])
                combined = torch.cat((context, hn), -1)
                values, predict = model.decoder.predict(combined)  #torch.Size([1, 20, 98862]) torch.Size([1, 20])


                for j in range(batch_size):
                    if length[j] >= index:
                        labels = summary_word_id[:,index].long().cuda()
                        loss = loss_f(values[0], labels)
                        batch_loss += loss

                # reconstruct input
                if teacher_forcing :
                    words = [dictionary.id_to_word[labels.tolist()[j]] for j in range(batch_size)]
                    words = torch.tensor([dictionary.get_emb(words[j]) for j in range(len(words))]).float().cuda()
                else:
                    words = [dictionary.id_to_word[predict.view(-1).tolist()[j]] for j in range(batch_size)]
                    words = torch.tensor([dictionary.get_emb(words[j]) for j in range(len(words))]).float().cuda()

                words = words.unsqueeze(0)
                inputs = torch.cat((hn, words), -1).permute(1,0,2) # h[0] torch.Size([B, 98862])
                if bidirectional:
                    hn = torch.cat((hn[:,:,:300], hn[:,:,300:]), 0)
        #             inputs = words.permute(1,0,2)
                index += 1
                #if predict summary exceed thres
                if index >= thres:
                    break
            total_words += sum(length)
            total_loss += batch_loss
            iteration += 1
            print(f' Validation Iteration: {iteration}, batch_loss : {batch_loss/sum(length)}, avg_loss: {total_loss/ total_words}', end = '\r')
        return total_loss/total_words

## Prediction

In [143]:
with open(f'./model/model_abtractive_0404_1500.pkl', 'rb') as inputs:
    model = pickle.load(inputs)

In [144]:
batch_size = 50
test_dataset = SummaryDataset(valid, dictionary, test = True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size = batch_size, collate_fn = create_mini_batch_test)

In [145]:
prediction = ''
iteration = 0
with torch.no_grad():
    for text_emb, text_word, id in test_loader:
        text_emb = text_emb.float().cuda()
        context = model.encoder(text_emb)  #torch.Size([1, 5, 150])
        
        hn = context
#         hn = Variable(torch.zeros(1, batch_size, hidden_size)).cuda()
        cn = Variable(torch.zeros(1, batch_size, hidden_size)).cuda()
        
        SOS =  torch.tensor([[dictionary.get_emb('_sos_') for i in range(batch_size)]]).float().cuda()
        inputs = torch.cat((context, SOS), 2) #torch.Size([1, 5, 450])
        inputs = inputs.permute(1,0,2)
#         inputs = SOS.permute(1,0,2)  # torch.Size([5, 1, 450])
        words = SOS

#       stop criteria
        thres = len(text_emb[0]) * 0.3
        ans = [[] for i in range(batch_size)]
        index = 0
        while True:
            hn, cn = model.decoder(inputs, hn, cn)  # torch.Size([1, 5, 150])
            combined = torch.cat((context, hn), -1)
            values, predict = model.decoder.predict(combined)  #torch.Size([1, 20, 98862]) torch.Size([1, 20])
        
            val, pred = values.topk(3)
#             print(pred)
            for i in range(batch_size):
                pred_word = dictionary.id_to_word[predict.tolist()[0][i]]
                ans[i].append(pred_word)
            
            words = torch.tensor([dictionary.get_emb(ans[j][index]) for j in range(batch_size)]).float().cuda()
            words = words.unsqueeze(0)
            
            inputs = torch.cat((hn, words), -1).permute(1,0,2)
#             inputs = words.permute(1,0,2)
            index += 1
            
            #if predict summary exceed 40 words then stop
            if index >= thres:
                break

        print(f'Validation loop, Iteration: {iteration}', end = '\r')
        iteration += 1
        for idx in range(batch_size):
            try:
                eos_idx = ans[idx].index('_eos_') + 1
                ans[idx] =  ans[idx][:eos_idx]
                prediction += json.dumps({"id":str(id[idx]), "predict": ' '.join(ans[idx])}) + '\n'
            except:
                prediction += json.dumps({"id":str(id[idx]), "predict": ' '.join(ans[idx])}) + '\n'

Validation loop, Iteration: 399

In [147]:
with open('prediction_ab_0404.json','w') as f:
    f.write(prediction)

# Neet to improve

1. padding loss recalculation (done)
2. torch.save
3. pytorch lightning
