## DrQA 

This notebook implements model proposed in the paper: [Reading Wikipedia to Answer Open-Domain Questions](https://arxiv.org/abs/1704.00051) which is called DrQA by the authors. Specifically, DrQA is an end-to-end system for open domain question answering which involves an information retrieval system as well. This notebook however only explains the deep learning model proposed by them. This model is very similar to the one explained in [this](https://arxiv.org/abs/1606.02858) paper. The first authors in both the papers are also the same. The latter model is also known as "Stanford Attentive Reader" and is one of the models that is explained in Chris Manning's lecture on QA.   
The flow of all the notebooks will be as follows:
* Data Preprocessing: This section prepares the data for training and involves trademark NLP preprocessing steps. The functions being called here are imported from the script `preprocess.py`.
* Model: I've tried to explain the intuition behind each layer/component of the model.
 > *Texts coming from the paper are in block quotes like this.*   
 
 Along with the intuition, I explain how equations written in paper can be transformed into code to the best of my ability.
* Training: While I have tried my best to use similar training procedures as mentioned in the paper, there can be some changes. I do not have unlimited access to GPUs. 
* References: I provide an exhaustive list of resources/references that I used at the end of each notebook.


### Tensor Based Approach
All the notebooks are based on this approach. Ultimately, building neural nets is all about working with tensors. Knowing the shape and contents of each tensor is something that I have found very useful. Hence, after each line of code, I have commented the tensor shape and changes that happen due to the transformations in code. This makes the process of understanding whats going on in neural nets more intuitive.



## Data Preprocessing

In [1]:
import pandas as pd
import numpy as np
import torchtext
import torch
from torch import nn
import json, re, unicodedata, string, typing, time
import torch.nn.functional as F
import spacy
from collections import Counter
import pickle
from nltk import word_tokenize
nlp = spacy.load('en')
from preprocess import *
%load_ext autoreload
%autoreload 2

In [4]:
# load dataset json files

train_data = load_json('./data/squad_train.json')
valid_data = load_json('./data/squad_dev.json')

# parse the json structure to return the data as a list of dictionaries

train_list = parse_data(train_data)
valid_list = parse_data(valid_data)

print('Train list len: ',len(train_list))
print('Valid list len: ',len(valid_list))

# converting the lists into dataframes

train_df = pd.DataFrame(train_list)
valid_df = pd.DataFrame(valid_list)

In [7]:
def normalize_spaces(text):
    '''
    Removes extra white spaces from the context.
    '''
    text = re.sub(r'\s', ' ', text)
    return text

train_df.context = train_df.context.apply(normalize_spaces)
valid_df.context = valid_df.context.apply(normalize_spaces)

In [9]:
train_df.head()

Unnamed: 0,id,context,question,label,answer
0,5733be284776f41900661182,"Architecturally, the school has a Catholic cha...",To whom did the Virgin Mary allegedly appear i...,"[515, 541]",Saint Bernadette Soubirous
1,5733be284776f4190066117f,"Architecturally, the school has a Catholic cha...",What is in front of the Notre Dame Main Building?,"[188, 213]",a copper statue of Christ
2,5733be284776f41900661180,"Architecturally, the school has a Catholic cha...",The Basilica of the Sacred heart at Notre Dame...,"[279, 296]",the Main Building
3,5733be284776f41900661181,"Architecturally, the school has a Catholic cha...",What is the Grotto at Notre Dame?,"[381, 420]",a Marian place of prayer and reflection
4,5733be284776f4190066117e,"Architecturally, the school has a Catholic cha...",What sits on top of the Main Building at Notre...,"[92, 126]",a golden statue of the Virgin Mary


In [10]:
# gather text to build vocabularies

%time vocab_text = gather_text_for_vocab([train_df, valid_df])
print("Number of sentences in dataset: ", len(vocab_text))

Wall time: 616 ms


In [13]:
# build word vocabulary

%time word2idx, idx2word, word_vocab = build_word_vocab(vocab_text)

raw-vocab: 111086
vocab-length: 111088
word2idx-length: 111088
Wall time: 42.5 s


In [16]:
# numericalize context and questions for training and validation set


%time train_df['context_ids'] = train_df.context.apply(context_to_ids, word2idx=word2idx)
%time valid_df['context_ids'] = valid_df.context.apply(context_to_ids, word2idx=word2idx)

%time train_df['question_ids'] = train_df.question.apply(question_to_ids,  word2idx=word2idx)
%time valid_df['question_ids'] = valid_df.question.apply(question_to_ids,  word2idx=word2idx)

Wall time: 1min 49s
Wall time: 45.8 s
Wall time: 10.6 s
Wall time: 4.07 s


In [17]:
# get indices with tokenization errors and drop those indices 

train_err = get_error_indices(train_df, idx2word)
valid_err = get_error_indices(valid_df, idx2word)

train_df.drop(train_err, inplace=True)
valid_df.drop(valid_err, inplace=True)

Number of error indices: 1002
Number of error indices: 431


In [19]:
# get start and end positions of answers from the context
# this is basically the label for training QA models

train_label_idx = train_df.apply(index_answer, axis=1, idx2word=idx2word)
valid_label_idx = valid_df.apply(index_answer, axis=1, idx2word=idx2word)

train_df['label_idx'] = train_label_idx
valid_df['label_idx'] = valid_label_idx

### Dump data to pickle files 
This ensures that we can directly access the preprocessed dataframe next time.

In [37]:
import pickle
with open('drqastoi.pickle','wb') as handle:
    pickle.dump(word2idx, handle)
    
train_df.to_pickle('drqatrain.pkl')
valid_df.to_pickle('drqavalid.pkl')

### Read data from pickle files

You only need to run the preprocessing once. Some preprocessing functions can take upto 3 mins. Therefore, pickling preprocessed data can save a lot of time.
Once the preprocessed files are saved, you can directly start from here.

In [21]:
train_df = pd.read_pickle('drqatrain.pkl')
valid_df = pd.read_pickle('drqavalid.pkl')

## Dataset/ Dataloader

In [32]:
class SquadDataset:
    '''
    -Divides the dataframe in batches.
    -Pads the contexts and questions dynamically for each batch by padding 
     the examples to the maximum-length sequence in that batch.
    -Calculates masks for context and question.
    -Calculates spans for contexts.
    '''
    
    def __init__(self, data, batch_size):
        
        self.batch_size = batch_size
        data = [data[i:i+self.batch_size] for i in range(0, len(data), self.batch_size)]
        self.data = data
    
    def get_span(self, text):
        
        text = nlp(text, disable=['parser','tagger','ner'])
        span = [(w.idx, w.idx+len(w.text)) for w in text]

        return span

    def __len__(self):
        return len(self.data)
    
    def __iter__(self):
        '''
        Creates batches of data and yields them.
        
        Each yield comprises of:
        :padded_context: padded tensor of contexts for each batch 
        :padded_question: padded tensor of questions for each batch 
        :context_mask & question_mask: zero-mask for question and context
        :label: start and end index wrt context_ids
        :context_text,answer_text: used while validation to calculate metrics
        :context_spans: spans of context text
        :ids: question_ids used in evaluation
        '''
        
        for batch in self.data:
                            
            spans = []
            context_text = []
            answer_text = []
            
            max_context_len = max([len(ctx) for ctx in batch.context_ids])
            padded_context = torch.LongTensor(len(batch), max_context_len).fill_(1)
            
            for ctx in batch.context:
                context_text.append(ctx)
                spans.append(self.get_span(ctx))
            
            for ans in batch.answer:
                answer_text.append(ans)
                
            for i, ctx in enumerate(batch.context_ids):
                padded_context[i, :len(ctx)] = torch.LongTensor(ctx)
            
            max_question_len = max([len(ques) for ques in batch.question_ids])
            padded_question = torch.LongTensor(len(batch), max_question_len).fill_(1)
            
            for i, ques in enumerate(batch.question_ids):
                padded_question[i,: len(ques)] = torch.LongTensor(ques)
                
            
            label = torch.LongTensor(list(batch.label_idx))
            context_mask = torch.eq(padded_context, 1)
            question_mask = torch.eq(padded_question, 1)
            
            ids = list(batch.id)  
            
            yield (padded_context, padded_question, context_mask, 
                   question_mask, label, context_text, answer_text, ids)
            
            

In [33]:
train_dataset = SquadDataset(train_df, 32)

In [34]:
valid_dataset = SquadDataset(valid_df, 32)

In [58]:
a = next(iter(train_dataset))

In [59]:
a[0].shape, a[1].shape, a[2].shape, a[3].shape, a[4].shape

(torch.Size([32, 253]),
 torch.Size([32, 19]),
 torch.Size([32, 253]),
 torch.Size([32, 19]),
 torch.Size([32, 2]))

## Model

Before we dive deep into the intricacies of the model, let's set up the notations. An input example during training is comprised of 
* a paragraph / context $p$ consisting of $l$ tokens { $p_{1}$, $p_{2}$,..., $p_{l}$ }
* a question $q$ consisting of $m$ tokens { $q_{1}$, $q_{2}$,..., $q_{m}$ }
* a start and and end position that comes from the context itself. More specifically, the start and end indices of the answer from the context  

The following flowchart shows the flow of the model. It might not make sense now, but as we progress down the chart and build all the components, things will become clearer.


<img src="images/drqaflow.PNG" width="700" height="800"/>


### Word Embedding

The first transformation for both the question and the context tokens is that they are passed through an embedding layer initialized with pre-trained GloVe word vectors. 300-dimensional vectors from 840B web crawl version are used here. This version of GLoVe has a vocabulary of 2.2M words. Out of vocabulary or OOV words are initialized by a zero vector. OOV words are the words that are present in your dataset but not in the pretrained vocabulary of GLoVe.  
These word vectors are used to project/convert a word into a floating point vector which encodes various features associated with the word into its dimensions. Such a conversion is necassary since computers cannot process words as strings but can seamlessly work with a large number of floating point matrices.   
A dot product between vectors of word that are semantically similar is close to 1 and vice-versa.




In [35]:
def create_glove_matrix():
    '''
    Parses the glove word vectors text file and returns a dictionary with the words as
    keys and their respective pretrained word vectors as values.

    '''
    glove_dict = {}
    with open("./glove.840B.300d/glove.840B.300d.txt", "r", encoding="utf-8") as f:
        for line in f:
            values = line.split(' ')
            word = values[0]
            vector = np.asarray(values[1:], dtype="float32")
            glove_dict[word] = vector

    f.close()
    
    return glove_dict

In [1]:
glove_dict = create_glove_matrix()

In [None]:
def create_word_embedding(glove_dict):
    '''
    Creates a weight matrix of the words that are common in the GloVe vocab and
    the dataset's vocab. Initializes OOV words with a zero vector.
    '''
    weights_matrix = np.zeros((len(word_vocab), 300))
    words_found = 0
    for i, word in enumerate(word_vocab):
        try:
            weights_matrix[i] = glove_dict[word]
            words_found += 1
        except:
            pass
    return weights_matrix, words_found

In [None]:
weights_matrix, words_found = create_word_embedding(glove_dict)

In [None]:
print("Total words found in glove vocab: ", words_found)

In [64]:
np.save('drqaglove_vt.npy',weights_matrix)

### Align Question Embedding

The paper has different encoding procedures for the context and the question. The context/paragraph encoding is more exhaustive and comprises of following additional features:

* exact match : encodes a binary feature if $p$ can be exactly matched to one word in question in its original, lemma or lowercase form
* token features : Includes POS, NER and TF of context tokens and
* aligned question embedding ($f_{align}$) .  

In this re-implementation I have only implemented the aligned question embedding. The other features can be added easily but they do not affect the metrics by a large margin(~2).  
$f_{align}$ has been formulated as shown below:

$$ f_{align} = \sum_{j}a_{i,j}E(q_{j}) $$ 

where $E()$ represents the glove embeddings and

<img src="images/drqa1.PNG" width="400" height="200"/>

where $\alpha()$ is a single dense layer with relu non-linearity. This transformation can be thought of as a projection to a new vector sub-space. The weights of the projection matrix will be learnt via backpropogation.
These equations can be converted into code quite easily. Lets break this down into smaller chunks and understand what's going on actually. 
<img src="images/drqa2.PNG" width="200" height="150"/>
This is simply the product of projections of glove embeddings of the context and the question. Careful inspection of the equation for $a_{i,j}$ reveals that it is actually a softmax of the above product. The equations above depict everything at token level where $i$ represents a context token and $j$ represents a question token. Practically we usually vectorize our computations and deal with tensors directly.
$f_{align}$ is a weighted representation of the question embeddings. $a_{i,j}$ represents the weights and hence a softmax function is necessary.  
#### Intuition
This feature enables the model to understand what portion of the context is more important or relevant with respect to the question. The products of projections taken at token level ensure a higher value when similar words from the question and context are multiplied. Quoting the paper,
> *these features add soft alignments between similar but non-identical words (e.g., car and vehicle).* 

This is achieved via backpropation and training the weights of the dense layer. While this might seem a bit weird initially, we have to trust the process of backpropogation.   

While implementing, we first calculate the projections of context and question vectors. We then use `torch.bmm` to calculate the product in the numerator of $a_{i,j}$, mask the product and then pass it through the softmax function to get $a_{i,j}$. Finally, we multiply this with the question embeddings. The output of this layer is an additional context embedding which is then concatenated with the glove embeddings.

In [None]:
class AlignQuestionEmbedding(nn.Module):
    
    def __init__(self, input_dim):        
        
        super().__init__()
        
        self.linear = nn.Linear(input_dim, input_dim)
        
        self.relu = nn.ReLU()
        
    def forward(self, context, question, question_mask):
        
        # context = [bs, ctx_len, emb_dim]
        # question = [bs, qtn_len, emb_dim]
        # question_mask = [bs, qtn_len]
    
        ctx_ = self.linear(context)
        ctx_ = self.relu(ctx_)
        # ctx_ = [bs, ctx_len, emb_dim]
        
        qtn_ = self.linear(question)
        qtn_ = self.relu(qtn_)
        # qtn_ = [bs, qtn_len, emb_dim]
        
        qtn_transpose = qtn_.permute(0,2,1)
        # qtn_transpose = [bs, emb_dim, qtn_len]
        
        align_scores = torch.bmm(ctx_, qtn_transpose)
        # align_scores = [bs, ctx_len, qtn_len]
        
        qtn_mask = question_mask.unsqueeze(1).expand(align_scores.size())
        # qtn_mask = [bs, 1, qtn_len] => [bs, ctx_len, qtn_len]
        
        # Fills elements of self tensor(align_scores) with value(-float(inf)) where mask is True. 
        # The shape of mask must be broadcastable with the shape of the underlying tensor.
        align_scores = align_scores.masked_fill(qtn_mask == 1, -float('inf'))
        # align_scores = [bs, ctx_len, qtn_len]
        
        align_scores_flat = align_scores.view(-1, question.size(1))
        # align_scores = [bs*ctx_len, qtn_len]
        
        alpha = F.softmax(align_scores_flat, dim=1)
        alpha = alpha.view(-1, context.shape[1], question.shape[1])
        # alpha = [bs, ctx_len, qtn_len]
        
        align_embedding = torch.bmm(alpha, question)
        # align = [bs, ctx_len, emb_dim]
        
        return align_embedding

## Stacked BiLSTM

The paragraph/context encoding which now has two features (glove and $f_{align}$) is then passed to a multilayer (3 layers) bidirectional LSTM. According to the paper,

> *Speciﬁcally, we choose to use a multi-layer bidirectional long short-term memory network (LSTM), and take the concatenation of each layer’s hidden units in the end. *

To achieve this functionality we cannot directly use the pytorch recurrent layers. Every recurrent layer in pytorch returns a tuple `[output, hidden]` where `output` holds the hidden states of all the timesteps from the __last layer only__. We need to access the hidden states of intermediate layers and then concatenate them at the end.
The following figure illustrates this point in more detail.

<img src="images/bilstm.png" width="700" height="600"/>

This figure shows a 3-layer bidirectional LSTM with an input sequence of size $n$. The green blocks denote the forward LSTMs and the blue blocks backward. Each block is labelled with the value that it calculates. The subscript denotes the time-step and the superscript denotes the depth or the layer-number. For example $hf_{1}^{(0)}$ calculates the first hidden state in forward LSTM in the first layer. 

As highlighted in the diagram, we need the intermediate hidden states passed between the layers along with the final output. To create this in code, we create a `nn.ModuleList` and add 3 LSTM layers to it. The input size of the first layer remains the same but for subsequent LSTMs the input size must be twice the hidden size. This is because the `output` of the first LSTM will have the dimension of `[batch_size, seq_len, hidden_size*num_directions]` and `num_directions` is 2 in our case. In the forward method, we loop through the LSTMs, store the hidden states of each layer and finally return the concatenated output. 



In [None]:
class StackedBiLSTM(nn.Module):
    
    def __init__(self, input_dim, hidden_dim, num_layers, dropout):
        
        super().__init__()
        
        self.dropout = dropout
        
        self.num_layers = num_layers
        
        self.lstms = nn.ModuleList()
        
        for i in range(self.num_layers):
            
            input_dim = input_dim if i == 0 else hidden_dim * 2
            
            self.lstms.append(nn.LSTM(input_dim, hidden_dim,
                                      batch_first=True, bidirectional=True))
           
    
    def forward(self, x):
        # x = [bs, seq_len, feature_dim]

        outputs = [x]
        for i in range(self.num_layers):

            lstm_input = outputs[-1]
            lstm_out = F.dropout(lstm_input, p=self.dropout)
            lstm_out, (hidden, cell) = self.lstms[i](lstm_input)
           
            outputs.append(lstm_out)

    
        output = torch.cat(outputs[1:], dim=2)
        # [bs, seq_len, num_layers*num_dir*hidden_dim]
        
        output = F.dropout(output, p=self.dropout)
      
        return output

## Linear Attention Layer

The paper does not mention specifically the name of this layer and I have named this based on my understanding of its functionality. The previous layers were majorly about encoding and representing the context. This layer is used to encode the question and is much simpler than the previous layers. The question tokens are first passed through the glove embedding layer, then passed through the bilstm layer and finally reach this layer. 
This layer is used to calculate the importance of each word in the question. This can be achieved by simply taking a softmax over the input. However to add more learning capacity to the model, the inputs are multiplied by a trainable weight vector $w$ and then passed through a softmax function.  
This layer calculates the weights as 
<img src="images/drqab.PNG" width="300" height="300"/>

Essentially the layer is performing "attention" on inputs. The $w$ in code is characterized by a linear layer.

In [None]:
class LinearAttentionLayer(nn.Module):
    
    def __init__(self, input_dim):
        
        super().__init__()
        
        self.linear = nn.Linear(input_dim, 1)
        
    def forward(self, question, question_mask):
        
        # question = [bs, qtn_len, input_dim] = [bs, qtn_len, bi_lstm_hid_dim]
        # question_mask = [bs,  qtn_len]
        
        qtn = question.view(-1, question.shape[-1])
        # qtn = [bs*qtn_len, hid_dim]
        
        attn_scores = self.linear(qtn)
        # attn_scores = [bs*qtn_len, 1]
        
        attn_scores = attn_scores.view(question.shape[0], question.shape[1])
        # attn_scores = [bs, qtn_len]
        
        attn_scores = attn_scores.masked_fill(question_mask == 1, -float('inf'))
        
        alpha = F.softmax(attn_scores, dim=1)
        # alpha = [bs, qtn_len]
        
        return alpha
        

The following function just multiplies the weights calculated in the previous layer by the outputs of the question bilstm layer. This allows the model to assign higher values to important words in each question.

$$ q = \sum_{j} b_{j} q_{j} $$

In [None]:
def weighted_average(x, weights):
    # x = [bs, len, dim]
    # weights = [bs, len]
    
    weights = weights.unsqueeze(1)
    # weights = [bs, 1, len]
    
    w = weights.bmm(x).squeeze(1)
    # w = [bs, 1, dim] => [bs, dim]
    
    return w

## Introduction to Attention in NLP

Attention as a concept in NLP was introduced in 2015 by Bahdanau et al. to improve neural machine translation systems. Before this paper, NMT systems were largely based on seq2seq architectures which had an encoder to encode a representation of the source language and a decoder which to decode this representation into the target language. Such models were trained on large quantities of parallel text data of two languages. One major drawback of this architecture was that it didn't work well for longer documents/sequences. This is because the entire information in the source sentence was being crammed into a single vector. If this vector fails to capture the important information from the source language, the system is going to perform poorly.   
<img src="images/seq2seq.PNG" width="600" height="700"/>

When we claim that these neural nets mimic the human brain, this is certainly not how the human brain works. While *learning* about some topic, we do not simply read 2-3 pages of content and expect our brain to remember all the details in the first go. We usually revisit various concepts, recollect and refer to the material back and forth before mastering it. The attention mechanism in NMT was designed to do this. While decoding at any particular time step, encoder hidden states from all the time-steps are made available to the decoder. The decoder then can look back at the encoder hidden states or the source language and make a more informed prediction at a particular time-step. This alieviates the problem of all the information from source language being crammed into a single vector.  
To illustrate this with equations, consider that the hidden states of the encoder RNN are represented by $H$ = {$h_{1}, h_{2}, h_{3},...,h_{t}$}. While decoding the token at position $t$, the input to the decoder unit is hidden state from previous unit $s_{t-1}$ and an attention vector which is a selective summary of the encoder hidden states and helps the decoder to pay more attention to a particular encoder state. 
The similarity between the encoder hidden states $H$ and the decoder hidden state so far $s_{t-1}$ is computed by,  
$$ \alpha = tanh (W [H ; s_{t-1}]) $$   

$\alpha$ is then passed through a softmax layer to obtain attention distribution such that $\sum_{t} \alpha_{t}$ = 1.
The final step is calculating the attention vector by taking a weighted sum of the encoder hidden states,
$$ \sum_{t} \alpha_{t} h_{t} $$

The following diagram illustrates this process.  
 
<img src="images/attnkj.PNG" width="600" height="100"/>


Since then, many different forms of attention have been proposed and used in the literature. Attention is not limited to NMT systems and has evolved into a more general concept in NLP. At the heart of it attention is about summarizing a particular entity/representation by *attending* to the important parts of this representation. 
A more general definition of attention is as follows:

> *__Given a set of vectors `values`, and a single vector `query`, attention is a method to calculate a weighted sum of the values, where the query determines which values to focus on.__*

It is a way to obtain a fixed size representation of an arbitrary set of representations (values), dependent on some other representation (query).   

In our earlier NMT example, the encoder hidden states {$h_{1}, h_{2}, h_{3},...,h_{t}$} are the __*values*__ and the decoder hidden state $s_{t-1}$ is the __*query*__.  

### A More General Take On Attention

In general there are 3 steps when calculating the attention. Consider that values are represented by {$h_{1}, h_{2}, h_{3},..h_{n}$} and query is $s$. Then attention always involves,

1. Calculating the energy $e$ or attention scores between these 2 vectors,
$e$   $ \epsilon$  $ R^{N} $
2. Taking softmax to get an attention distribution $\alpha$, $\alpha$ $\epsilon$ $R^{N}$

$$ \alpha = softmax(e)$$ 
$$ \sum_{t}^{N} \alpha_{t} = 1 $$

3. Taking the weighted sum of the `values` by using $\alpha$
$$ a = \sum_{t}^{N}\alpha_{t}h_{t} $$


Now there are different ways to calculate the energy between `query` and `values`. 
* **Basic Dot Product Attention**    
$$ e_{t} = s^{T}h_{t}$$      
* **Additive Attention**
$$ e_{t} = v^{T} tanh (W [h_{t};s])$$  
This is nothing but the Bahdanau attention first proposed for NMT systems.
* **Scaled Dot Product Attention**
$$ e_{t} = s^{T}h_{t}/\sqrt n$$
where $n$ is the model size. A modified version of this proposed in the Transformers paper by Vaswani et al. is now employed in almost every NLP system.

* **Bilinear Attention**
$$ e_{t} = s^{T} W h_{t}$$
where $W$ is a trainable weight vector.
This is the method used in this paper to predict the start and end position of the answer from the context.    


To implement this layer, we characterise $W$ by a linear layer.
First the linear layer is applied to the question, which is equivalent to the product $W.q$. This product is then multiplied by the context using `torch.bmm`.   
Note that softmax is not taken over here to get the weights. This is taken care of when we calculate the loss using cross entropy. The following layer does not actually calculate the attention as a weighted sum. It just uses the bilinear term's representation to predict the span. However the intuition behind the bilinear term still remains the same.

In [None]:
class BilinearAttentionLayer(nn.Module):
    
    def __init__(self, context_dim, question_dim):
        
        super().__init__()
        
        self.linear = nn.Linear(question_dim, context_dim)
        
    def forward(self, context, question, context_mask):
        
        # context = [bs, ctx_len, ctx_hid_dim] = [bs, ctx_len, hid_dim*6] = [bs, ctx_len, 768]
        # question = [bs, qtn_hid_dim] = [bs, qtn_len, 768]
        # context_mask = [bs, ctx_len]
        
        qtn_proj = self.linear(question)
        # qtn_proj = [bs, ctx_hid_dim]
        
        qtn_proj = qtn_proj.unsqueeze(2)
        # qtn_proj = [bs, ctx_hid_dim, 1]
        
        scores = context.bmm(qtn_proj)
        # scores = [bs, ctx_len, 1]
        
        scores = scores.squeeze(2)
        # scores = [bs, ctx_len]
        
        scores = scores.masked_fill(context_mask == 1, -float('inf'))
        
        #alpha = F.log_softmax(scores, dim=1)
        # alpha = [bs, ctx_len]
        
        return scores

## Putting it together

The following module brings all the components discussed so far together. It takes in the context and question tokens as inputs and returns the start and end positions of the answer from the context.  

<img src="images/drqaflow.PNG" width="600" height="600"/>

  
Going down the flowchart, following steps are performed in sequence:  
* The context and question tokens are passed through the Glove embedding layer. The glove embeddings are partially finetuned during training. According to the paper,  
> *We keep most of the pre-trained word embeddings ﬁxed and only ﬁne-tune the 1000 most frequent question words because the representations of some key words such as what, how, which, many could be crucial for QA systems.*   

In code, this is done by using hooks in pytorch. Hooks work as a callback functions and are executed after `forward` or `backward` function is called for a particular tensor. You should read more about this in their documentation.

* Aligned question embedding is calculated for the context vector and concatenated (using `torch.cat`) to the context representation. If $d$ is the embedding dimension then context $\epsilon$ $R^{2d}$ and question $\epsilon$ $R^{d}$.
* Context and question representations are then passed to bilstm layers to obtain tensors of dimension `[batch_size, seq_len, hidden_dim*6]` since the LSTM is bidirectional and there are 3 layers of it.
* The embedded question is also passed through the linear attention layer and a weighted sum of its output is taken with the biLSTM output.
* Both these representations are finally passed through the bilinear attention layer to predict the start and end position of the answer.   

An intriguing point here is that the same set of weights are passed to the bilinear attention layers. Yet how do they predict different things. This is left over to the neural network to learn. Our loss function ensures that our objective is to predict different positions from the context. It is now the neural net's responsibility to learn different weights for each layer. It is sort of a "black-box" and we have to trust the process of backpropogation.

In [2]:
class DocumentReader(nn.Module):
    
    def __init__(self, hidden_dim, embedding_dim, num_layers, num_directions, dropout, device):
        
        super().__init__()
        
        self.device = device
        
        #self.embedding = self.get_glove_embedding()
        
        self.context_bilstm = StackedBiLSTM(embedding_dim * 2, hidden_dim, num_layers, dropout)
        
        self.question_bilstm = StackedBiLSTM(embedding_dim, hidden_dim, num_layers, dropout)
        
        self.glove_embedding = self.get_glove_embedding()
        
        def tune_embedding(grad, words=1000):
            grad[words:] = 0
            return grad
        
        self.glove_embedding.weight.register_hook(tune_embedding)
        
        self.align_embedding = AlignQuestionEmbedding(embedding_dim)
        
        self.linear_attn_question = LinearAttentionLayer(hidden_dim*num_layers*num_directions) 
        
        self.bilinear_attn_start = BilinearAttentionLayer(hidden_dim*num_layers*num_directions, 
                                                          hidden_dim*num_layers*num_directions)
        
        self.bilinear_attn_end = BilinearAttentionLayer(hidden_dim*num_layers*num_directions,
                                                        hidden_dim*num_layers*num_directions)
        
        self.dropout = nn.Dropout(dropout)
   
        
    def get_glove_embedding(self):
        
        weights_matrix = np.load('drqaglove_vt.npy')
        num_embeddings, embedding_dim = weights_matrix.shape
        embedding = nn.Embedding.from_pretrained(torch.FloatTensor(weights_matrix).to(self.device),freeze=False)

        return embedding
    
    
    def forward(self, context, question, context_mask, question_mask):
       
        # context = [bs, len_c]
        # question = [bs, len_q]
        # context_mask = [bs, len_c]
        # question_mask = [bs, len_q]
        
        
        ctx_embed = self.glove_embedding(context)
        # ctx_embed = [bs, len_c, emb_dim]
        
        ques_embed = self.glove_embedding(question)
        # ques_embed = [bs, len_q, emb_dim]
        

        ctx_embed = self.dropout(ctx_embed)
     
        ques_embed = self.dropout(ques_embed)
             
        align_embed = self.align_embedding(ctx_embed, ques_embed, question_mask)
        # align_embed = [bs, len_c, emb_dim]  
        
        ctx_bilstm_input = torch.cat([ctx_embed, align_embed], dim=2)
        # ctx_bilstm_input = [bs, len_c, emb_dim*2]
                
        ctx_outputs = self.context_bilstm(ctx_bilstm_input)
        # ctx_outputs = [bs, len_c, hid_dim*layers*dir] = [bs, len_c, hid_dim*6]
       
        qtn_outputs = self.question_bilstm(ques_embed)
        # qtn_outputs = [bs, len_q, hid_dim*6]
    
        qtn_weights = self.linear_attn_question(qtn_outputs, question_mask)
        # qtn_weights = [bs, len_q]
            
        qtn_weighted = weighted_average(qtn_outputs, qtn_weights)
        # qtn_weighted = [bs, hid_dim*6]
        
        start_scores = self.bilinear_attn_start(ctx_outputs, qtn_weighted, context_mask)
        # start_scores = [bs, len_c]
         
        end_scores = self.bilinear_attn_end(ctx_outputs, qtn_weighted, context_mask)
        # end_scores = [bs, len_c]
        
      
        return start_scores, end_scores

###  Hyperparameters

> *We use 3-layer bidirectional LSTMs with h = 128 hidden units for both paragraph and question encoding. Dropout with p = 0.3 is applied to word embeddings and all the hidden units of LSTMs. *

In [43]:
device = torch.device('cuda')
HIDDEN_DIM = 128
EMB_DIM = 300
NUM_LAYERS = 3
NUM_DIRECTIONS = 2
DROPOUT = 0.3
device = torch.device('cuda')

model = DocumentReader(HIDDEN_DIM,
                       EMB_DIM, 
                       NUM_LAYERS, 
                       NUM_DIRECTIONS, 
                       DROPOUT, 
                       device).to(device)

## Training

In [None]:
optimizer = torch.optim.Adamax(model.parameters())

In [45]:
def count_parameters(model):
    '''Returns the number of trainable parameters in the model.'''
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 37,367,549 trainable parameters


In [22]:
def train(model, train_dataset):
    '''
    Trains the model.
    '''
    
    print("Starting training ........")
    
    train_loss = 0.
    batch_count = 0
    
    # put the model in training mode
    model.train()
    
    # iterate through training data
    for batch in train_dataset:

        if batch_count % 500 == 0:
            print(f"Starting batch: {batch_count}")
        batch_count += 1

        context, question, context_mask, question_mask, label, ctx, ans, ids = batch
        
        # place the tensors on GPU
        context, context_mask, question, question_mask, label = context.to(device), context_mask.to(device),\
                                    question.to(device), question_mask.to(device), label.to(device)
        
        # forward pass, get the predictions
        preds = model(context, question, context_mask, question_mask)

        start_pred, end_pred = preds
        
        # separate labels for start and end position
        start_label, end_label = label[:,0], label[:,1]
        
        # calculate loss
        loss = F.cross_entropy(start_pred, start_label) + F.cross_entropy(end_pred, end_label)
        
        # backward pass, calculates the gradients
        loss.backward()
        
        # gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), 10)
        
        # update the gradients
        optimizer.step()
        
        # zero the gradients to prevent them from accumulating
        optimizer.zero_grad()

        train_loss += loss.item()

    return train_loss/len(train_dataset)

In [49]:
def valid(model, valid_dataset):
    '''
    Performs validation.
    '''
    
    print("Starting validation .........")
   
    valid_loss = 0.

    batch_count = 0
    
    f1, em = 0., 0.
    
    # puts the model in eval mode. Turns off dropout
    model.eval()
    
    predictions = {}
    
    for batch in valid_dataset:

        if batch_count % 500 == 0:
            print(f"Starting batch {batch_count}")
        batch_count += 1

        context, question, context_mask, question_mask, label, context_text, answers, ids = batch

        context, context_mask, question, question_mask, label = context.to(device), context_mask.to(device),\
                                    question.to(device), question_mask.to(device), label.to(device)

        with torch.no_grad():

            preds = model(context, question, context_mask, question_mask)

            p1, p2 = preds

            y1, y2 = label[:,0], label[:,1]

            loss = F.cross_entropy(p1, y1) + F.cross_entropy(p2, y2)

            valid_loss += loss.item()

            
            # get the start and end index positions from the model preds
            
            batch_size, c_len = p1.size()
            ls = nn.LogSoftmax(dim=1)
            mask = (torch.ones(c_len, c_len) * float('-inf')).to(device).tril(-1).unsqueeze(0).expand(batch_size, -1, -1)
            
            score = (ls(p1).unsqueeze(2) + ls(p2).unsqueeze(1)) + mask
            score, s_idx = score.max(dim=1)
            score, e_idx = score.max(dim=1)
            s_idx = torch.gather(s_idx, 1, e_idx.view(-1, 1)).squeeze()
            
            # stack predictions
            for i in range(batch_size):
                id = ids[i]
                pred = context[i][s_idx[i]:e_idx[i]+1]
                pred = ' '.join([idx2word[idx.item()] for idx in pred])
                predictions[id] = pred
            
            
            
    em, f1 = evaluate(predictions)            
    return valid_loss/len(valid_dataset), em, f1
                

In [51]:
def evaluate(predictions):
    '''
    Gets a dictionary of predictions with question_id as key
    and prediction as value. The validation dataset has multiple 
    answers for a single question. Hence we compare our prediction
    with all the answers and choose the one that gives us
    the maximum metric (em or f1). 
    This method first parses the JSON file, gets all the answers
    for a given id and then passes the list of answers and the 
    predictions to calculate em, f1.
    
    
    :param dict predictions
    Returns
    : exact_match: 1 if the prediction and ground truth 
      match exactly, 0 otherwise.
    : f1_score: 
    '''
    with open('./data/squad_dev.json','r',encoding='utf-8') as f:
        dataset = json.load(f)
        
    dataset = dataset['data']
    f1 = exact_match = total = 0
    for article in dataset:
        for paragraph in article['paragraphs']:
            for qa in paragraph['qas']:
                total += 1
                if qa['id'] not in predictions:
                    continue
                
                ground_truths = list(map(lambda x: x['text'], qa['answers']))
                
                prediction = predictions[qa['id']]
                
                exact_match += metric_max_over_ground_truths(
                    exact_match_score, prediction, ground_truths)
                
                f1 += metric_max_over_ground_truths(
                    f1_score, prediction, ground_truths)
                
    
    exact_match = 100.0 * exact_match / total
    f1 = 100.0 * f1 / total
    
    return exact_match, f1



In [24]:
def normalize_answer(s):
    '''
    Performs a series of cleaning steps on the ground truth and 
    predicted answer.
    '''
    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
    '''
    Returns maximum value of metrics for predicition by model against
    multiple ground truths.
    
    :param func metric_fn: can be 'exact_match_score' or 'f1_score'
    :param str prediction: predicted answer span by the model
    :param list ground_truths: list of ground truths against which
                               metrics are calculated. Maximum values of 
                               metrics are chosen.
                            
    
    '''
    scores_for_ground_truths = []
    for ground_truth in ground_truths:
        score = metric_fn(prediction, ground_truth)
        scores_for_ground_truths.append(score)
        
    return max(scores_for_ground_truths)


def f1_score(prediction, ground_truth):
    '''
    Returns f1 score of two strings.
    '''
    prediction_tokens = normalize_answer(prediction).split()
    ground_truth_tokens = normalize_answer(ground_truth).split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1


def exact_match_score(prediction, ground_truth):
    '''
    Returns exact_match_score of two strings.
    '''
    return (normalize_answer(prediction) == normalize_answer(ground_truth))


def epoch_time(start_time, end_time):
    '''
    Helper function to record epoch time.
    '''
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [23]:

train_losses = []
valid_losses = []
ems = []
f1s = []
epochs = 5

for epoch in range(epochs):
    print(f"Epoch {epoch+1}")
    
    start_time = time.time()
    
    train_loss = train(model, train_dataset)
    valid_loss, em, f1 = valid(model, valid_dataset)
    
    end_time = time.time()
    
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    train_losses.append(train_loss)
    valid_losses.append(valid_loss)
    ems.append(em)
    f1s.append(f1)
    
    print(f"Epoch train loss : {train_loss}| Time: {epoch_mins}m {epoch_secs}s")
    print(f"Epoch valid loss: {valid_loss}")
    print(f"Epoch EM: {em}")
    print(f"Epoch F1: {f1}")
    print("====================================================================================")
    

## References

* Papers read/referenced
    1. https://arxiv.org/abs/1704.00051
    2. https://arxiv.org/abs/1606.02858
    3. https://arxiv.org/abs/1409.0473
* Other helpful links
    1. https://lilianweng.github.io/lil-log/2018/06/24/attention-attention.html
    2. https://github.com/facebookresearch/DrQA
    3. https://github.com/hitvoice/DrQA. Special thanks to [Runqi Yang](https://github.com/hitvoice) who helped me clarify some doubts with respect to preprocessing the SQUAD dataset.
    4. https://towardsdatascience.com/the-definitive-guide-to-bidaf-part-3-attention-92352bbdcb07. Good introduction to attention.
    5. https://web.stanford.edu/class/archive/cs/cs224n/cs224n.1184/lectures/lecture10.pdf. The attention section of this notebook is largely inspired and derived from these slides.
* Following links are related to debugging neural nets. Something on which I was stuck for quite some time during training these models.
    1. https://datascience.stackexchange.com/questions/410/choosing-a-learning-rate
    2. https://www.jeremyjordan.me/nn-learning-rate/
    3. https://towardsdatascience.com/estimating-optimal-learning-rate-for-a-deep-neural-network-ce32f2556ce0
    4. https://towardsdatascience.com/learning-rate-schedules-and-adaptive-learning-rate-methods-for-deep-learning-2c8f433990d1
    5. https://towardsdatascience.com/checklist-for-debugging-neural-networks-d8b2a9434f21
    6. https://arxiv.org/abs/1708.07120
    7. https://sgugger.github.io/how-do-you-find-a-good-learning-rate.html