In [1]:
import csv
import json
import math
import nltk
import random

import numpy as np
from collections import Counter

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available else "cpu")

In [2]:
class Vocab:
    def __init__(self, counter, sos, eos, pad, unk, min_freq=5):
        self.sos = sos
        self.eos = eos
        self.pad = pad
        self.unk = unk
        
        self.pad_idx = 0
        self.unk_idx = 1
        self.sos_idx = 2
        self.eos_idx = 3
        
        self._token2idx = {
            self.sos: self.sos_idx,
            self.eos: self.eos_idx,
            self.pad: self.pad_idx,
            self.unk: self.unk_idx,
        }
        self._idx2token = {idx:token for token, idx in self._token2idx.items()}
        
        idx = len(self._token2idx)
        min_freq = 0 if min_freq is None else min_freq
        
        for token, count in counter.items():
            if count > min_freq:
                self._token2idx[token] = idx
                self._idx2token[idx]   = token
                idx += 1
        
        self.vocab_size = len(self._token2idx)
        self.tokens     = list(self._token2idx.keys())
    
    def token2idx(self, token):
        return self._token2idx.get(token, self.pad_idx)
    
    def idx2token(self, idx):
        return self._idx2token.get(idx, self.pad)
    
    def __len__(self):
        return len(self._token2idx)

In [3]:
def pad_single_seq(sequence, pad_idx, max_length):
    '''
    Inputs:
        sequence: list of tokens
    '''    
    return sequence + [pad_idx]*(max_length - len(sequence))

In [12]:
print(nltk.word_tokenize(("HEy how are you?").lower()))

['hey', 'how', 'are', 'you', '?']


In [4]:
class Dataset(object):
    def __init__(self, path, val=False):

        shuffle  = True
        self.val = val
        self.data = []
        word_data = []       
        words_counter = Counter()
          
        with open(path) as json_data:
            json_data = json.load(json_data)
            
        for data_cell in json_data['data']:
            for paragraph in data_cell['paragraphs']:
                context = nltk.word_tokenize(paragraph['context'].lower())
                for qa in paragraph['qas']:
                    question = nltk.word_tokenize(qa['question'].lower())
                    for ans in qa['answers']:
                        
                        answer = nltk.word_tokenize(ans['text'].lower())

                        word_data.append((context, question, answer))

                        for token in context:
                            words_counter[token] += 1
                        for token in question:
                            words_counter[token] += 1
                        for token in answer:
                            words_counter[token] += 1  

                
        sos = "<sos>"
        eos = "<eos>"
        pad = "<pad>"
        unk = "<unk>"

        self.words_vocab = Vocab(words_counter, 
                            sos, eos, pad, unk)

 
        if not val:
            random.shuffle(self.data)
        
        for context, question, answer in word_data:

            cell_context = [self.words_vocab.token2idx(item) for item in context]    
            cell_question = [self.words_vocab.token2idx(item) for item in question]
            cell_answer = [self.words_vocab.token2idx(item) for item in answer]
            
            self.data.append((cell_context, cell_question, cell_answer))
                

    def __len__(self):
        return len(self.data)
        
    def get_batch(self, batch_size):
        
        random_ids = np.random.randint(0, len(self.data), batch_size)
        if not self.val:
            batch_data = [self.data[idx] for idx in random_ids]
        else:
            batch_data = self.data
        
        max_context_length = max([len(a) for (a, _, _) in batch_data])
        max_question_length = max([len(b) for (_, b, _) in batch_data])
        max_answer_length = max([len(c) for (_, _, c) in batch_data])

        contexts = []
        questions = []
        answers = []
        for a, b, c in batch_data:
            
            cell_context  = pad_single_seq(a, self.words_vocab.pad_idx, max_context_length)
            cell_question = pad_single_seq(b, self.words_vocab.pad_idx, max_question_length)
            cell_answer   = pad_single_seq(c, self.words_vocab.pad_idx, max_answer_length)  
            
            cell_context  = torch.LongTensor(cell_context).to(device)
            cell_question = torch.LongTensor(cell_question).to(device)
            cell_answer   = torch.LongTensor(cell_answer).to(device)
            
            contexts.append(cell_context)
            questions.append(cell_question)
            answers.append(cell_answer)
            

        contexts  = torch.stack(contexts, 0)
        questions = torch.stack(questions, 0)      
        answers   = torch.stack(answers, 0)#.squeeze(1)

        return contexts, questions, answers

In [5]:
train_dataset = Dataset('train-v2.0.json')
test_dataset  = Dataset('train-v2.0.json', val=True)

In [17]:
#Number of words in vocabulary
len(train_dataset.words_vocab._token2idx)

60958

# Transformer Model

In [6]:
class PositionalEncoding(nn.Module):
    def __init__(self, dim, max_len=5000):
        pe = torch.zeros(max_len, dim)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) *
                             -(math.log(10000.0) / dim)))
        pe[:, 0::2] = torch.sin(position.float() * div_term)
        pe[:, 1::2] = torch.cos(position.float() * div_term)
        pe = pe.unsqueeze(0)
        super(PositionalEncoding, self).__init__()
        self.register_buffer('pe', pe)
        self.dim = dim

    def forward(self, emb):
        emb = emb * math.sqrt(self.dim)
        emb = emb + self.pe[:,:emb.size(1),:]
        return emb

In [7]:
class MultiHeadAttention(nn.Module):
    
    def __init__(self, num_heads, model_size):
        super(MultiHeadAttention, self).__init__()
        
        assert model_size % num_heads == 0 
        
        self.model_size = model_size
        self.num_heads = num_heads
        self.head_size    = model_size // num_heads
        
        self.linear_query = nn.Linear(model_size, model_size, bias = False)
        self.linear_key   = nn.Linear(model_size, model_size, bias = False)
        self.linear_value = nn.Linear(model_size, model_size, bias = False)
        self.linear_out = nn.Linear(model_size, model_size, bias = False)
        
    def forward(self, query, key, value, mask):
        '''
        Inputs:
            query:   (batch, target_len, hidden)
            key:     (batch, source_len, hidden)
            value:   (batch, source_len, hidden)  
            mask:    (batch, target_len, source_len)

        Outputs:
            output:  (batch, target_len, model_size)
            weight:  (batch, num_heads, target_len, source_len)
        '''
        batch_size = query.size(0)        
        
        query = self.linear_query(query)
        key   = self.linear_key(key)
        value = self.linear_value(value)
      
        query = query.view(batch_size, -1, self.num_heads, self.head_size)
        key   = key.view(batch_size, -1, self.num_heads, self.head_size)
        value = value.view(batch_size, -1, self.num_heads, self.head_size)
        
        query = query.transpose(1, 2)
        key = key.transpose(1, 2)
        value = value.transpose(1, 2)
             
        key = key.transpose(2, 3)
        
        logits = torch.matmul(query, key)
        logits = logits / math.sqrt(self.head_size)

        
        if mask is not None:
            mask = mask.unsqueeze(1)#.repeat()
            logits.masked_fill_(mask, -1e18)      
        
        weights = F.softmax(logits, dim = -1)

        output = torch.matmul(weights, value)       
        output = output.transpose(1,2).contiguous()
        output = output.view(batch_size, -1, self.model_size)
        output = self.linear_out(output)

        return output, weights

In [8]:
class PositionWiseFeedForward(nn.Module):
    def __init__(self, model_size, ff_size):
        super(PositionWiseFeedForward, self).__init__()
        
        self.linear_1 = nn.Linear(model_size, ff_size)
        self.linear_2 = nn.Linear(ff_size, model_size)
        self.layer_norm = nn.LayerNorm(model_size)
        
    def forward(self, x):
        '''
        Inputs:
            x: (batch_size x seq_len x model-size)
        Outputs:
            output: 
        '''        
        res = x
        x = self.layer_norm(x)
        x = self.linear_1(x)
        x = F.relu(x)
        x = self.linear_2(x)
        
        return x + res

### Encoder

In [9]:
class TransformerEncoderLayer(nn.Module):
    def __init__(self, model_size, num_heads, ff_size):
        super(TransformerEncoderLayer, self).__init__()
        
        self.self_attention = MultiHeadAttention(num_heads, model_size)
        self.positionwise_ff = PositionWiseFeedForward(model_size, ff_size)
        self.layer_norm = nn.LayerNorm(model_size)
        
    def forward(self, x, mask=None):
        '''
        Inputs:
            x: (batch x seq_len x model_size)
            mask: (batch x seq_len x seq_len)
            
        Outputs:
            output : (batch x seq_len, model_size)
        '''
        
        res  = x
        x, _ = self.self_attention(x, x, x, mask)       
        x    = x + res
        x    = self.positionwise_ff(x)
        x    = self.layer_norm(x)
        
        
        return x        

In [10]:
class TransformerEncoder(nn.Module):
    def __init__(self, 
                vocab_size,
                num_layers, 
                model_size, 
                num_heads,
                ff_size,
                padding_idx):
        super(TransformerEncoder, self).__init__()
        
        self.padding_idx = padding_idx

        self.embedding = nn.Embedding(vocab_size, model_size, padding_idx=padding_idx)
    
        self.positional_enc = PositionalEncoding(model_size)

        self.enc_blocks = nn.ModuleList(
                            [TransformerEncoderLayer(model_size, num_heads, ff_size)
                                        for _ in range(num_layers)])
        
    def forward(self, source):
        
        '''
        Inputs: 
            source: (batch_size, source_len)
            
        Outputs:
            x: (batch, source_len, hidden)
        '''
        source_mask = source == self.padding_idx
        mask = (source == self.padding_idx).unsqueeze(1).repeat(1, source.size(1), 1)

        source_emb = self.embedding(source)
        source_emb = self.positional_enc(source_emb)
        
        x = source_emb
        for layer in self.enc_blocks:       
            x = layer(x, mask)

        return x, source_mask   

### Decoder

In [11]:
class TransformerDecoderLayer(nn.Module):
    def __init__(self, model_size, num_heads, ff_size):
        super(TransformerDecoderLayer, self).__init__()
        
        self.self_attention = MultiHeadAttention(num_heads, model_size)
        self.positionwise_ff = PositionWiseFeedForward(model_size, ff_size)
        self.layer_norm = nn.LayerNorm(model_size)
        
    def forward(self, x, enc_outputs, enc_mask=None, subseq_mask=None):
        '''
        Inputs:
            x: (batch x target_len x model_size)
            enc_outputs: (batch x num_heads x model_size)
            enc_mask : (batch x target_len x source_len)
            subseq_mask: (batch x target_len x target_len)
            
        Outputs:
            output : (batch x ? x ?)
        '''      
        res  = x
        
        x, _ = self.self_attention(x, x, x, subseq_mask)     
        x    = x + res       
        x    = self.layer_norm(x)       
        
        res2 = x      
        #enc_mask = enc_mask.unsqueeze(1).repeat(1, res.size(1), 1)
        x, _ = self.self_attention(x, enc_outputs, enc_outputs, enc_mask)
        
        x = res2 + x        
        x = self.positionwise_ff(x)        
        x = self.layer_norm(x)        
        
        return x        

In [12]:
class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, emb_size, hidden_size, num_layers, pad_idx):
        super(TransformerDecoder, self).__init__()
        
        self.padding_idx = pad_idx
        
        self.vocab_size = vocab_size
        
        self.embedding = nn.Embedding(vocab_size, emb_size, padding_idx=pad_idx)
        
        self.positional_enc = PositionalEncoding(hidden_size)
        
        self.dec_blocks = nn.ModuleList(
                            [TransformerDecoderLayer(hidden_size, num_heads, ff_size)
                                        for _ in range(num_layers)])
        self.linear_out = nn.Linear(hidden_size, vocab_size, bias = False)   
        
    def forward(self, target, enc_outputs, batch_words, val=False):        
        '''
        Inputs: 
            source: (batch_size, target_len)
            
        Outputs:
            x: (batch, source_len, hidden)
        '''

        batch_size, target_len = target.size()
        
        subseq_mask = torch.triu(
                torch.ones((target_len, target_len), device=target.device, dtype=torch.uint8), diagonal=1)
        subseq_mask = subseq_mask.unsqueeze(0).expand(batch_size, -1, -1)  # b x ls x ls
        
        dec_mask_ = (target == 0).unsqueeze(1).repeat(1, target_len, 1)
        
        dec_mask = (subseq_mask + dec_mask_).gt(0)
        
        target_emb = self.embedding(target)
        target_emb = self.positional_enc(target_emb) 
        
        
        enc_mask = batch_words == 0
        enc_mask = enc_mask.unsqueeze(1).repeat(1, target_len, 1)
        x = target_emb
        for layer in self.dec_blocks:       
            x = layer(x, enc_outputs, enc_mask, dec_mask)            
        
        
        logits = self.linear_out(x)
        print(logits.shape, 'logits')
        
        logits = logits.view(-1, self.vocab_size)
        
        #out = F.softmax(logits)
        
        return logits

In [13]:
class Model(nn.Module):
    def __init__(self, encoder, decoder, hidden_size):
        super(Model, self).__init__()

        self.encoder = encoder
        self.decoder = decoder
        
        
    def forward(self, batch_context, batch_question,mask = None):
        
        enc_outputs, _ = self.encoder(batch_context)
        out = self.decoder(batch_question, enc_outputs, batch_context, False)
        
        return out

In [14]:
hidden_size = 256
num_heads   = 8
num_layers = 2
ff_size = 128

encoder = TransformerEncoder(len(train_dataset.words_vocab), num_layers, 
                hidden_size, 
                num_heads,
                ff_size,
                0).to(device)

decoder = TransformerDecoder(len(train_dataset.words_vocab), hidden_size, hidden_size, num_layers, 0).to(device)

criterion = nn.CrossEntropyLoss()
model = Model(encoder, decoder, hidden_size).to(device)
optimizer = optim.Adam(model.parameters())

In [15]:
class Trainer:
    def __init__(self, train_dataset, test_dataset, model, optimizer, criterion, batch_size):
        
        
        self.train_dataset = train_dataset
        self.test_dataset = test_dataset
        self.train_losses = []
        self.val_losses = []
        self.batch_size = batch_size
        
        self.optimizer = optimizer
        self.model = model
        self.criterion = criterion
        
               
    def train(self, n_epochs):
        
        mask_words = None
        for epoch in range(n_epochs):

            for batch_idx in range(len(self.train_dataset)//self.batch_size):
                self.optimizer.zero_grad()
                
                contexts, questions, answers = self.train_dataset.get_batch(32)
                
                logits = self.model(contexts, questions)
                print(logits.shape)
                print(answers.shape)

                loss = self.criterion(logits, answers)                        
                
                loss.backward()
                                
                self.optimizer.step()


                self.train_losses.append(loss.item())
                
                if batch_idx % 200 == 0:
                    val_loss = self.eval_()
                    self.val_losses.append(val_loss.item())
                    self.plot(epoch, batch_idx, self.train_losses, self.val_losses)
                  
        
    def eval_(self):
        
        val_words, val_trans_in, val_trans_out, val_words_lens, val_trans_lens = self.test_dataset.get_batch(len(self.test_dataset))
        val_mask = val_words != 0
        logits = self.model(val_words, val_trans_in, val_words_lens, val_mask)
        val_trans_out = val_trans_out.view(-1)                

        mask = val_trans_out != trans_vocab.pad_idx

        loss = self.criterion(logits[mask], val_trans_out[mask])

        return loss
        
    def plot(self, epoch, batch_idx, train_losses, val_losses):
        clear_output(True)
        plt.figure(figsize=(20,5))
        plt.subplot(131)
        plt.title('epoch %s. | batch: %s | loss: %s' % (epoch, batch_idx, train_losses[-1]))
        plt.plot(train_losses)
        plt.subplot(132)
        plt.title('epoch %s. | loss: %s' % (epoch, val_losses[-1]))
        plt.plot(val_losses)
        plt.show()  

In [16]:
trainer = Trainer(train_dataset, test_dataset, model, optimizer, criterion, batch_size = 32)

In [17]:
trainer.train(10)

torch.Size([32, 21, 60958]) logits
torch.Size([672, 60958])
torch.Size([32, 31])


ValueError: Expected input batch_size (672) to match target batch_size (32).