<a href="https://colab.research.google.com/github/dksifoua/Question-Answering/blob/master/1.%20DrQA%20-%20Document%20reader%20Question%20Answering.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!nvidia-smi

Mon Oct 19 07:12:14 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 455.23.05    Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   66C    P8    12W /  70W |      0MiB / 15079MiB |      0%      Default |
|                               |                      |                 ERR! |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

## Load Dependencies

In [2]:
!pip install tqdm --upgrade >> /dev/null 2>&1
!pip install torchtext --upgrade >> /dev/null 2>&1
!pip install spacy --upgrade >> /dev/null 2>&1
!python -m spacy download en >> /dev/null 2>&1

In [3]:
import re
import json
import tqdm
import spacy
import warnings
import collections
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torchtext.data import Dataset, Example, Field
from torchtext.data.iterator import BucketIterator

In [4]:
warnings.simplefilter(action='ignore', category=UserWarning)
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=RuntimeWarning)
warnings.simplefilter(action='ignore', category=DeprecationWarning)

SEED = 546
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {DEVICE}')

Device: cuda


## Prepare Data

***Donwload data***

In [5]:
%%time
!mkdir -p ./data

!wget --no-check-certificate \
    https://worksheets.codalab.org/rest/bundles/0x7e0a0a21057c4d989aa68da42886ceb9/contents/blob/ \
    -O ./data/train.json

!wget --no-check-certificate \
    https://worksheets.codalab.org/rest/bundles/0x8f29fe78ffe545128caccab74eb06c57/contents/blob/ \
    -O ./data/valid.json

--2020-10-19 07:12:43--  https://worksheets.codalab.org/rest/bundles/0x7e0a0a21057c4d989aa68da42886ceb9/contents/blob/
Resolving worksheets.codalab.org (worksheets.codalab.org)... 40.114.41.203
Connecting to worksheets.codalab.org (worksheets.codalab.org)|40.114.41.203|:443... connected.
HTTP request sent, awaiting response... 200 OK
Syntax error in Set-Cookie: codalab_session=""; expires=Thu, 01 Jan 1970 00:00:00 GMT; Max-Age=-1; Path=/ at position 70.
Length: unspecified [application/json]
Saving to: ‘./data/train.json’

./data/train.json       [                <=> ]  28.88M  9.07MB/s    in 3.2s    

2020-10-19 07:12:47 (9.07 MB/s) - ‘./data/train.json’ saved [30288272]

--2020-10-19 07:12:47--  https://worksheets.codalab.org/rest/bundles/0x8f29fe78ffe545128caccab74eb06c57/contents/blob/
Resolving worksheets.codalab.org (worksheets.codalab.org)... 40.114.41.203
Connecting to worksheets.codalab.org (worksheets.codalab.org)|40.114.41.203|:443... connected.
HTTP request sent, awaiting r

***Load data***

In [6]:
def load_json(path):
    with open(path, mode='r', encoding='utf-8') as file:
        return json.load(file)['data']

In [7]:
train_raw_data = load_json(path='./data/train.json')
valid_raw_data = load_json(path='./data/valid.json')
print(f'Length of raw train data: {len(train_raw_data):,}')
print(f'Length of raw valid data: {len(valid_raw_data):,}')

Length of raw train data: 442
Length of raw valid data: 48


***Parse data***

In [8]:
def parse_json(data):
    qas = []
    for paragraphs in data:
        for para in paragraphs['paragraphs']:
            for qa in para['qas']:
                for ans in qa['answers']:
                    qas.append({
                        'id': qa['id'],
                        'context': para['context'],
                        'question': qa['question'],
                        'answer': ans['text'],
                        'answer_start': ans['answer_start'],
                    })
    return qas

In [9]:
train_qas = parse_json(train_raw_data)
valid_qas = parse_json(valid_raw_data)
print(f'Length of train qa pairs: {len(train_qas):,}')
print(f'Length of valid qa pairs: {len(valid_qas):,}')

print('Id:', train_qas[0]['id'])
print('Context:', train_qas[0]['context'])
print('Question:', train_qas[0]['question'])
print('Answer starts at:', train_qas[0]['answer_start'])
print('Answer:', train_qas[0]['answer'])

Length of train qa pairs: 87,599
Length of valid qa pairs: 34,726
Id: 5733be284776f41900661182
Context: Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.
Question: To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?
Answer starts at: 515
Answer: Saint Bernadette Soubirous


In [10]:
for i in range(len(train_qas)): # Test answer_start are correct in train set
    assert train_qas[i]['answer'] == train_qas[i]['context'][train_qas[i]['answer_start']:train_qas[i]['answer_start'] + len(train_qas[i]['answer'])]

for i in range(len(valid_qas)): # Test answer_start are correct in validation set
    assert valid_qas[i]['answer'] == valid_qas[i]['context'][valid_qas[i]['answer_start']:valid_qas[i]['answer_start'] + len(valid_qas[i]['answer'])]

***Add targets***

In [11]:
def add_targets(qas, nlp=spacy.load('en')):
    for qa in tqdm.tqdm(qas):
        context = nlp(qa['context'], disable=['parser','tagger','ner'])
        answer = nlp(qa['answer'], disable=['parser','tagger','ner'])
        ans_start = qa['answer_start']
        for i in range(len(context)):
            if context[i].idx == ans_start:
                ans = context[i:i + len(answer)]
                qa['target'] = (ans[0].i, ans[-1].i)
                break

In [12]:
add_targets(train_qas)
add_targets(valid_qas)
print()
print(f'Length of train qa pairs: {len(train_qas):,}')
print(f'Length of valid qa pairs: {len(valid_qas):,}')

100%|██████████| 87599/87599 [01:26<00:00, 1017.92it/s]
100%|██████████| 34726/34726 [00:36<00:00, 962.94it/s]


Length of train qa pairs: 87,599
Length of valid qa pairs: 34,726





In [13]:
def filter_qas(qa, nlp=spacy.load('en')):
    if 'target' in [*qa.keys()]:
        context = nlp(qa['context'], disable=['parser','tagger','ner'])
        start, end = qa['target']
        return context[start:end+1].text == qa['answer']
    return False

In [14]:
%%time
train_qas = [*filter(filter_qas, train_qas)]
valid_qas = [*filter(filter_qas, valid_qas)]
print(f'Length of train qa pairs after filtering out bad qa pairs: {len(train_qas):,}')
print(f'Length of valid qa pairs after filtering out bad qa pairs: {len(valid_qas):,}')

Length of train qa pairs after filtering out bad qa pairs: 86,597
Length of valid qa pairs after filtering out bad qa pairs: 34,295
CPU times: user 1min 46s, sys: 108 ms, total: 1min 46s
Wall time: 1min 46s


In [15]:
def test_targets(qas, nlp=spacy.load('en')):
    for qa in tqdm.tqdm(qas):
        if 'target' in [*qa.keys()]:
            context = nlp(qa['context'], disable=['parser','tagger','ner'])
            start, end = qa['target']
            assert context[start:end + 1].text == qa['answer']

In [16]:
test_targets(train_qas)
test_targets(valid_qas)

100%|██████████| 86597/86597 [01:21<00:00, 1066.05it/s]
100%|██████████| 34295/34295 [00:33<00:00, 1016.24it/s]


***Add features***

In [17]:
def add_features(qas, nlp=spacy.load('en')):
    for qa in tqdm.tqdm(qas):
        question = [token.text for token in nlp(qa['question'], disable=['parser','tagger','ner'])]
        context = nlp(qa['context'], disable=['parser'])
        qa['pos'], qa['ner'] = zip(*map(lambda token: (token.tag_, token.ent_type_ if token.ent_type_ != '' else 'None'), context))

In [18]:
add_features(train_qas)
add_features(valid_qas)

100%|██████████| 86597/86597 [20:35<00:00, 70.09it/s]
100%|██████████| 34295/34295 [08:39<00:00, 66.04it/s]


***Build datasets***

In [19]:
ID = Field(tokenize=None, batch_first=True)
POS = Field(tokenize=None, batch_first=True)
NER = Field(tokenize=None, batch_first=True)
TEXT = Field(lower=True, tokenizer_language='en', tokenize='spacy', batch_first=True)
TARGET = Field(sequential=False, use_vocab=False, batch_first=True)

In [70]:
train_dataset = Dataset([Example.fromdict(data=qa, fields={
    'id': ('id', ID),
    'context': ('cxt', TEXT),
    'question': ('qst', TEXT),
    'answer': ('ans', TEXT),
    'pos': ('pos', POS),
    'ner': ('ner', NER),
    'target': ('trg', TARGET)
}) for qa in tqdm.tqdm(train_qas)], fields={
    'id': ID,
    'cxt': TEXT,
    'qst': TEXT,
    'ans': TEXT,
    'pos': POS,
    'ner': NER,
    'trg': TARGET
})

valid_dataset = Dataset([Example.fromdict(data=qa, fields={
    'id': ('id', ID),
    'context': ('cxt', TEXT),
    'question': ('qst', TEXT),
    'answer': ('ans', TEXT),
    'pos': ('pos', POS),
    'ner': ('ner', NER),
    'target': ('trg', TARGET)
}) for qa in tqdm.tqdm(valid_qas)], fields={
    'id': ID,
    'cxt': TEXT,
    'qst': TEXT,
    'ans': TEXT,
    'pos': POS,
    'ner': NER,
    'trg': TARGET
})
print()
print(f'Length of train dataset: {len(train_dataset.examples):,}')
print(f'Length of valid dataset: {len(valid_dataset.examples):,}')

100%|██████████| 86597/86597 [02:05<00:00, 689.39it/s]
100%|██████████| 34295/34295 [00:50<00:00, 674.37it/s]


Length of train dataset: 86,597
Length of valid dataset: 34,295





In [71]:
ID.build_vocab([*map(lambda x: x.id, train_dataset.examples)] + [*map(lambda x: x.id, valid_dataset.examples)])
TEXT.build_vocab(train_dataset)
POS.build_vocab(train_dataset)
NER.build_vocab(train_dataset)
print(f'Length of ID vocabulary: {len(ID.vocab):,}')
print(f'Length of TEXT vocabulary: {len(TEXT.vocab):,}')
print(f'Length of POS vocabulary: {len(POS.vocab):,}')
print(f'Length of NER vocabulary: {len(NER.vocab):,}')

Length of ID vocabulary: 97,108
Length of TEXT vocabulary: 91,446
Length of POS vocabulary: 52
Length of NER vocabulary: 21


***Download pretrained GloVe embedding***

In [22]:
%%time
!wget --no-check-certificate \
    http://nlp.stanford.edu/data/glove.840B.300d.zip \
    -O ./data/glove.840B.300d.zip
!unzip -q ./data/glove.840B.300d.zip -d ./data
!rm -r ./data/glove.840B.300d.zip

--2020-10-19 07:51:00--  http://nlp.stanford.edu/data/glove.840B.300d.zip
Resolving nlp.stanford.edu (nlp.stanford.edu)... 171.64.67.140
Connecting to nlp.stanford.edu (nlp.stanford.edu)|171.64.67.140|:80... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://nlp.stanford.edu/data/glove.840B.300d.zip [following]
--2020-10-19 07:51:00--  https://nlp.stanford.edu/data/glove.840B.300d.zip
Connecting to nlp.stanford.edu (nlp.stanford.edu)|171.64.67.140|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: http://downloads.cs.stanford.edu/nlp/data/glove.840B.300d.zip [following]
--2020-10-19 07:51:00--  http://downloads.cs.stanford.edu/nlp/data/glove.840B.300d.zip
Resolving downloads.cs.stanford.edu (downloads.cs.stanford.edu)... 171.64.64.22
Connecting to downloads.cs.stanford.edu (downloads.cs.stanford.edu)|171.64.64.22|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2176768927 (2.0G) [application/zip

In [23]:
def load_glove(path):
    glove = {}
    with open(path, mode='r', encoding='utf-8') as file:
        for line in tqdm.tqdm(file):
            values = line.split(' ')
            glove[values[0]] = np.asarray(values[1:], dtype='float32')
        return glove

In [24]:
%%time
glove = load_glove(path='./data/glove.840B.300d.txt')

2196017it [02:50, 12848.44it/s]

CPU times: user 2min 44s, sys: 6.39 s, total: 2min 50s
Wall time: 2min 50s





In [25]:
def load_embeddings(glove, field, embedding_size=300, most_common=1000):
    most_common_words, most_common_indexes = [*map(lambda x: x[0], field.vocab.freqs.most_common(most_common))], []
    embedding_matrix = np.zeros((len(field.vocab), embedding_size))
    n_words = 0
    for index, word in tqdm.tqdm(enumerate(field.vocab.freqs), total=len(field.vocab)):
        if word in most_common_words:
            most_common_indexes.append(index)
        try:
            embedding_matrix[index] = glove[word]
            n_words += 1
        except KeyError:
            pass
    return embedding_matrix, n_words, most_common_indexes

In [26]:
embedding_matrix, n_words, most_common_indexes = load_embeddings(glove, TEXT)
print(f'\nWords found: {n_words}/{len(TEXT.vocab)}')
np.save('./data/GloVe_DrQA.npy', embedding_matrix)

100%|█████████▉| 91444/91446 [00:01<00:00, 49300.96it/s]



Words found: 64754/91446


In [27]:
# Free up the RAM
del glove
del embedding_matrix

## Modeling

***Stacked Bidirectional LSTM Layer***

In [28]:
class StackedBiLSTMsLayer(nn.Module):

    def __init__(self, input_size, hidden_size, n_layers, dropout):
        super(StackedBiLSTMsLayer, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.n_layers = n_layers
        self.dropout = nn.Dropout(p=dropout)
        self.lstms = nn.ModuleList([nn.LSTM(input_size if i == 0 else hidden_size * 2, hidden_size,
                                            num_layers=n_layers, bidirectional=True) for i in range(n_layers)])
    
    def forward(self, inputs):
        """
        :param Tensor[batch_size, seq_len, input_size] inputs
        :return Tensor[batch_size, seq_len, hidden_size * n_layers * 2]
        """
        outputs = [inputs]
        for lstm in self.lstms:
            out, _ = lstm(self.dropout(outputs[-1])) # [batch_size, seq_len, hidden_size * 2]
            outputs.append(out)
        return self.dropout(torch.cat(outputs[1:], dim=-1))

***Aligned Question Embedding Layer***

In [29]:
class AlignQuestionEmbeddingLayer(nn.Module):

    def __init__(self, hidden_size):
        super(AlignQuestionEmbeddingLayer, self).__init__()
        self.hidden_size = hidden_size
        self.linear = nn.Linear(hidden_size, hidden_size)

    def forward(self, cxt_embed, qst_embed, qst_mask):
        """
        :param Tensor[batch_size, cxt_len, embedding_size] cxt_embded
        :param Tensor[batch_size, qst_len, embedding_size] qst_embed
        :param Tensor[batch_size, qst_len] qst_mask
        :return Tensor[batch_size, cxt_len, hidden_size]
        """
        cxt_embed = F.relu(self.linear(cxt_embed)) # [batch_size, cxt_len, hidden_size]
        qst_embed = F.relu(self.linear(qst_embed)) # [batch_size, qst_len, hidden_size]
        scores = torch.bmm(cxt_embed, qst_embed.transpose(-1, -2)) # [batch_size, cxt_len, qst_len]
        scores = scores.masked_fill(qst_mask.unsqueeze(1) == 0, 1e-18)
        attention_weights = F.softmax(scores, dim=-1) # [batch_size, cxt_len, qst_len]
        return torch.bmm(attention_weights, qst_embed)

***Question Encoding Layer***

In [30]:
class QuestionEncodingLayer(nn.Module):

    def __init__(self, input_size, hidden_size, dropout, n_layers):
        super(QuestionEncodingLayer, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.n_layers = n_layers
        self.stacked_bilstms_layer = StackedBiLSTMsLayer(input_size=input_size, hidden_size=hidden_size, n_layers=n_layers, dropout=dropout)
        self.linear = nn.Linear(input_size, 1)
    
    def linear_self_attention(self, qst_embed, qst_mask):
        """
        :param Tensor[batch_size, qst_len, embedding_size] qst_embed
        :param Tensor[batch_size, qst_len] qst_mask
        :return Tensor[batch_size, qst_len]
        """
        scores = self.linear(qst_embed).squeeze(-1) # [batch_size, qst_len]
        scores = scores.masked_fill(qst_mask == 0, 1e-18)
        return F.softmax(scores, dim=-1)

    
    def forward(self, qst_embed, qst_mask):
        """
        :param Tensor[batch_size, qst_len, embedding_size] qst_embed
        :param Tensor[batch_size, qst_len] qst_mask
        :return Tensor[batch_size, hidden_size * n_layers * 2]
        """
        attention_weights = self.linear_self_attention(qst_embed=qst_embed, qst_mask=qst_mask) # [batch_size, qst_len]
        lstm_outputs = self.stacked_bilstms_layer(inputs=qst_embed) # [batch_size, qst_len, hidden_size * n_layers * 2]
        return torch.bmm(attention_weights.unsqueeze(1), lstm_outputs).squeeze(1)

***BiLinear Attention Layer***

In [60]:
class BiLinearAttentionLayer(nn.Module):

    def __init__(self, cxt_size, qst_size):
        super(BiLinearAttentionLayer, self).__init__()
        self.cxt_size = cxt_size
        self.qst_size = qst_size
        self.linear = nn.Linear(qst_size, cxt_size)
    
    def forward(self, cxt_encoded, qst_encoded, cxt_mask):
        """
        :param Tensor[batch_size, cxt_len, cxt_size] cxt_encoded
        :param Tensor[batch_size, qst_size] qst_encoded
        :param Tensor[batch_size, cxt_len] cxt_mask
        :return Tensor[batch_size, cxt_len, hidden_size]
        """
        qst_encoded = self.linear(qst_encoded) # [batch_size, cxt_size]
        scores = torch.bmm(cxt_encoded, qst_encoded.unsqueeze(-1)) # [batch_size, cxt_len, 1]
        scores = scores.squeeze(-1).masked_fill(cxt_mask == 0, 1e-18) # [batch_size, cxt_len]
        return F.log_softmax(scores, dim=-1) if self.training else F.softmax(scores, dim=-1)

***Document reader Question Answering Model***

In [83]:
class DrQA(nn.Module):

    def __init__(self, vocab_size, embedding_size, n_extra_features, hidden_size, n_layers, dropout, pad_index):
        super(DrQA, self).__init__()
        self.vocab_size = vocab_size
        self.embedding_size = embedding_size
        self.n_extra_features = n_extra_features
        self.hidden_size = hidden_size
        self.n_layers = n_layers
        self.dropout = nn.Dropout(p=dropout)
        self.pad_index = pad_index
        self.embedding = nn.Embedding(vocab_size, embedding_size, padding_idx=pad_index)
        self.align_question_embedding_layer = AlignQuestionEmbeddingLayer(hidden_size=embedding_size)
        self.cxt_stacked_bi_lstm_layer = StackedBiLSTMsLayer(input_size=embedding_size * 2 + n_extra_features,
                                                             hidden_size=hidden_size, n_layers=n_layers, dropout=dropout)
        self.qst_encoding_layer = QuestionEncodingLayer(input_size=embedding_size, hidden_size=hidden_size, dropout=dropout, n_layers=n_layers)
        self.bilinear_attention_layer_start = BiLinearAttentionLayer(cxt_size=hidden_size * n_layers * 2, qst_size=hidden_size * n_layers * 2)
        self.bilinear_attention_layer_end = BiLinearAttentionLayer(cxt_size=hidden_size * n_layers * 2, qst_size=hidden_size * n_layers * 2)
    
    def load_glove_embeddings(self, path, most_common_indexes):
        def tune_embeddings(grad, words=most_common_indexes):
            grad[most_common_indexes] = 0
            return grad
        
        self.embedding.weight = nn.Parameter(torch.FloatTensor(np.load(path)))
        self.embedding.weight.register_hook(tune_embeddings) # Only fine-tune the 1000 most frequent question words
    
    def make_cxt_mask(self, cxt_sequences):
        """
        :param Tensor[batch_size, cxt_len]
        :return Tensor[batch_size, cxt_len]
        """
        return cxt_sequences != self.pad_index
    
    def make_qst_mask(self, qst_sequences):
        """
        :param Tensor[batch_size, qst_len]
        :return Tensor[batch_size, qst_len]
        """
        return qst_sequences != self.pad_index
    
    @staticmethod
    def decode(starts, ends):
        """
        :param Tensor[batch_size, cxt_len] starts
        :param Tensor[batch_size, cxt_len] ends
        :return list(int) start_indexes
        :return list(int) end_indexes
        :return list(float) pred_probas
        """
        start_indexes, end_indexes, pred_probas = [], [], []
        for i in range(starts.size(0)):
            probas = torch.ger(starts[i], ends[i]) # [cxt_len, cxt_len]
            proba, index = torch.topk(probas.view(-1), k=1)
            start_indexes.append(index.tolist()[0] // probas.size(0))
            end_indexes.append(index.tolist()[0] % probas.size(1))
            pred_probas.append(proba.tolist()[0])
        return start_indexes, end_indexes, pred_probas

    @staticmethod
    def build_exact_match(cxt_sequences, qst_sequences):
        """
        :param Tensor[batch_size, cxt_len] cxt_sequences
        :param Tensor[batch_size, qst_len] qst_sequences
        :return Tensor[batch_size, cxt_len]
        """
        em_sequences = []
        for i, cxt_sequence in enumerate(cxt_sequences):
            em = [word in qst_sequences[i] for word in cxt_sequence]
            em_sequences.append(em)
        return torch.tensor(em_sequences, dtype=torch.float, device=cxt_sequences.device)

    @staticmethod
    def build_normalized_term_frequency(cxt_sequences):
        """
        :param Tensor[batch_size, cxt_len] cxt_sequences
        :return Tensor[batch_size, cxt_len]
        """
        ntfs = []
        for i, cxt_sequence in enumerate(cxt_sequences):
            counts = collections.Counter(cxt_sequence.tolist())
            count_norm = sum(counts.values())
            ntfs.append([counts[indice] / count_norm for indice in cxt_sequence.tolist()])
        return torch.tensor(ntfs, dtype=torch.float, device=cxt_sequences.device)
    
    def forward(self, cxt_sequences, pos_sequences, ner_sequences, qst_sequences):
        """
        :param Tensor[batch_size, cxt_len] cxt_sequences
        :param Tensor[batch_size, cxt_len] pos_sequences
        :param Tensor[batch_size, cxt_len] ner_sequences
        :param Tensor[batch_size, qst_len] qst_sequences
        :param Tensor[batch_size, cxt_len] cxt_mask
        :param Tensor[batch_size, qst_len] qst_mask
        :return Tensor[batch_size, cxt_len] starts
        :return Tensor[batch_size, cxt_len] ends
        """
        cxt_mask = self.make_cxt_mask(cxt_sequences) # [batch_size, cxt_len]
        qst_mask = self.make_qst_mask(qst_sequences) # [batch_size, qst_len]
        cxt_embedded = self.dropout(self.embedding(cxt_sequences)) # [batch_size, cxt_len, embedding_size]
        qst_embedded = self.dropout(self.embedding(qst_sequences)) # [batch_size, cxt_len, embedding_size]
        cxt_aligned = self.align_question_embedding_layer(cxt_embed=cxt_embedded, qst_embed=qst_embedded,
                                                          qst_mask=qst_mask) # [batch_size, cxt_len, embedding_size]
        # em_sequences = self.build_exact_match(cxt_sequences=cxt_sequences, qst_sequences=qst_sequences)
        # tf_sequences = self.build_normalized_term_frequency(cxt_sequences=cxt_sequences)
        cxt_encoded = torch.cat([cxt_embedded, # [batch_size, cxt_len, embedding_size]
                                #  em_sequences.unsqueeze(-1), # [batch_size, cxt_len, 1]
                                 pos_sequences.unsqueeze(-1), # [batch_size, cxt_len, 1]
                                 ner_sequences.unsqueeze(-1), # [batch_size, cxt_len, 1]
                                #  tf_sequences.unsqueeze(-1), # [batch_size, cxt_len, 1]
                                 cxt_aligned # [batch_size, cxt_len, embedding_size]
                                 ], dim=-1) # [batch_size, cxt_len, embedding_size * 2 + 4]
        cxt_encoded = self.cxt_stacked_bi_lstm_layer(inputs=cxt_encoded) # [batch_size, cxt_len, hidden_size * n_layers * 2]
        qst_encoded = self.qst_encoding_layer(qst_embed=qst_embedded, qst_mask=qst_mask) # [batch_size, hidden_size * n_layers * 2]
        starts = self.bilinear_attention_layer_start(cxt_encoded=cxt_encoded, qst_encoded=qst_encoded, cxt_mask=cxt_mask)
        ends = self.bilinear_attention_layer_end(cxt_encoded=cxt_encoded, qst_encoded=qst_encoded, cxt_mask=cxt_mask)
        return starts, ends

***Training routines***

In [84]:
class AverageMeter:
    
    def __init__(self):
        self.value = 0.
        self.sum = 0.
        self.count = 0
        self.average = 0.
        
    def reset(self):
        self.value = 0.
        self.sum = 0.
        self.count = 0
        self.average = 0.
        
    def update(self, value, n=1):
        self.value = value
        self.sum += value * n
        self.count += n
        self.average = self.sum / self.count

In [85]:
def normalize_answer(s):
    """Performs a series of cleaning steps on the ground truth and predicted answer."""
    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
    """Returns maximum value of metrics for predicition by model against
    multiple ground truths.
    
    :param func metric_fn: can be 'exact_match_score' or 'f1_score'
    :param str prediction: predicted answer span by the model
    :param list ground_truths: list of ground truths against which
                               metrics are calculated. Maximum values of 
                               metrics are chosen.
    """
    scores_for_ground_truths = []
    for ground_truth in ground_truths:
        score = metric_fn(prediction, ground_truth)
        scores_for_ground_truths.append(score)
        
    return max(scores_for_ground_truths)


def f1_score(prediction, ground_truth):
    """Returns f1 score of two strings."""
    prediction_tokens = normalize_answer(prediction).split()
    ground_truth_tokens = normalize_answer(ground_truth).split()
    common = collectons.Counter(prediction_tokens) & collectons.Counter(ground_truth_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1


def exact_match_score(prediction, ground_truth):
    """Returns exact_match_score of two strings."""
    return (normalize_answer(prediction) == normalize_answer(ground_truth))

In [86]:
def evaluate(predictions, path='./data/valid.json'):
    """Gets a dictionary of predictions with question_id as key
    and prediction as value. The validation dataset has multiple 
    answers for a single question. Hence we compare our prediction
    with all the answers and choose the one that gives us
    the maximum metric (em or f1).

    :param dict predictions
    :return float exact_match: 1 if the prediction and ground truth match exactly, 0 otherwise.
    :return float f1_score
    """
    dataset = load_json(path)
    f1 = exact_match = total = 0
    for article in dataset:
        for paragraph in article['paragraphs']:
            for qa in paragraph['qas']:
                total += 1
                if qa['id'] not in predictions:
                    continue
                ground_truths = [*map(lambda x: x['text'], qa['answers'])]
                prediction = predictions[qa['id']]
                exact_match += metric_max_over_ground_truths(exact_match_score, prediction, ground_truths)
                f1 += metric_max_over_ground_truths(f1_score, prediction, ground_truths)
    exact_match = 100.0 * exact_match / total
    f1 = 100.0 * f1 / total
    return exact_match, f1

In [91]:
class Trainer:
    
    def __init__(self, model, optimizer, criterion, id_field, text_field):
        self.model = model
        self.optimizer = optimizer
        self.criterion = criterion
        self.id_field = id_field
        self.text_field = text_field
        
    def train_step(self, loader, epoch, grad_clip):
        loss_tracker = AverageMeter()
        self.model.train()
        progress_bar = tqdm.tqdm(enumerate(loader), total=len(loader))
        for i, batch in progress_bar:
            self.optimizer.zero_grad()
            starts, ends = self.model(cxt_sequences=batch.cxt, pos_sequences=batch.pos, ner_sequences=batch.ner,
                                      qst_sequences=batch.qst) # [batch_size, cxt_len]
            loss = self.criterion(starts, batch.trg[:, 0]) + self.criterion(ends, batch.trg[:, 1])
            loss.backward()
            nn.utils.clip_grad_norm_(self.model.parameters(), grad_clip)
            self.optimizer.step()
            loss_tracker.update(loss.item())
            progress_bar.set_description(f'Epoch: {epoch+1:02d} -     loss: {loss_tracker.average:.3f}')
        return loss_tracker.average
    
    def validate(self, loader, epoch):
        loss_tracker, predictions = AverageMeter(), {}
        self.model.eval()
        with torch.no_grad():
            progress_bar = tqdm.tqdm(enumerate(loader), total=len(loader))
            for i, batch in progress_bar:
                starts, ends = self.model(cxt_sequences=batch.cxt, pos_sequences=batch.pos, ner_sequences=batch.ner,
                                          qst_sequences=batch.qst) # [batch_size, cxt_len]
                loss = self.criterion(starts, batch.trg[:, 0]) + self.criterion(ends, batch.trg[:, 1])
                start_indexes, end_indexes, _ = self.model.decode(starts=starts, ends=ends)
                for i in range(starts.size(0)):
                    id = self.id_field.vocab.itos[batch.id[i].item()]
                    prediction = batch.cxt[i][start_indexes[i]:end_indexes[i]+1]
                    predictions[id] = ' '.join([self.text_field.vocab.itos[indice.item()] for indice in prediction])
                loss_tracker.update(loss.item())
                progress_bar.set_description(f'Epoch: {epoch+1:02d} - val_loss: {loss_tracker.average:.3f}')
        return loss_tracker.average, predictions
    
    def train(self, train_loader, valid_loader, n_epochs, grad_clip):
        history, best_loss = {'loss': [], 'val_loss': [], 'em': [], 'f1': []}, np.inf
        for epoch in range(n_epochs):
            loss = self.train_step(train_loader, epoch, grad_clip)
            val_loss, predictions = self.validate(valid_loader, epoch)
            em, f1 = evaluate(predictions)
            print(f'\nF1={f1:.3f} - EM={em:.3f}')
            history['loss'].append(loss); history['val_loss'].append(val_loss)
            history['em'].append(em); history['f1'].append(f1)
            if best_loss > val_loss:
                best_loss = val_loss
                torch.save(self.model.state_dict(), './checkpoints/DrQA.pth')
        return history

***Train the model***

In [92]:
N_LAYERS = 3
HIDDEN_SIZE = 128
EMBED_SIZE = 300
DROPOUT = 0.3
N_EPOCHS = 5
BATCH_SIZE = 32
GRAD_CLIP = 10.0

In [93]:
drqa = DrQA(vocab_size=len(TEXT.vocab), embedding_size=EMBED_SIZE, n_extra_features=2, hidden_size=HIDDEN_SIZE, n_layers=N_LAYERS,
            dropout=DROPOUT, pad_index=TEXT.vocab.stoi[TEXT.pad_token])
drqa.load_glove_embeddings('./data/GloVe_DrQA.npy', most_common_indexes)
drqa.to(DEVICE)
optimizer = optim.Adamax(params=drqa.parameters())
criterion = nn.NLLLoss(ignore_index=TEXT.vocab.stoi[TEXT.pad_token])
print(f'Number of parameters of the model: {sum(p.numel() for p in drqa.parameters() if p.requires_grad):,}')
print(drqa)
trainer = Trainer(model=drqa, optimizer=optimizer, criterion=criterion, id_field=ID, text_field=TEXT)

Number of parameters of the model: 36,219,697
DrQA(
  (dropout): Dropout(p=0.3, inplace=False)
  (embedding): Embedding(91446, 300, padding_idx=1)
  (align_question_embedding_layer): AlignQuestionEmbeddingLayer(
    (linear): Linear(in_features=300, out_features=300, bias=True)
  )
  (cxt_stacked_bi_lstm_layer): StackedBiLSTMsLayer(
    (dropout): Dropout(p=0.3, inplace=False)
    (lstms): ModuleList(
      (0): LSTM(602, 128, num_layers=3, bidirectional=True)
      (1): LSTM(256, 128, num_layers=3, bidirectional=True)
      (2): LSTM(256, 128, num_layers=3, bidirectional=True)
    )
  )
  (qst_encoding_layer): QuestionEncodingLayer(
    (stacked_bilstms_layer): StackedBiLSTMsLayer(
      (dropout): Dropout(p=0.3, inplace=False)
      (lstms): ModuleList(
        (0): LSTM(300, 128, num_layers=3, bidirectional=True)
        (1): LSTM(256, 128, num_layers=3, bidirectional=True)
        (2): LSTM(256, 128, num_layers=3, bidirectional=True)
      )
    )
    (linear): Linear(in_features=3

In [None]:
!mkdir -p ./checkpoints
train_iterator, valid_iterator =  BucketIterator.splits((train_dataset, valid_dataset), batch_size=BATCH_SIZE, sort=False, device=DEVICE)
history = trainer.train(train_loader=train_iterator, valid_loader=valid_iterator, n_epochs=N_EPOCHS, grad_clip=GRAD_CLIP)

Epoch: 01 -     loss: 8.513:  80%|███████▉  | 2165/2707 [05:28<01:21,  6.62it/s]