In [1]:
import csv
import json
import math
import nltk
import random

import numpy as np
from collections import Counter

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [2]:
class Vocab:
    def __init__(self, counter, sos, eos, pad, unk, min_freq=5):
        self.sos = sos
        self.eos = eos
        self.pad = pad
        self.unk = unk
        
        self.pad_idx = 0
        self.unk_idx = 1
        self.sos_idx = 2
        self.eos_idx = 3
        
        self._token2idx = {
            self.sos: self.sos_idx,
            self.eos: self.eos_idx,
            self.pad: self.pad_idx,
            self.unk: self.unk_idx,
        }
        self._idx2token = {idx:token for token, idx in self._token2idx.items()}
        
        idx = len(self._token2idx)
        min_freq = 0 if min_freq is None else min_freq
        
        for token, count in counter.items():
            if count > min_freq:
                self._token2idx[token] = idx
                self._idx2token[idx]   = token
                idx += 1
        
        self.vocab_size = len(self._token2idx)
        self.tokens     = list(self._token2idx.keys())
    
    def token2idx(self, token):
        return self._token2idx.get(token, self.pad_idx)
    
    def idx2token(self, idx):
        return self._idx2token.get(idx, self.pad)
    
    def __len__(self):
        return len(self._token2idx)

In [3]:
def pad_single_seq(sequence, pad_idx, max_length):
    '''
    Inputs:
        sequence: list of tokens
    '''    
    return sequence + [pad_idx]*(max_length - len(sequence))

In [4]:
class Dataset(object):
    def __init__(self, path, val=False):

        shuffle  = True
        self.val = val
        self.data = []
        word_data = []       
        words_counter = Counter()
          
        with open(path) as json_data:
            json_data = json.load(json_data)
            
        for data_cell in json_data['data']:
            for paragraph in data_cell['paragraphs']:
                context = nltk.word_tokenize(paragraph['context'])
                for qa in paragraph['qas']:
                    question = nltk.word_tokenize(qa['question'])
                    for ans in qa['answers']:
                        
                        answer = nltk.word_tokenize(ans['text'])

                        word_data.append((context, question, answer))

                        for token in context:
                            words_counter[token] += 1
                        for token in question:
                            words_counter[token] += 1
                        for token in answer:
                            words_counter[token] += 1  

                
        sos = "<sos>"
        eos = "<eos>"
        pad = "<pad>"
        unk = "<unk>"

        self.words_vocab = Vocab(words_counter, 
                            sos, eos, pad, unk)

 
        if not val:
            random.shuffle(self.data)
        
        for context, question, answer in word_data:

            cell_context = [self.words_vocab.token2idx(item) for item in context]    
            cell_question = [self.words_vocab.token2idx(item) for item in question]
            cell_answer = [self.words_vocab.token2idx(item) for item in answer]
            
            self.data.append((cell_context, cell_question, cell_answer))
                

    def __len__(self):
        return len(self.data)
        
    def get_batch(self, batch_size, sort = False):
        
        random_ids = np.random.randint(0, len(self.data), batch_size)
        if not self.val:
            batch_data = [self.data[idx] for idx in random_ids]
        else:
            batch_data = self.data
        
        max_context_length = max([len(a) for (a, _, _) in batch_data])
        max_question_length = max([len(b) for (_, b, _) in batch_data])
        max_answer_length = max([len(c) for (_, _, c) in batch_data])

        contexts = []
        questions = []
        answers = []
        for a, b, c in batch_data:
            
            cell_context  = pad_single_seq(a, self.words_vocab.pad_idx, max_context_length)
            cell_question = pad_single_seq(b, self.words_vocab.pad_idx, max_question_length)
            cell_answer   = pad_single_seq(c, self.words_vocab.pad_idx, max_answer_length)  
            
            cell_context  = torch.LongTensor(cell_context).to(device)
            cell_question = torch.LongTensor(cell_question).to(device)
            cell_answer   = torch.LongTensor(cell_answer).to(device)
            
            contexts.append(cell_context)
            questions.append(cell_question)
            answers.append(cell_answer)
            

        contexts  = torch.stack(contexts, 0)
        questions = torch.stack(questions, 0)      
        answers   = torch.stack(answers, 0)#.squeeze(1)

        return contexts, questions, answers

In [5]:
train_dataset = Dataset('train-v2.0.json')

# Transformer Model

### Encoder

In [None]:
class TransformerEncoderLayer(nn.Module):
    def __init__(self, model_size, num_heads, ff_size):
        super(TransformerEncoderLayer, self).__init__()
        
        self.self_attention = MultiHeadAttention(num_heads, model_size)
        self.positionwise_ff = PositionWiseFeedForward(model_size, ff_size)
        self.layer_norm = nn.LayerNorm(model_size)
        
    def forward(self, x, mask=None):
        '''
        Inputs:
            x: (batch x seq_len x model_size)
            mask: (batch x seq_len x seq_len)
            
        Outputs:
            output : (batch x seq_len, model_size)
        '''
        
        res  = x
        x, _ = self.self_attention(x, x, x, mask)       
        x    = x + res
        x    = self.positionwise_ff(x)
        x    = self.layer_norm(x)
        
        
        return x        

In [None]:
class TransformerEncoder(nn.Module):
    def __init__(self, 
                vocab_size,
                num_layers, 
                model_size, 
                num_heads,
                ff_size,
                padding_idx):
        super(TransformerEncoder, self).__init__()
        
        self.padding_idx = padding_idx

        self.embedding = nn.Embedding(vocab_size, model_size, padding_idx=padding_idx)
    
        self.positional_enc = PositionalEncoding(model_size)

        self.enc_blocks = nn.ModuleList(
                            [TransformerEncoderLayer(model_size, num_heads, ff_size)
                                        for _ in range(num_layers)])
        
    def forward(self, source):
        
        '''
        Inputs: 
            source: (batch_size, source_len)
            
        Outputs:
            x: (batch, source_len, hidden)
        '''
        source_mask = source == self.padding_idx
        mask = (source == self.padding_idx).unsqueeze(1).repeat(1, source.size(1), 1)

        source_emb = self.embedding(source)
        source_emb = self.positional_enc(source_emb)
        
        x = source_emb
        for layer in self.enc_blocks:       
            x = layer(x, mask)

        return x, source_mask   

### Decoder

In [None]:
class TransformerDecoderLayer(nn.Module):
    def __init__(self, model_size, num_heads, ff_size):
        super(TransformerDecoderLayer, self).__init__()
        
        self.self_attention = MultiHeadAttention(num_heads, model_size)
        self.positionwise_ff = PositionWiseFeedForward(model_size, ff_size)
        self.layer_norm = nn.LayerNorm(model_size)
        
    def forward(self, x, enc_outputs, enc_mask=None, subseq_mask=None):
        '''
        Inputs:
            x: (batch x target_len x model_size)
            enc_outputs: (batch x num_heads x model_size)
            enc_mask : (batch x target_len x source_len)
            subseq_mask: (batch x target_len x target_len)
            
        Outputs:
            output : (batch x ? x ?)
        '''      
        res  = x
        
        x, _ = self.self_attention(x, x, x, subseq_mask)     
        x    = x + res       
        x    = self.layer_norm(x)       
        
        res2 = x      
        #enc_mask = enc_mask.unsqueeze(1).repeat(1, res.size(1), 1)
        x, _ = self.self_attention(x, enc_outputs, enc_outputs, enc_mask)
        
        x = res2 + x        
        x = self.positionwise_ff(x)        
        x = self.layer_norm(x)        
        
        return x        

In [None]:
class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, emb_size, hidden_size, num_layers, pad_idx):
        super(TransformerDecoder, self).__init__()
        
        self.padding_idx = pad_idx
        
        self.vocab_size = vocab_size
        
        self.embedding = nn.Embedding(vocab_size, emb_size, padding_idx=pad_idx)
        
        self.positional_enc = PositionalEncoding(hidden_size)
        
        self.dec_blocks = nn.ModuleList(
                            [TransformerDecoderLayer(hidden_size, num_heads, ff_size)
                                        for _ in range(num_layers)])
        self.linear_out = nn.Linear(hidden_size, vocab_size, bias = False)   
        
    def forward(self, target, enc_outputs, batch_words, val=False):        
        '''
        Inputs: 
            source: (batch_size, target_len)
            
        Outputs:
            x: (batch, source_len, hidden)
        '''

        batch_size, target_len = target.size()
        
        subseq_mask = torch.triu(
                torch.ones((target_len, target_len), device=target.device, dtype=torch.uint8), diagonal=1)
        subseq_mask = subseq_mask.unsqueeze(0).expand(batch_size, -1, -1)  # b x ls x ls
        
        dec_mask_ = (target == 0).unsqueeze(1).repeat(1, target_len, 1)
        
        dec_mask = (subseq_mask + dec_mask_).gt(0)
        
        target_emb = self.embedding(target)
        target_emb = self.positional_enc(target_emb) 
        
        
        enc_mask = batch_words == 0
        enc_mask = enc_mask.unsqueeze(1).repeat(1, target_len, 1)
        x = target_emb
        for layer in self.dec_blocks:       
            x = layer(x, enc_outputs, enc_mask, dec_mask)            
        
        
        logits = self.linear_out(x)
        
        logits = logits.view(-1, self.vocab_size)
        
        #out = F.softmax(logits)
        
        return logits

In [None]:
class Model(nn.Module):
    def __init__(self, dataset, encoder, decoder, hidden_size):
        super(Model, self).__init__()
        self.dataset = dataset
        self.encoder = encoder
        self.decoder = decoder
        
        
    def forward(self, batch_words, batch_trans_in, source_lens = None, mask = None):
        
        enc_outputs, _ = self.encoder(batch_words)
        out = self.decoder(batch_trans_in, enc_outputs, batch_words, False)
        
        return out