# Natural Language Processing

## QANet

The papers that we've seen so far have been heavily based on recurrent neural nets and attention. However, RNNs are slow to train given their sequential nature and are also slow for inference. QANet was proposed in early 2018. This paper does away with recurrence and is only based on self-attention and convolutions. 

The key motivation behind the design of the model is that: convolution captures the **local** structure of the text, while the self-attention learns the **global** interaction between each pair of words.

In [1]:
import numpy as np
import time
import torch
from torch import nn
import torch.nn.functional as F

import spacy
nlp = spacy.load('en_core_web_sm')

## 1. Load preprocessed data

This time, we shall enjoy the privilege of only loading the pickles that we have made.

In [2]:
import pickle

with open('bidafw2id.pickle','rb') as handle:
    word2idx = pickle.load(handle)
with open('bidafc2id.pickle','rb') as handle:
    char2idx = pickle.load(handle)

In [3]:
import pandas as pd

train_df = pd.read_pickle('bidaftrain.pkl')
valid_df = pd.read_pickle('bidafvalid.pkl')

In [4]:
idx2word = {v:k for k,v in word2idx.items()}

## 2. Preparing Dataloader/Dataset

No changes from previous part

In [5]:
class SquadDataset:
    '''
    - Creates batches dynamically by padding to the length of largest example
      in a given batch.
    - Calulates character vectors for contexts and question.
    - Returns tensors for training.
    '''
    
    def __init__(self, data, batch_size):

        self.batch_size = batch_size
        data = [data[i:i+self.batch_size] for i in range(0, len(data), self.batch_size)]
        self.data = data
        
        
    def __len__(self):
        return len(self.data)
    
    def make_char_vector(self, max_sent_len, sentence, max_word_len=16):
        
        char_vec = torch.zeros(max_sent_len, max_word_len).type(torch.LongTensor)
        
        for i, word in enumerate(nlp(sentence, disable=['parser','ner'])):
            for j, ch in enumerate(word.text):
                if j == max_word_len:
                    break
                char_vec[i][j] = char2idx.get(ch, 0)
        
        return char_vec     
    
    def get_span(self, text):

        text = nlp(text, disable=['parser','ner'])
        span = [(w.idx, w.idx+len(w.text)) for w in text]

        return span

    def __iter__(self):
        '''
        Creates batches of data and yields them.
        
        Each yield comprises of:
        :padded_context: padded tensor of contexts for each batch 
        :padded_question: padded tensor of questions for each batch 
        :char_ctx & ques_ctx: character-level ids for context and question
        :label: start and end index wrt context_ids
        :context_text,answer_text: used while validation to calculate metrics
        :ids: question_ids for evaluation
        '''
        
        for batch in self.data:
            
            spans = []
            ctx_text = []
            answer_text = []
            
            for ctx in batch.context:
                ctx_text.append(ctx)
                spans.append(self.get_span(ctx))
            
            for ans in batch.answer:
                answer_text.append(ans)
                
            max_context_len = max([len(ctx) for ctx in batch.context_ids])
            padded_context = torch.LongTensor(len(batch), max_context_len).fill_(1)
            
            for i, ctx in enumerate(batch.context_ids):
                padded_context[i, :len(ctx)] = torch.LongTensor(ctx)
                
            max_word_ctx = 16
          
            char_ctx = torch.zeros(len(batch), max_context_len, max_word_ctx).type(torch.LongTensor)
            for i, context in enumerate(batch.context):
                char_ctx[i] = self.make_char_vector(max_context_len, context)
            
            max_question_len = max([len(ques) for ques in batch.question_ids])
            padded_question = torch.LongTensor(len(batch), max_question_len).fill_(1)
            
            for i, ques in enumerate(batch.question_ids):
                padded_question[i, :len(ques)] = torch.LongTensor(ques)
                
            max_word_ques = 16
            
            char_ques = torch.zeros(len(batch), max_question_len, max_word_ques).type(torch.LongTensor)
            for i, question in enumerate(batch.question):
                char_ques[i] = self.make_char_vector(max_question_len, question)
            
              
            label = torch.LongTensor(list(batch.label_idx))
            ids = list(batch.id)
            
            yield (padded_context, padded_question, char_ctx, char_ques, label, ctx_text, answer_text, ids)
            

In [6]:
# create dataloaders
train_dataset = SquadDataset(train_df,16)
valid_dataset = SquadDataset(valid_df,16)

## 3. Prepare embeddings

No changes...

In [7]:
#we can skip here
#we simply load the glove embedding we save earlier
#this is why saving is so nice!
#to load, do like this -> weights_matrix = np.load('drqaglove_vt.npy')
#we are using drqaglove because the paper uses 300d version of glove...

## 4. Model

### 4.1 Depthwise Separable Convolutions

Depthwise separable convolutions serve the same purpose as normal convolutions with the only difference being that they are faster because they reduce the number of multiplication operations. This is done by breaking the convolution operation into two parts: depthwise convolution and pointwise convolution.

#### Depthwise convolution

<img src="images/depthconv.PNG" width="800" height="900"/>

#### Pointwise convolution

<img src="images/pointconv.PNG" width="700" height="700"/>

In code, the depthwise phase of the convolution is done by assigning `groups` as `in_channels`. According to the documentation, 

> *At groups= `in_channels`, each `input channel is convolved with its own set of filters, of size: $\left\lfloor\frac{out\_channels}{in\_channels}\right\rfloor$*

In [8]:
class DepthwiseSeparableConvolution(nn.Module):
    
    def __init__(self, in_channels, out_channels, kernel_size, dim=1):
        
        super().__init__()
        self.dim = dim
        if dim == 2:
            
            self.depthwise_conv = nn.Conv2d(in_channels=in_channels, out_channels=in_channels,
                                        kernel_size=kernel_size, groups=in_channels, padding=kernel_size//2)
        
            self.pointwise_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
        
        else:
        
            self.depthwise_conv = nn.Conv1d(in_channels=in_channels, out_channels=in_channels,
                                            kernel_size=kernel_size, groups=in_channels, padding=kernel_size//2,
                                            bias=False)

            self.pointwise_conv = nn.Conv1d(in_channels, out_channels, kernel_size=1, padding=0, bias=True)

    def forward(self, x):
        # x = [bs, seq_len, emb_dim]
        if self.dim == 1:
            x = x.transpose(1,2)
            x = self.pointwise_conv(self.depthwise_conv(x))
            x = x.transpose(1,2)
        else:
            x = self.pointwise_conv(self.depthwise_conv(x))
        #print("DepthWiseConv output: ", x.shape)
        return x

### 4.2 Highway Networks

No changes.

In [9]:
class HighwayLayer(nn.Module):
    
    def __init__(self, layer_dim, num_layers=2):
    
        super().__init__()
        self.num_layers = num_layers
        
        self.flow_layers = nn.ModuleList([nn.Linear(layer_dim, layer_dim) for _ in range(num_layers)])
        self.gate_layers = nn.ModuleList([nn.Linear(layer_dim, layer_dim) for _ in range(num_layers)])
    
    def forward(self, x):
        #print("Highway input: ", x.shape)
        for i in range(self.num_layers):
            
            flow = self.flow_layers[i](x)
            gate = torch.sigmoid(self.gate_layers[i](x))
            
            x = gate * flow + (1 - gate) * x
            
        #print("Highway output: ", x.shape)
        return x

### 4.3 Embedding Layer

This layer:
* converts word-level tokens into a 300-dim pre-trained glove embedding vector 
* creates trainable character embeddings using 2-D convolutions
* concatenates character and word embeddings and passes them through a highway network  

The details of calculating character embeddings has been discussed in detail in the previous notebook. The only difference here is that instead of max-pooling, `torch.max` is used to get a fixed-size representation of each word.

In [10]:
class EmbeddingLayer(nn.Module):
    
    def __init__(self, char_vocab_dim, char_emb_dim, kernel_size, device):
        
        super().__init__()
        
        self.device = device
        
        self.char_embedding = nn.Embedding(char_vocab_dim, char_emb_dim)
        
        self.word_embedding = self.get_glove_word_embedding()
        
        self.conv2d = DepthwiseSeparableConvolution(char_emb_dim, char_emb_dim, kernel_size,dim=2)
        
        self.highway = HighwayLayer(self.word_emb_dim + char_emb_dim)
    
        
    def get_glove_word_embedding(self):
        
        weights_matrix = np.load('drqaglove_vt.npy')
        num_embeddings, embedding_dim = weights_matrix.shape
        self.word_emb_dim = embedding_dim
        embedding = nn.Embedding.from_pretrained(torch.FloatTensor(weights_matrix).to(self.device),freeze=True)

        return embedding
    
    def forward(self, x, x_char):
        # x = [bs, seq_len]
        # x_char = [bs, seq_len, word_len(=16)]
        
        word_emb = self.word_embedding(x)
        # word_emb = [bs, seq_len, word_emb_dim]
                
        word_emb = F.dropout(word_emb,p=0.1)
        
        char_emb = self.char_embedding(x_char)
        # char_embed = [bs, seq_len, word_len, char_emb_dim]
              
        char_emb = F.dropout(char_emb.permute(0,3,1,2), p=0.05)
        # [bs, char_emb_dim, seq_len, word_len] == [N, Cin, Hin, Win]
        
        conv_out = F.relu(self.conv2d(char_emb))
        # [bs, char_emb_dim, seq_len, word_len] 
        # the depthwise separable conv does not change the shape of the input
        
        char_emb, _ = torch.max(conv_out, dim=3)
        # [bs, char_emb_dim, seq_len]
        
        char_emb = char_emb.permute(0,2,1)
        # [bs, seq_len, char_emb_dim]
        
        concat_emb = torch.cat([char_emb, word_emb], dim=2)
        # [bs, seq_len, char_emb_dim + word_emb_dim]
        
        emb = self.highway(concat_emb)
        # [bs, seq_len, char_emb_dim + word_emb_dim]
        
        #print("Embedding output: ", emb.shape)
        return emb

### 4.4 Multiheaded Self Attention

Since we have mentioned this a lot in our class, I will skip the explaination here.

In [11]:
class MultiheadAttentionLayer(nn.Module):
    
    def __init__(self, hid_dim, num_heads, device):
        
        super().__init__()
        self.num_heads = num_heads
        self.device = device
        self.hid_dim = hid_dim
        
        self.head_dim = self.hid_dim // self.num_heads
        
        self.fc_q = nn.Linear(hid_dim, hid_dim)
        
        self.fc_k = nn.Linear(hid_dim, hid_dim)
        
        self.fc_v = nn.Linear(hid_dim, hid_dim)
        
        self.fc_o = nn.Linear(hid_dim, hid_dim)
        
        self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)
        
        
    def forward(self, x, mask):
        # x = [bs, len_x, hid_dim]
        # mask = [bs, len_x]
        
        batch_size = x.shape[0]
        
        Q = self.fc_q(x)
        K = self.fc_k(x)
        V = self.fc_v(x)
        # Q = K = V = [bs, len_x, hid_dim]
        
        Q = Q.view(batch_size, -1, self.num_heads, self.head_dim).permute(0,2,1,3)
        K = K.view(batch_size, -1, self.num_heads, self.head_dim).permute(0,2,1,3)
        V = V.view(batch_size, -1, self.num_heads, self.head_dim).permute(0,2,1,3)
        # [bs, len_x, num_heads, head_dim ]  => [bs, num_heads, len_x, head_dim]
        
        K = K.permute(0,1,3,2)
        # [bs, num_heads, head_dim, len_x]
        
        energy = torch.matmul(Q, K) / self.scale
        # (bs, num_heads){[len_x, head_dim] * [head_dim, len_x]} => [bs, num_heads, len_x, len_x]
        
        mask = mask.unsqueeze(1).unsqueeze(2)
        # [bs, 1, 1, len_x]
        
        #print("Mask: ", mask)
        #print("Energy: ", energy)
        
        energy = energy.masked_fill(mask == 1, -1e10)
        
        #print("energy after masking: ", energy)
        
        alpha = torch.softmax(energy, dim=-1)
        #  [bs, num_heads, len_x, len_x]
        
        #print("energy after smax: ", alpha)
        alpha = F.dropout(alpha, p=0.1)
        
        a = torch.matmul(alpha, V)
        # [bs, num_heads, len_x, head_dim]
        
        a = a.permute(0,2,1,3)
        # [bs, len_x, num_heads, hid_dim]
        
        a = a.contiguous().view(batch_size, -1, self.hid_dim)
        # [bs, len_x, hid_dim]
        
        a = self.fc_o(a)
        # [bs, len_x, hid_dim]
        
        #print("Multihead output: ", a.shape)
        return a

### 4.5 Positional Embedding

The model so far does not have any idea about the positioning of words in a sentence.

One simple method of doing this is to assign a single number to each token between $[0, 1]$, where first word starts with 0 and the last word corresponds to 1. This solution presents some problems. For different sentence lengths, we'll have different intervals over which tokens are distributed. We would not have a consistent meaning of a particular position across all inputs(of varying lengths).   

Another method is to use learned position embeddings. This is used in BERT, where, the positional embedding a lookup table of size $[512, 768]$ where 512 is the maximum sequence length that BERT can process. This lookup matrix is randomly intialized and trained along with the model.   

Here however, the authors have used another method of encoding position which is same as that proposed in the original transformers paper. The positional embedding can be defined as,
<img src="images/posemb.PNG" width="500" height="400"/>

where $pos$ is the position, $i$ is the dimension of embedding, and $d_{model}$ is the model dimension.  These embeddings are simply added to the word embeddings of the tokens at their respective positions. 


In [12]:
from torch.autograd import Variable
import math

class PositionEncoder(nn.Module):
    
    def __init__(self, model_dim, device, max_length=1000):
        
        super().__init__()
        
        self.device = device
        
        self.model_dim = model_dim
        
        pos_encoding = torch.zeros(max_length, model_dim)
        
        for pos in range(max_length):
            
            for i in range(0, model_dim, 2):
                
                pos_encoding[pos, i]   = math.sin(pos / (10000 ** ((2*i)/model_dim)))
                pos_encoding[pos, i+1] = math.cos(pos / (10000 ** ((2*(i+1))/model_dim)))
        
        pos_encoding = pos_encoding.unsqueeze(0).to(device)
        self.register_buffer('pos_encoding', pos_encoding)  #register_buffer saves the parameters into the state_dict, but not trained by optimizer
    
    def forward(self, x):
        x = x + Variable(self.pos_encoding[:, :x.shape[1]], requires_grad=False)  
        #print("PE output: ", x.shape)
        return x

#### A bit about register_buffer

If you have parameters in your model, which should be saved and restored in the state_dict, but not trained by the optimizer, you should register them as buffers.
Buffers won’t be returned in model.parameters(), so that the optimizer won’t have a change to update them.

Note: You can also do the same with register_parameters.

In [28]:
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.my_tensor = torch.randn(1)
        self.register_buffer('my_buffer', torch.randn(1))
        self.my_param = nn.Parameter(torch.randn(1))
        
    def forward(self, x):
            return x

model = MyModel()
print(model.my_tensor)
print(model.state_dict())

tensor([0.1409])
OrderedDict([('my_param', tensor([0.2987])), ('my_buffer', tensor([-2.7767]))])


### 4.6 Encoder Block

This layer brings together all the components discussed so far. 

<img src="images/encoderblock.PNG" width="250" height="50"/>

The following steps are performed by this layer:  

* A positional embedding is injected into the input.
* This is then passed through a series of convolutional layers. The number of these layers depend upon the layer of which these encoder blocks are a part of. For embedding encoder layer, this number is 4 and for model encoder layer it is 2. The layers of convolution are defined using `nn.Modulelist`. 
* The output of this is then passed to a multiheaded self attention layer and finally to a feedforward network which is simply a linear layer.
* As can be seen in the figure above, the model involves residual connections, layer normalizations and dropouts too. These too are implemented appropriately. An easy way to understand the residual connections in code would be draw 2-3 iterations of the lower block (that involves convolution) and ensure that everything matches.



In [13]:
class EncoderBlock(nn.Module):
    
    def __init__(self, model_dim, num_heads, num_conv_layers, kernel_size, device):
        
        super().__init__()
        
        self.num_conv_layers = num_conv_layers
        
        self.conv_layers = nn.ModuleList([DepthwiseSeparableConvolution(model_dim, model_dim, kernel_size)
                                          for _ in range(num_conv_layers)])
        
        self.multihead_self_attn = MultiheadAttentionLayer(model_dim, num_heads, device)
        
        self.position_encoder = PositionEncoder(model_dim, device)
        
        self.pos_norm = nn.LayerNorm(model_dim)
        
        self.conv_norm = nn.ModuleList([nn.LayerNorm(model_dim) for _ in range(self.num_conv_layers)])
        
        self.feedfwd_norm = nn.LayerNorm(model_dim)
        
        self.feed_fwd = nn.Linear(model_dim, model_dim)
        
    def forward(self, x, mask):
        # x = [bs, len_x, model_dim]
        # mask = [bs, len_x]
        
        out = self.position_encoder(x)
        # [bs, len_x, model_dim]
        
        res = out
        
        out = self.pos_norm(out)
        # [bs, len_x, model_dim]
        
        for i, conv_layer in enumerate(self.conv_layers):
            
            out = F.relu(conv_layer(out))
            out = out + res
            if (i+1) % 2 == 0:
                out = F.dropout(out, p=0.1)
            res = out
            out = self.conv_norm[i](out)
        
        
        out = self.multihead_self_attn(out, mask)
        # [bs, len_x, model_dim]
        
        out = F.dropout(out + res, p=0.1)
        
        res = out
        
        out = self.feedfwd_norm(out)
        
        out = F.relu(self.feed_fwd(out))
        # [bs, len_x, model_dim]
            
        out = F.dropout(out + res, p=0.1)
        # [bs, len_x, model_dim]
        #print("Encoder block output: ", out.shape)
        return out

### 4.7 Context-Query Attention Layer

This layer is very similar to the attention flow layer in BIDAF. It calculates attention in two directions. Context-query attention tells us what query words are the most relevant to each context word.   

Let $C$ and $Q$ represent the encoded context and query respectively. Given that the context length is $n$ and query length is $m$, a similarity matrix is calculated first. The similarity matrix captures the similarity between each pair of context and query words. It is denoted by $S$ and is a $n$-by-$m$ matrix. The similarity matrix is calculated as,
$$ S = f\ (Q,\ C)$$
where $f$ is a trilinear similarity function defined as,
$$ f(q,c) = W_{0}\ [q\ ;\ c\ ;\ q \odot c] $$,
where $W_{0}$ is trainable variable, $;$ denotes concatenation and $\odot$ denotes element wise multiplication.  
Context-to-Query attention can then be calculated as,
$$ A = \overline S\ .\ Q^{T} $$,
where $\overline S$ is obtained by normalizing each row of $S$ using softmax. The computations so far are exactly similar to those in BIDAF. You can refer to the previous notebook for a more detailed explanation.  

Query-to-Context attention is calculated as,
$$B = \overline S\ .\ \overline{\overline S}^{T}\ .\ C^{T}$$,
where $\overline{\overline S}^{T}$ is the column-normalized matrix of $S$ by softmax function.  

The implementation is fairly straightforward and is just about multiplying the said tensors.

In [14]:
class ContextQueryAttentionLayer(nn.Module):
    
    def __init__(self, model_dim):
        
        super().__init__() 
        
        self.W0 = nn.Linear(3*model_dim, 1, bias=False)
        
    def forward(self, C, Q, c_mask, q_mask):
        # C = [bs, ctx_len, model_dim]
        # Q = [bs, qtn_len, model_dim]
        # c_mask = [bs, ctx_len]
        # q_mask = [bs, qtn_len]
        
        c_mask = c_mask.unsqueeze(2)
        # [bs, ctx_len, 1]
        
        q_mask = q_mask.unsqueeze(1)
        # [bs, 1, qtn_len]
        
        ctx_len = C.shape[1]
        qtn_len = Q.shape[1]
        
        C_ = C.unsqueeze(2).repeat(1,1,qtn_len,1)
        # [bs, ctx_len, qtn_len, model_dim] 
        
        Q_ = Q.unsqueeze(1).repeat(1,ctx_len,1,1)
        # [bs, ctx_len, qtn_len, model_dim]
        
        C_elemwise_Q = torch.mul(C_, Q_)
        # [bs, ctx_len, qtn_len, model_dim]
        
        S = torch.cat([C_, Q_, C_elemwise_Q], dim=3)
        # [bs, ctx_len, qtn_len, model_dim*3]
        
        S = self.W0(S).squeeze()
        #print("Simi matrix: ", S.shape)
        # [bs, ctx_len, qtn_len, 1] => # [bs, ctx_len, qtn_len]
        
        S_row = S.masked_fill(q_mask==1, -1e10)
        S_row = F.softmax(S_row, dim=2)
        
        S_col = S.masked_fill(c_mask==1, -1e10)
        S_col = F.softmax(S_col, dim=1)
        
        A = torch.bmm(S_row, Q)
        # (bs)[ctx_len, qtn_len] X [qtn_len, model_dim] => [bs, ctx_len, model_dim]
        
        B = torch.bmm(torch.bmm(S_row,S_col.transpose(1,2)), C)
        # [ctx_len, qtn_len] X [qtn_len, ctx_len] => [bs, ctx_len, ctx_len]
        # [ctx_len, ctx_len] X [ctx_len, model_dim ] => [bs, ctx_len, model_dim]
        
        model_out = torch.cat([C, A, torch.mul(C,A), torch.mul(C,B)], dim=2)
        # [bs, ctx_len, model_dim*4]
        
        #print("C2Q output: ", model_out.shape)
        return F.dropout(model_out, p=0.1)
        
        

### 4.8 Output Layer

The output layer is tasked with predicting the start and end indices of the answer from the context. The input to this layer
$M_{1}$, $M_{2}$ and $M_{3}$ are the outputs of 3 model encoders(explained below), from bottom to top. The start index $p_{1}$ is then calculated as,  

$$ p_{1} = softmax\ (\ W_{1}\ [M_{1}\ ;\ M_{2}])$$
and end as,
$$ p_{2} = softmax\ (\ W_{2}\ [M_{1}\ ;\ M_{3}])$$

where $W_{1}$ and $W_{2}$ are trainable variables.

In [15]:
class OutputLayer(nn.Module):
    
    def __init__(self, model_dim):
        
        super().__init__()
        
        self.W1 = nn.Linear(2*model_dim, 1, bias=False)
        
        self.W2 = nn.Linear(2*model_dim, 1, bias=False)
        
        
    def forward(self, M1, M2, M3, c_mask):
        
        start = torch.cat([M1,M2], dim=2)
        
        start = self.W1(start).squeeze()
        
        p1 = start.masked_fill(c_mask==1, -1e10)
        
        #p1 = F.log_softmax(start.masked_fill(c_mask==1, -1e10), dim=1)
        
        end = torch.cat([M1, M3], dim=2)
        
        end = self.W2(end).squeeze()
        
        p2 = end.masked_fill(c_mask==1, -1e10)
        
        #p2 = F.log_softmax(end.masked_fill(c_mask==1, -1e10), dim=1)
        
        #print("preds: ", [p1.shape,p2.shape])
        return p1, p2
        

## QANet

<img src="images/qanet.PNG" width="500" height="600"/>

Going up the flowchart above, the following module does the following end-to-end:
* The inputs to the `forward` method are word-level and character-level tokens for both the context and the query. These tokens are passed to the embedding layer.  

 > *The word embedding is ﬁxed during training and initialized from the p1 = 300 dimensional pre-trained GloVe word vectors, which are ﬁxed during training.*  

 > *The character embedding is obtained as follows: Each character is represented as a trainable vector of dimension p2 = 200, meaning each word can be viewed as the concatenation of the embedding vectors for each of its characters.*   

 For each word the concatenation of these two embeddings is passed on to a 2-layer highway network. Highway network does not affect the shape of the input. Hence the output shape from the `EmbeddingLayer` defined above would be `[bs, ctx_len, word_emb_dim + char_emb_dim]` = `[batch_size, ctx_len, 500]`. This is then supposed to be passed to `embedding_encoder` or the Embedding Encoder Layer. This layer however requires the input dimension to be 128 which is the `model_dim`
 and not 500. As clearly mentioned in the paper,    
 > *Note that the input of this layer is a vector of dimension p1 + p2 = 500 for each individual word, which is immediately mapped to d = 128 by a one-dimensional convolution. The output of this layer is a also of dimension d = 128.*   
 
 We therefore map the output of embedding to 128 in code using `ctx_resizer` and `qtn_resizer`.  

* The resized tensors are then passed on to the *Embedding Encoding Layer* which is a single encoder block with 4 conv layers. 8 attention heads are used in the self-attention module which is the same for all the encoder blocks in the model.
  
* The output of previous layer is then passed on to the *Contex-Query Attention Layer*.  The output dimension of this layer is `4 * model_dim`. This is again resized using `c2q_resizer` to have a dimension of `model_dim`. 
* Next the encoded representation so far is passed on to the *Model Encoder Layer*. This layer comprises of 7 blocks of encoder, with each block having 2 convolutional layers. 
 > *We share weights between each of the 3 repetitions of the model encoder.*
 This can be seen in code while calculating $M_{1}$, $M_{2}$ and $M_{3}$.
* Finally the shared-weight matrices are passed to the output layer which predicts the start and end index of the answer.


In [16]:
class QANet(nn.Module):
    
    def __init__(self, char_vocab_dim, char_emb_dim, word_emb_dim, kernel_size, model_dim, num_heads, device):
        
        super().__init__()
        
        self.embedding = EmbeddingLayer(char_vocab_dim, char_emb_dim, kernel_size, device)
        
        self.ctx_resizer = DepthwiseSeparableConvolution(char_emb_dim+word_emb_dim, model_dim, 5)
        
        self.qtn_resizer = DepthwiseSeparableConvolution(char_emb_dim+word_emb_dim, model_dim, 5)
        
        self.embedding_encoder = EncoderBlock(model_dim, num_heads, 4, 5, device)
        
        self.c2q_attention = ContextQueryAttentionLayer(model_dim)
        
        self.c2q_resizer = DepthwiseSeparableConvolution(model_dim*4, model_dim, 5)
        
        self.model_encoder_layers = nn.ModuleList([EncoderBlock(model_dim, num_heads, 2, 5, device)
                                                   for _ in range(7)])
        
        self.output = OutputLayer(model_dim)
        
        self.device=device
    
    def forward(self, ctx, qtn, ctx_char, qtn_char):
        
        # ctx : [bs, ctx_len]
        # qtn : [bs, qtn_len]
        # ctx_char : [bs, ctx_len, ctx_word_len]
        # qtn_char : [bs, qtn_len, qtn_word_len]
        
        c_mask = torch.eq(ctx, 1).float().to(self.device)
        q_mask = torch.eq(qtn, 1).float().to(self.device)
        
        ctx_emb = self.embedding(ctx, ctx_char)
        # [bs, ctx_len, ch_emb_dim + word_emb_dim]
            
        ctx_emb = self.ctx_resizer(ctx_emb)
        #  [bs, ctx_len, model_dim]
        
        qtn_emb = self.embedding(qtn, qtn_char)
        # [bs, ctx_len, ch_emb_dim + word_emb_dim]
        
        qtn_emb = self.qtn_resizer(qtn_emb)
        # [bs, qtn_len, model_dim]
        
        C = self.embedding_encoder(ctx_emb, c_mask)
        # [bs, ctx_len, model_dim]
        
        Q = self.embedding_encoder(qtn_emb, q_mask)
        # [bs, qtn_len, model_dim]
            
        C2Q = self.c2q_attention(C, Q, c_mask, q_mask)
        # [bs, ctx_len, model_dim*4]
        
        M1 = self.c2q_resizer(C2Q)
        # [bs, ctx_len, model_dim]
    
        for layer in self.model_encoder_layers:
            M1 = layer(M1, c_mask)
        
        M2 = M1
        # [bs, ctx_len, model_dim]  
        
        for layer in self.model_encoder_layers:
            M2 = layer(M2, c_mask)
        
        M3 = M2
        # [bs, ctx_len, model_dim]
        
        for layer in self.model_encoder_layers:
            M3 = layer(M3, c_mask)
            
        p1, p2 = self.output(M1, M2, M3, c_mask)
        
        return p1, p2

In [17]:
CHAR_VOCAB_DIM = len(char2idx)
CHAR_EMB_DIM = 200
WORD_EMB_DIM = 300
KERNEL_SIZE  = 5
MODEL_DIM = 128
NUM_ATTENTION_HEADS = 8
device = torch.device('cuda')

model = QANet(CHAR_VOCAB_DIM,
              CHAR_EMB_DIM, 
              WORD_EMB_DIM,
              KERNEL_SIZE,
              MODEL_DIM,
              NUM_ATTENTION_HEADS,
              device).to(device)

In [18]:
for batch in train_dataset:
    context, question, char_ctx, char_ques, label, ctx_text, ans, ids = batch
    break

In [19]:
 context, question, char_ctx, char_ques, label = context.to(device), question.to(device),\
                                    char_ctx.to(device), char_ques.to(device), label.to(device)

In [20]:
preds = model(context, question, char_ctx, char_ques)

In [21]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 2,243,296 trainable parameters


## 5. Training

> *We use the ADAM optimizer (Kingma & Ba, 2014) with β1 = 0.8,β2 = 0.999, $\epsilon$ = 10−7. We use a learning rate warm-up scheme with an inverse exponential increase from 0.0 to 0.001 in the ﬁrst 1000 steps, and then maintain a constant learning rate for the remainder of training.*

Note: I have not used learning-rate warm up scheme to keep things simple for initial training. 

In [22]:
import torch.optim as optim
optimizer = optim.Adam(model.parameters(), betas=(0.8,0.999), eps=10e-7, weight_decay=3*10e-7)

In [23]:
def train(model, train_dataset):
    print("Starting training ........")
   

    train_loss = 0.
    batch_count = 0

    for batch in train_dataset:

        if batch_count % 500 == 0:
            print(f"Starting batch: {batch_count}")
        batch_count += 1
        
        context, question, char_ctx, char_ques, label, ctx_text, ans, ids = batch
        
        # place data on GPU
        context, question, char_ctx, char_ques, label = context.to(device), question.to(device),\
                                    char_ctx.to(device), char_ques.to(device), label.to(device)
        
        # forward pass, get predictions
        preds = model(context, question, char_ctx, char_ques)

        start_pred, end_pred = preds
        
        # separate labels for start and end position
        start_label, end_label = label[:,0], label[:,1]
        
        # calculate loss
        loss = F.cross_entropy(start_pred, start_label) + F.cross_entropy(end_pred, end_label)
        
        # backward pass
        loss.backward()
        
        # update the gradients
        optimizer.step()

        # zero the gradients so that they do not accumulate
        optimizer.zero_grad()

        train_loss += loss.item()

    return train_loss/len(train_dataset)

In [24]:
def valid(model, valid_dataset):
    
    print("Starting validation .........")
   
    valid_loss = 0.

    batch_count = 0
    
    f1, em = 0., 0.
    
    predictions = {}
    
    for batch in valid_dataset:

        if batch_count % 500 == 0:
            print(f"Starting batch {batch_count}")
        batch_count += 1

        context, question, char_ctx, char_ques, label, ctx_text, ans, ids = batch

        context, question, char_ctx, char_ques, label = context.to(device), question.to(device),\
                                    char_ctx.to(device), char_ques.to(device), label.to(device)

        with torch.no_grad():

            preds = model(context, question, char_ctx, char_ques)

            p1, p2 = preds

            y1, y2 = label[:,0], label[:,1]

            loss = F.nll_loss(p1, y1) + F.nll_loss(p2, y2)

            valid_loss += loss.item()

            batch_size, c_len = p1.size()
            ls = nn.LogSoftmax(dim=1)
            mask = (torch.ones(c_len, c_len) * float('-inf')).to(device).tril(-1).unsqueeze(0).expand(batch_size, -1, -1)
            score = (ls(p1).unsqueeze(2) + ls(p2).unsqueeze(1)) + mask
            score, s_idx = score.max(dim=1)
            score, e_idx = score.max(dim=1)
            s_idx = torch.gather(s_idx, 1, e_idx.view(-1, 1)).squeeze()
            
           
            for i in range(batch_size):
                id = ids[i]
                pred = context[i][s_idx[i]:e_idx[i]+1]
                pred = ' '.join([idx2word[idx.item()] for idx in pred])
                predictions[id] = pred
            
    em, f1 = evaluate(predictions)
    return valid_loss/len(valid_dataset), em, f1           
  

In [25]:
import json

def evaluate(predictions):
    '''
    Gets a dictionary of predictions with question_id as key
    and prediction as value. The validation dataset has multiple 
    answers for a single question. Hence we compare our prediction
    with all the answers and choose the one that gives us
    the maximum metric (em or f1). 
    This method first parses the JSON file, gets all the answers
    for a given id and then passes the list of answers and the 
    predictions to calculate em, f1.
    
    
    :param dict predictions
    Returns
    : exact_match: 1 if the prediction and ground truth 
      match exactly, 0 otherwise.
    : f1_score: 
    '''
    with open('./data/squad_dev.json','r',encoding='utf-8') as f:
        dataset = json.load(f)
        
    dataset = dataset['data']
    f1 = exact_match = total = 0
    for article in dataset:
        for paragraph in article['paragraphs']:
            for qa in paragraph['qas']:
                total += 1
                if qa['id'] not in predictions:
                    continue
                
                ground_truths = list(map(lambda x: x['text'], qa['answers']))
                
                prediction = predictions[qa['id']]
                
                exact_match += metric_max_over_ground_truths(
                    exact_match_score, prediction, ground_truths)
                
                f1 += metric_max_over_ground_truths(
                    f1_score, prediction, ground_truths)
                
    
    exact_match = 100.0 * exact_match / total
    f1 = 100.0 * f1 / total
    
    return exact_match, f1



In [26]:
import string, re
from collections import Counter

def normalize_answer(s):
    '''
    Performs a series of cleaning steps on the ground truth and 
    predicted answer.
    '''
    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
    '''
    Returns maximum value of metrics for predicition by model against
    multiple ground truths.
    
    :param func metric_fn: can be 'exact_match_score' or 'f1_score'
    :param str prediction: predicted answer span by the model
    :param list ground_truths: list of ground truths against which
                               metrics are calculated. Maximum values of 
                               metrics are chosen.
                            
    
    '''
    scores_for_ground_truths = []
    for ground_truth in ground_truths:
        score = metric_fn(prediction, ground_truth)
        scores_for_ground_truths.append(score)
        
    return max(scores_for_ground_truths)


def f1_score(prediction, ground_truth):
    '''
    Returns f1 score of two strings.
    '''
    prediction_tokens = normalize_answer(prediction).split()
    ground_truth_tokens = normalize_answer(ground_truth).split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1


def exact_match_score(prediction, ground_truth):
    '''
    Returns exact_match_score of two strings.
    '''
    return (normalize_answer(prediction) == normalize_answer(ground_truth))

def epoch_time(start_time, end_time):
    '''
    Helper function to record epoch time.
    '''
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [27]:
train_losses = []
valid_losses = []
ems = []
f1s = []
epochs = 3
for epoch in range(epochs):
    print(f"Epoch {epoch+1}")
    start_time = time.time()
    
    train_loss = train(model, train_dataset)
    valid_loss, em, f1 = valid(model, valid_dataset)
    
    end_time = time.time()
    
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    train_losses.append(train_loss)
    valid_losses.append(valid_loss)
    ems.append(em)
    f1s.append(f1)
    
    print(f"Epoch train loss : {train_loss}| Time: {epoch_mins}m {epoch_secs}s")
    print(f"Epoch valid loss: {valid_loss}")
    print(f"Epoch EM: {em}")
    print(f"Epoch F1: {f1}")
    print("====================================================================================")
    

Epoch 1
Starting training ........
Starting batch: 0
Starting batch: 500
Starting validation .........
Starting batch 0
Epoch train loss : 20.102696708741227| Time: 5m 30s
Epoch valid loss: 0.07382533744778423
Epoch EM: 0.04730368968779565
Epoch F1: 1.0085605745957305
Epoch 2
Starting training ........
Starting batch: 0
Starting batch: 500
Starting validation .........
Starting batch 0
Epoch train loss : 9.0079408164916| Time: 5m 23s
Epoch valid loss: -0.3216766810669573
Epoch EM: 0.4162724692526017
Epoch F1: 1.28854482827315
Epoch 3
Starting training ........
Starting batch: 0
Starting batch: 500
Starting validation .........
Starting batch 0
Epoch train loss : 8.308144180173796| Time: 5m 25s
Epoch valid loss: 0.254921169532918
Epoch EM: 0.586565752128666
Epoch F1: 1.440395454776078


## References

* Papers read/ referenced:
    1. The QANet paper: https://arxiv.org/abs/1804.09541
    2. Attention is All You Need https://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf
    3. Convolutional Neural Networks for Sentence Classification: https://arxiv.org/abs/1408.5882
    4. Highway Networks: https://arxiv.org/abs/1505.00387
* Other helpful links:
    1. https://lilianweng.github.io/lil-log/2018/06/24/attention-attention.html
    2. The Illustrated Transformer:http://jalammar.github.io/illustrated-transformer/. This is an excellent piece of writing with amazing easy-to-understand visualizations. Must read.
    3. https://mccormickml.com/2019/11/11/bert-research-ep-1-key-concepts-and-sources/. Chris McCormick's BERT research series is another great resource to learn about self attention and various other details about BERT. He has a blog as well as youtube video series on the same.
    4. https://kazemnejad.com/blog/transformer_architecture_positional_encoding/
    5. https://nlp.seas.harvard.edu/2018/04/03/attention.html. The annotated Transformer.
    6. https://nlp.seas.harvard.edu/slides/aaai16.pdf. A great resource for character embeddings.
    7. https://www.youtube.com/watch?v=T7o3xvJLuHk. Easy explanation of depthwise separable convolutions.
    8. https://towardsdatascience.com/a-basic-introduction-to-separable-convolutions-b99ec3102728. Another amazing blog for depthwise separable convolutions.
    9. https://github.com/bentrevett/pytorch-seq2seq. A great series of notebooks on Machine Translation using PyTorch.  
Some of the repositories below might be out of date. 
    10. https://github.com/BangLiu/QANet-PyTorch
    11. https://github.com/NLPLearn/QANet
    12. https://github.com/setoidz/QANet-pytorch
    13. https://github.com/hackiey/QAnet-pytorch/tree/master/qanet