In [1]:
import os
import glob
import torch
import nltk
import copy
import math
import random
import numpy as np
from collections import Counter

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR
from torch.nn.utils.rnn import pack_padded_sequence as pack
from torch.nn.utils.rnn import pad_packed_sequence as unpack


from IPython.display import clear_output
import matplotlib.pyplot as plt
%matplotlib inline
device = torch.device("cuda" if torch.cuda.is_available else "cpu")

In [40]:
class Vocab:
    def __init__(self, counter, sos, eos, pad, unk, min_freq=None):
        self.sos = sos
        self.eos = eos
        self.pad = pad
        self.unk = unk
        
        self.pad_idx = 0
        self.unk_idx = 1

        
        self._token2idx = {
            self.pad: self.pad_idx,
            self.unk: self.unk_idx,
        }
        self._idx2token = {idx:token for token, idx in self._token2idx.items()}
        
        idx = len(self._token2idx)
        min_freq = 0 if min_freq is None else min_freq
        
        for token, count in counter.items():
            if count > min_freq:
                self._token2idx[token] = idx
                self._idx2token[idx]   = token
                idx += 1
        
        self.vocab_size = len(self._token2idx)
        self.tokens     = list(self._token2idx.keys())
    
    def token2idx(self, token):
        return self._token2idx.get(token, self.pad_idx)
    
    def idx2token(self, idx):
        return self._idx2token.get(idx, self.pad)
    
    def __len__(self):
        return len(self._token2idx)

def pad_contexts(list_):
    
    length_dict = {key: len(value) for key, value in d.items()}
    
def pad_many_seq(sequences, pad_idx, max_length):
    '''
    Inputs:
        sequences: list of list of tokens
    '''    
    return [ seq + [pad_idx]*(max_length - len(seq)) for seq in sequences]


def pad_single_seq(sequence, pad_idx, max_length):
    '''
    Inputs:
        sequence: list of tokens
    '''    
    return sequence + [pad_idx]*(max_length - len(sequence))

def pad_body(sequences, pad_idx, max_length, max_seq_length):
    '''
    Inputs:
        sequences: list of list of tokens
    '''    
    return sequences + [torch.zeros(max_seq_length)] * (max_length - len(sequences))
    
    
def words_tokenize(line):
    return  nltk.word_tokenize(line)


In [41]:
class Dataset(object):
    def __init__(self, path, val):
        
        
        shuffle  = True
        self.val = val
        self.data = []
        primal_data = []
        context = []

        dict_cell = {}
      
        with open(path) as f:
            file_content = f.readlines()
            for x in file_content:

                line = nltk.word_tokenize(x)
               
                if line[0] == '1':
                    del context
                    context = []

                if '?' in line:
                    dict_cell['question'] = line[1:line.index('?') + 1]
                    dict_cell['context'] =  copy.deepcopy(context)
                    dict_cell['answer'] = line[line.index('?') + 1]
                    primal_data.append(dict_cell)
                    del dict_cell
                    dict_cell = {}
                else:
                    context.append(line[1:])
         
        words_counter = Counter()
        
        for cell in primal_data:
            
            for context in cell['context']:
                for token in context:
                    words_counter[token] += 1
            for token in cell['question']:
                words_counter[token] += 1
            
            words_counter[cell['answer']] += 1
                
        sos = "<sos>"
        eos = "<eos>"
        pad = "<pad>"
        unk = "<unk>"

        self.words_vocab = Vocab(words_counter, 
                            sos, eos, pad, unk)

 
        if not val:
            random.shuffle(primal_data)
        
        for cell in primal_data:

            cell_context = [[self.words_vocab.token2idx(item) for item in cntx] for cntx in cell['context']]         
            cell_question = [self.words_vocab.token2idx(item) for item in cell['question']]
            cell_answer = [self.words_vocab.token2idx(cell['answer'])]# for item in cell['answer']]
            self.data.append((cell_context, cell_question, cell_answer))
                

    def __len__(self):
        return len(self.data)
        
    def get_batch(self, batch_size, sort = False):
        
        random_ids = np.random.randint(0, len(self.data), batch_size)
        if not self.val:
            batch_data = [self.data[idx] for idx in random_ids]
        else:
            batch_data = self.data
        
        max_context_length = max([max(map(len, a)) for (a, _, _) in batch_data])
        max_question_length = max([len(b) for (_, b, _) in batch_data])
        max_contexts_length = max([len(a) for (a, _, _) in batch_data])

        contexts = []
        questions = []
        answers = []
        for a, b, c in batch_data:

            cell_context = pad_many_seq(a, self.words_vocab.pad_idx, max_context_length)
            cell_context = pad_body(cell_context, self.words_vocab.pad_idx, max_contexts_length, max_context_length)
            cell_question = pad_single_seq(b, self.words_vocab.pad_idx, max_question_length)  
            
            cell_context = torch.LongTensor(cell_context).to(device)
            cell_question = torch.LongTensor(cell_question).to(device)
            cell_answer = torch.LongTensor(c).to(device)
            
            contexts.append(cell_context)
            questions.append(cell_question)
            answers.append(cell_answer)
            

        contexts = torch.stack(contexts, 0)
        questions = torch.stack(questions, 0)      
        answers = torch.stack(answers, 0).squeeze(1)

        return contexts, questions, answers

        
    def __getitem__(self, idx):

        return self.data[idx]['context'], self.data[idx]['question'], self.data[idx]['answer']

In [51]:
import os
import re
from itertools import chain

import numpy as np

import torch
import torch.utils.data
from torch.utils.data import DataLoader

def load_task(data_dir, task_id, only_supporting=False):
    assert task_id > 0 and task_id <= 20
    files = os.listdir(data_dir)
    files = [os.path.join(data_dir, f) for f in files]
    s = "qa{}_".format(task_id)
    train_file = [f for f in files if s in f and 'train' in f][0]
    test_file  = [f for f in files if s in f and 'test'  in f][0]
    train_data = get_stories(train_file, only_supporting)
    test_data  = get_stories(test_file,  only_supporting)
    return train_data, test_data


def tokenize(sent):
    return [x.strip() for x in re.split("(\W+)?", sent) if x.strip()]

def parse_stories(lines, only_supporting=False):
    data  = []
    story = []
    for line in lines:
        nid, line = line.lower().split(" ", 1)
        nid  = int(nid)
        if nid == 1:
            story = []
        if '\t' in line: #question
            q, a, supporting = line.split('\t')
            q = tokenize(q)
            a = [a]
            substory = None
            if q[-1] == "?":
                q = q[:-1]
            
            if only_supporting:
                supporting = map(int, supporting.split())
                substory   = [story[i - 1] for i in supporting]
            else:
                substory   = [x for x in story if x]
            
            data.append((substory, q, a))
            story.append("")
        else:
            sent = tokenize(line)
            if sent[-1] == '.':
                sent = sent[:-1]
            story.append(sent)
    return data


def get_stories(f, only_supporting=False):
    with open(f) as f:
        return parse_stories(f.readlines(), only_supporting=only_supporting)
    
    
def vectorize_data(data, word_idx, sentence_size, memory_size):
    S, Q, A = [], [], []
    for story, query, answer in data:
        ss = []
        for i, sentence in enumerate(story, 1):
            ls = max(0, sentence_size - len(sentence))
            ss.append([word_idx[w] for w in sentence] + [0] * ls)
        
        ss = ss[::-1][:memory_size][::-1]
        
        for i in range(len(ss)):
            ss[i][-1] = len(word_idx) - memory_size - i + len(ss)
        
        lm = max(0, memory_size - len(ss))
        for _ in range(lm):
            ss.append([0] * sentence_size)
        
        lq = max(0, sentence_size - len(query))
        q = [word_idx[w] for w in query] + [0] * lq
        
        y = np.zeros(len(word_idx) + 1)
        for a in answer:
            y[word_idx[a]] = 1
        
        S.append(ss)
        Q.append(q)
        A.append(y)
    return np.array(S), np.array(Q), np.array(A)

class bAbIDataset(torch.utils.data.Dataset):
    def __init__(self, dataset_dir, task_id=1, memory_size=50, train=True):
        self.train       = train
        self.task_id     = task_id
        self.dataset_dir = dataset_dir
        
        train_data, test_data = load_task(self.dataset_dir, task_id)
        data = train_data + test_data
        
        self.vocab = set([])
        for story, query, answer in data:
            self.vocab = self.vocab | set(list(chain.from_iterable(story)) + query + answer)
        self.vocab = sorted(self.vocab)
        word_idx = {word:i+1 for i, word in enumerate(self.vocab)}
        
        self.max_story_size = max([len(story) for story, _, _ in data])
        self.query_size     = max([len(query) for _, query, _ in data])
        self.sentence_size  = max([len(row) for row in chain.from_iterable([story for story, _, _ in data])])
        self.memory_size    = max(memory_size, self.max_story_size)
        
        for i in range(self.memory_size):
            word_idx["time{}".format(i+1)] = "time{}".format(i + 1)
        
        self.num_vocab = len(word_idx)
        self.sentence_size = max(self.sentence_size, self.query_size)
        self.sentence_size += 1
        self.word_idx      = word_idx
        
        self.mean_story_size = int(np.mean([len(s) for s, _, _ in data]))
        
        if train:
            story, query, answer = vectorize_data(train_data, self.word_idx, self.sentence_size, self.memory_size)
        else:
            story, query, answer = vectorize_data(test_data, self.word_idx, self.sentence_size, self.memory_size)
         
        self.data_story  = torch.LongTensor(story)
        self.data_query  = torch.LongTensor(query)
        self.data_answer = torch.LongTensor(np.argmax(answer, axis=1))
    
    def __getitem__(self, idx):
        return self.data_story[idx], self.data_query[idx], self.data_answer[idx]
    
    def __len__(self):
        return len(self.data_story)

In [42]:
train_dataset = Dataset('tasks_1-20_v1-2/en-10k/qa1_single-supporting-fact_train.txt', False)

  return _compile(pattern, flags).split(string, maxsplit)


In [43]:
class AtentionDecoder(nn.Module):
    def __init__(self, vocab_size, emb_size, hidden_size, pad_idx):
        super(AtentionDecoder, self).__init__()
       
        self.vocab_size = vocab_size
        self.emb_size = emb_size
        self.hidden_size = hidden_size
        
        self.embedding = nn.Embedding(vocab_size, emb_size, padding_idx=pad_idx)

        self.GRU_context = nn.GRU(emb_size, hidden_size, batch_first = True)
        
        self.GRU_query = nn.GRU(emb_size, hidden_size, batch_first = True)
        
        self.linear_1 = nn.Linear(2 * hidden_size, hidden_size)
        self.linear_2 = nn.Linear(hidden_size, vocab_size)
        
        
    def forward(self, context, query, mask):
        '''
        context: batch_size, N_contexts, seq_len
        query: batch_size, seq_len
        '''
        
        emb_context = self.embedding(context)
        emb_context = emb_context.view(-1, context.shape[-1], self.emb_size)
        _, hidden_context = self.GRU_context(emb_context)
        hidden_context = hidden_context.squeeze(0)
        hidden_context = hidden_context.view(context.shape[0], context.shape[1], self.hidden_size)    
        print(hidden_context.shape, 'hidden')
            
        embedded_query = self.embedding(query)    
        query_outputs, query_hidden = self.GRU_query(embedded_query)     
        query_hidden = query_hidden.squeeze(0)
        
        
        _, att_outputs = self.attention(query_hidden, hidden_context, mask)
        output = torch.cat([query_hidden, att_outputs], dim = 1)       
        linear_1 = self.linear_1(output)        
        linear_1 = torch.tanh(linear_1)        
        linear_2 = self.linear_2(linear_1)        
        out = linear_2.view(-1, self.vocab_size)
           
        return out
    
    def attention(self, query, context, mask = None):
        
        '''
        Inputs:
            query:   (batch_size, hidden) - outputs of decoder
            context: (batch_size, N, hidden) - outputs of encoder
            mask:    (batch_size, enc_seq_len)
        Outputs:
            weights: (batch_size, dec_seq_len, enc_seq_len)
            outputs: (batch, dec_seq_len, hidden)
        '''
        query = query.unsqueeze(-1)
        logits = torch.matmul(context, query)

        logits = logits.squeeze(-1)
        
        weights = F.softmax(logits, dim = -1)             
        weights = weights.unsqueeze(1)

        outputs = torch.matmul(weights, context)
        outputs = outputs.squeeze(1)

        return weights, outputs

In [44]:
class HierarchyRNN(nn.Module):
    def __init__(self, vocab_size, emb_size, hidden_size, pad_idx):
        super(HierarchyRNN, self).__init__()
       
        self.vocab_size = vocab_size
        self.emb_size = emb_size
        self.hidden_size = hidden_size
        
        self.embedding = nn.Embedding(vocab_size, emb_size, padding_idx=pad_idx)

        self.GRU_context = nn.GRU(emb_size, hidden_size, batch_first = True)
        
        self.GRU_hierarchy = nn.GRU(hidden_size, hidden_size, batch_first = True)
        
        self.GRU_query = nn.GRU(emb_size, hidden_size, batch_first = True)
        
        self.linear_1 = nn.Linear(1 * hidden_size, int(hidden_size/2))
        self.linear_2 = nn.Linear(int(hidden_size/2), vocab_size)
        
        
    def forward(self, context, query, mask):
        '''
        context: batch_size, N_contexts, seq_len
        query: batch_size, seq_len
        '''
        
        emb_context = self.embedding(context)
        emb_context = emb_context.view(-1, context.shape[-1], self.emb_size)
        _, hidden_context = self.GRU_context(emb_context)
        hidden_context = hidden_context.squeeze(0)
        hidden_context = hidden_context.view(context.shape[0], context.shape[1], self.hidden_size)    
            
        embedded_query = self.embedding(query)    
        query_outputs, query_hidden = self.GRU_query(embedded_query)     
        query_hidden = query_hidden.squeeze(0)
        
        

        query_hidden = query_hidden.unsqueeze(1)
        xxx = []
        xxx.append(query_hidden)
        xxx.append(hidden_context)
        
        res = torch.cat(xxx, 1)
        _, hidden_hierqarchy = self.GRU_hierarchy(res)
        hidden_hierqarchy = hidden_hierqarchy.squeeze(0)
        
        #output = torch.cat([hidden_hierqarchy, query_hidden], dim = 1)       
        #linear_1 = self.linear_1(output) 
        
        linear_1 = self.linear_1(hidden_hierqarchy) 
        linear_1 = torch.relu(linear_1)        
        linear_2 = self.linear_2(linear_1)        
        out = linear_2.view(-1, self.vocab_size)
           
        return out
    

In [45]:
att_decoder = HierarchyRNN(len(train_dataset.words_vocab), 64, 512, pad_idx=train_dataset.words_vocab.pad_idx).to(device)

criterion = nn.CrossEntropyLoss()

decoder_optimizer = optim.Adam(att_decoder.parameters())

In [46]:
class Model(nn.Module):
    def __init__(self, dataset, encoder, decoder,vocab_size):
        super(Model, self).__init__()
        self.dataset = dataset
        self.encoder = encoder
        self.decoder = decoder
        
        self.vocab_size = vocab_size
        
        
    def forward(self, context, query, mask):

        out = self.decoder(context, query, mask)
           
        return out   
    
    def generate(self, bos_idx, eos_idx, batch_words):
        inp = [bos_idx]
        outputs, hidden = self.encoder(batch_words)
        mask_words  = batch_words != 0
        

        for _ in range(100):
            inp_tensor = torch.LongTensor([[inp[-1]]]).to(batch_words.device)
            logits, hidden   = self.decoder(hidden, outputs, inp_tensor, mask_words)
            next_token = F.softmax(logits, dim=-1)[-1].topk(1)[1].item()
            inp.append(next_token)
            if next_token == eos_idx:
                break
        return inp

In [47]:
class Trainer:
    def __init__(self, dataset, test_dataset, model,  decoder_optimizer, criterion, batch_size):
        
        
        self.dataset = dataset
        self.test_dataset = test_dataset
        self.train_losses = []
        self.val_losses = []
        self.batch_size = batch_size
        
        self.model = model

        self.decoder_optimizer = decoder_optimizer
        self.criterion = criterion
        
               
    def train(self, n_epochs):
        
        mask_words = None
        for epoch in range(n_epochs):
           #scheduler.step()

#             for batch_idx in range(len(self.dataset)//self.batch_size):

#                 contexts, questions, answers = self.dataset.get_batch(self.batch_size)
            for contexts, questions, answers in train_loader:

                logits = self.model(contexts, questions, mask_words)
                
                loss = self.criterion(logits, answers)
                
                self.decoder_optimizer.zero_grad()
                loss.backward()
                self.decoder_optimizer.step()

                self.train_losses.append(loss.item())
                
                if batch_idx % 200 == 0:
                    val_acc = self.eval_()
                    self.val_losses.append(val_acc)
                    self.plot(epoch, batch_idx, self.train_losses, self.val_losses)
                  
        
    def eval_(self):
        
        with torch.no_grad():
            contexts, questions, answers  = self.test_dataset.get_batch(len(self.test_dataset))

            logits = self.model(contexts, questions, None)

            acc = self.accuracy(logits, answers)

        return acc
        
    def accuracy(self, logits, answers):
        
        correct = 0
        for en, lg in enumerate(logits):

            ans = lg.argmax()#.item()     
            if ans == answers[en].item():
                correct += 1
                
   
        #correct = torch.argmax(logits, 1) == answers
        #correct.sum().item()/len(logits)
        return correct/len(logits)

        
    def plot(self, epoch, batch_idx, train_losses, val_losses):
        clear_output(True)
        plt.figure(figsize=(20,5))
        plt.subplot(131)
        plt.title('epoch %s. | batch: %s | loss: %s' % (epoch, batch_idx, train_losses[-1]))
        plt.plot(train_losses)
        plt.subplot(132)
        plt.title('epoch %s. | acc: %s' % (epoch, val_losses[-1]))
        plt.plot(val_losses)
        plt.show()  

In [48]:
train_dataset = Dataset('tasks_1-20_v1-2/en-10k/qa1_single-supporting-fact_train.txt',val = False)
test_dataset = Dataset('tasks_1-20_v1-2/en-10k/qa1_single-supporting-fact_test.txt', val = True)

# model = Model(train_dataset, None, decoder, len(train_dataset.words_vocab)).to(device)
trainer = Trainer(train_dataset, test_dataset, att_decoder,  decoder_optimizer, criterion, batch_size = 128)

In [49]:
trainer.train(1000)

KeyboardInterrupt: 

In [None]:
max(trainer.val_losses)

In [None]:
def _print(val):
    contexts, questions, answers = data.get_batch(1)

    inp = model.generate(data.words_vocab.sos_idx, data.words_vocab.eos_idx, batch_words)
            
    tokens = [data.trans_vocab.idx2token(idx) for idx in inp if idx not in [data.trans_vocab.sos_idx,
                                                                       data.trans_vocab.eos_idx,
                                                                       data.trans_vocab.pad_idx]]
    print('Src: ', ''.join([data.words_vocab.idx2token(idx) for idx in batch_words[0].tolist()]))
    print('Pred:', ''.join(tokens))
    print('Real:', ''.join([data.trans_vocab.idx2token(idx) for idx in batch_trans_out[0].tolist() if idx not in [data.trans_vocab.sos_idx,
                                                                            data.trans_vocab.eos_idx,
                                                                            data.trans_vocab.pad_idx]]))