<a href="https://colab.research.google.com/github/dksifoua/Question-Answering/blob/master/2%20-%20BiDAF%2C%20Bi-Directional%20Attention%20Flow.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!nvidia-smi

Mon Nov  2 00:15:26 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 455.32.00    Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   32C    P8     9W /  70W |      0MiB / 15079MiB |      0%      Default |
|                               |                      |                 ERR! |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

## Load dependencies

In [None]:
!pip install tqdm --upgrade >> /dev/null 2>&1
!pip install spacy --upgrade >> /dev/null 2>&1
!python -m spacy download en >> /dev/null 2>&1

In [None]:
import re
import json
import time
import tqdm
import spacy
import string
import itertools
import collections
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

from IPython.core.display import display, HTML

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence, pad_packed_sequence
from torch.utils.data import Dataset, DataLoader

In [None]:
SEED = 546
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {DEVICE}')

Device: cuda


## Prepare data

***Download data***

In [None]:
!rm -rf ./data
!mkdir ./data

!wget --no-check-certificate \
    https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json \
    -O ./data/train-v1.1.json

!wget --no-check-certificate \
    https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json \
    -O ./data/dev-v1.1.json

--2020-11-02 00:15:51--  https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json
Resolving rajpurkar.github.io (rajpurkar.github.io)... 185.199.110.153, 185.199.111.153, 185.199.108.153, ...
Connecting to rajpurkar.github.io (rajpurkar.github.io)|185.199.110.153|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 30288272 (29M) [application/json]
Saving to: ‘./data/train-v1.1.json’


2020-11-02 00:15:53 (60.8 MB/s) - ‘./data/train-v1.1.json’ saved [30288272/30288272]

--2020-11-02 00:15:53--  https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json
Resolving rajpurkar.github.io (rajpurkar.github.io)... 185.199.110.153, 185.199.111.153, 185.199.108.153, ...
Connecting to rajpurkar.github.io (rajpurkar.github.io)|185.199.110.153|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4854279 (4.6M) [application/json]
Saving to: ‘./data/dev-v1.1.json’


2020-11-02 00:15:54 (22.3 MB/s) - ‘./data/dev-v1.1.json’ saved [4854279/485427

***Load JSON data***

In [None]:
def load(path):
    with open(path, mode='r', encoding='utf-8') as file:
        return json.load(file)['data']
    raise FileNotFoundError

In [None]:
train_raw_data = load('./data/train-v1.1.json')
valid_raw_data = load('./data/dev-v1.1.json')
print(f'Length of raw train data: {len(train_raw_data):,}')
print(f'Length of raw valid data: {len(valid_raw_data):,}')

Length of raw train data: 442
Length of raw valid data: 48


***Parse JSON data***

In [None]:
def parse(data, nlp=spacy.load('en')):
    qas = []
    for paragraphs in tqdm.tqdm(data):
        for para in paragraphs['paragraphs']:
            context = nlp(para['context'], disable=['parser'])
            for qa in para['qas']:
                id = qa['id']
                question = nlp(qa['question'], disable=['parser', 'tagger', 'ner'])
                for ans in qa['answers']:
                    qas.append({
                        'id': id,
                        'context': context,
                        'question': question,
                        'answer': nlp(ans['text'], disable=['parser', 'tagger', 'ner']),
                        'answer_start': ans['answer_start'],
                    })
    return qas

In [None]:
train_qas = parse(train_raw_data)
valid_qas = parse(valid_raw_data)
print()
print(f'Length of train qa pairs: {len(train_qas):,}')
print(f'Length of valid qa pairs: {len(valid_qas):,}')
print('==================== Example ====================')
print('Id:', train_qas[0]['id'])
print('Context:', train_qas[0]['context'])
print('Question:', train_qas[0]['question'])
print('Answer starts at:', train_qas[0]['answer_start'])
print('Answer:', train_qas[0]['answer'])

100%|██████████| 442/442 [04:43<00:00,  1.56it/s]
100%|██████████| 48/48 [00:34<00:00,  1.40it/s]


Length of train qa pairs: 87,599
Length of valid qa pairs: 34,726
Id: 5733be284776f41900661182
Context: Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.
Question: To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?
Answer starts at: 515
Answer: Saint Bernadette Soubirous





In [None]:
def test_answer_start(qas):
    """Test answer_start are correct in train set"""
    for qa in tqdm.tqdm(qas):
        answer = qa['answer'].text
        context = qa['context'].text
        answer_start = qa['answer_start']
        assert answer == context[answer_start:answer_start + len(answer)]

In [None]:
test_answer_start(train_qas)
test_answer_start(valid_qas)

100%|██████████| 87599/87599 [00:07<00:00, 11019.19it/s]
100%|██████████| 34726/34726 [00:02<00:00, 11754.15it/s]


***Add targets***

In [None]:
def add_targets(qas):
    """Add start and end index token"""
    for qa in qas:
        context = qa['context']
        answer = qa['answer']
        ans_start = qa['answer_start']
        for i in range(len(context)):
            if context[i].idx == ans_start:
                ans = context[i:i + len(answer)]
                qa['target'] = [ans[0].i, ans[-1].i]
                break

In [None]:
%%time
add_targets(train_qas)
add_targets(valid_qas)
print(f'Length of train qa pairs: {len(train_qas):,}')
print(f'Length of valid qa pairs: {len(valid_qas):,}')

Length of train qa pairs: 87,599
Length of valid qa pairs: 34,726
CPU times: user 1.5 s, sys: 22 ms, total: 1.52 s
Wall time: 1.53 s


In [None]:
def filter_qas(qa):
    """Remove bad targets"""
    if 'target' in [*qa.keys()]:
        start, end = qa['target']
        return qa['context'][start:end + 1].text == qa['answer'].text
    return False

In [None]:
%%time
train_qas = [*filter(filter_qas, train_qas)]
valid_qas = [*filter(filter_qas, valid_qas)]
print(f'Length of train qa pairs after filtering out bad qa pairs: {len(train_qas):,}')
print(f'Length of valid qa pairs after filtering out bad qa pairs: {len(valid_qas):,}')

Length of train qa pairs after filtering out bad qa pairs: 86,597
Length of valid qa pairs after filtering out bad qa pairs: 34,295
CPU times: user 1.16 s, sys: 2.01 ms, total: 1.16 s
Wall time: 1.16 s


In [None]:
def test_targets(qas):
    for qa in qas:
        if 'target' in [*qa.keys()]:
            start, end = qa['target']
            assert qa['context'][start:end + 1].text == qa['answer'].text

In [None]:
%%time
test_targets(train_qas)
test_targets(valid_qas)

CPU times: user 1.09 s, sys: 957 µs, total: 1.09 s
Wall time: 1.1 s


***Build vocabularies***

In [None]:
class Vocab:

    def __init__(self, pad_token, unk_token):
        self.pad_token = pad_token
        self.unk_token = unk_token
        self.vocab = None
        self.word2count = None
        self.word2index = None
        self.index2word = None
    
    def build(self, data, min_freq):
        """
        :param List[Union[spacy.tokens.doc.Doc, str]] data
        :param int min_freq
        """
        words = [self.pad_token, self.unk_token]
        type_0 = type(data[0])
        if type_0 == spacy.tokens.doc.Doc:
            for item in data: # context and question
                words += [word.text.lower() for word in item]
        elif type_0 == str: # id
            words += data
        self.word2count = collections.Counter(words)
        self.vocab = sorted(filter(
            lambda word: self.word2count[word] >= min_freq or word == self.pad_token or word == self.unk_token, self.word2count
        ))
        self.word2index = {word: index for index, word in enumerate(self.vocab)}
        self.index2word = {index: word for index, word in enumerate(self.vocab)}
    
    def __len__(self):
        return len(self.vocab)
    
    def stoi(self, word):
        return self.word2index.get(str(word), self.word2index[self.unk_token])

    def itos(self, index):
        return self.index2word[index]

In [None]:
class CharVocab:

    def __init__(self, pad_token, unk_token):
        self.pad_token = pad_token
        self.unk_token = unk_token
        self.vocab = None
        self.char2count = None
        self.char2index = None
        self.index2char = None
    
    def build(self, data, min_freq):
        """
        :param List[Union[spacy.tokens.doc.Doc, str, Tuple]] data
        :param int min_freq
        """
        chars = [self.pad_token, self.unk_token]
        type_0 = type(data[0])
        if type_0 == spacy.tokens.doc.Doc:
            for item in data: # context and question
                for word in item:
                    chars += [*word.text.lower().strip()]
        else:
            raise Exception
        self.char2count = collections.Counter(chars)
        self.vocab = sorted(filter(
            lambda char: self.char2count[char] >= min_freq or char == self.pad_token or char == self.unk_token, self.char2count
        ))
        self.char2index = {char: index for index, char in enumerate(self.vocab)}
        self.index2char = {index: char for index, char in enumerate(self.vocab)}
    
    def __len__(self):
        return len(self.vocab)
    
    def stoi(self, char):
        return self.char2index.get(str(char), self.char2index[self.unk_token])

    def itos(self, index):
        return self.index2char[index]

In [None]:
%%time
PAD_TOKEN = '<pad>'
UNK_TOKEN = '<unk>'

ID = Vocab(pad_token=PAD_TOKEN, unk_token=UNK_TOKEN)
TEXT = Vocab(pad_token=PAD_TOKEN, unk_token=UNK_TOKEN)
CHAR = CharVocab(pad_token=PAD_TOKEN, unk_token=UNK_TOKEN)

ids = [*map(lambda qa: qa['id'], train_qas)] + [*map(lambda qa: qa['id'], valid_qas)]
contexts, questions = zip(*map(lambda qa: (qa['context'], qa['question']), train_qas))

ID.build(data=[*set(ids)], min_freq=0)
TEXT.build(data=[*set(contexts)] + [*set(questions)], min_freq=5)
CHAR.build(data=[*set(contexts)] + [*set(questions)], min_freq=0)

print(f'Length of ID vocabulary: {len(ID):,}')
print(f'Length of TEXT vocabulary: {len(TEXT):,}')
print(f'Length of CHAR vocabulary: {len(CHAR):,}')

Length of ID vocabulary: 97,108
Length of TEXT vocabulary: 26,885
Length of CHAR vocabulary: 1,261
CPU times: user 9.44 s, sys: 271 ms, total: 9.71 s
Wall time: 9.71 s


***Build datasets***

In [None]:
class SQuADV1Dataset(Dataset):

    def __init__(self, data, id_vocab, text_vocab, char_vocab):
        self.data = data
        self.id_vocab = id_vocab
        self.text_vocab = text_vocab
        self.char_vocab = char_vocab
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        id = torch.LongTensor([self.id_vocab.stoi(item['id'])])

        context = item['context']
        ctx = torch.LongTensor([*map(lambda word: self.text_vocab.stoi(word.text.lower()), context)])
        char_ctx = [*map(lambda word: [*map(self.char_vocab.stoi, [*word.text.lower().strip()])], context)]
        char_len_ctx = torch.LongTensor([*map(len, char_ctx)])
        char_ctx = torch.LongTensor([*itertools.chain.from_iterable(char_ctx)])
        
        question = item['question']
        qst = torch.LongTensor([*map(lambda word: self.text_vocab.stoi(word.text.lower()), question)])
        char_qst = [*map(lambda word: [*map(self.char_vocab.stoi, [*word.text.lower().strip()])], question)]
        char_len_qst = torch.LongTensor([*map(len, char_qst)])
        char_qst = torch.LongTensor([*itertools.chain.from_iterable(char_qst)])
        
        trg = torch.LongTensor(item['target'])
        
        return id, ctx, char_ctx, char_len_ctx, qst, char_qst, char_len_qst, trg

In [None]:
train_dataset = SQuADV1Dataset(data=train_qas, id_vocab=ID, text_vocab=TEXT, char_vocab=CHAR)
valid_dataset = SQuADV1Dataset(data=valid_qas, id_vocab=ID, text_vocab=TEXT, char_vocab=CHAR)

id, ctx, char_ctx, char_len_ctx, qst, char_qst, char_len_qst, trg = train_dataset[0]
print(f'id shape: {id.shape}')
print(f'ctx shape: {ctx.shape}')
print(f'char_ctx shape: {char_ctx.shape}')
print(f'char_len_ctx shape: {char_len_ctx.shape}')
print(f'qst shape: {qst.shape}')
print(f'char_qst shape: {char_qst.shape}')
print(f'char_len_qst shape: {char_len_qst.shape}')
print(f'trg shape: {trg.shape}')

id shape: torch.Size([1])
ctx shape: torch.Size([142])
char_ctx shape: torch.Size([572])
char_len_ctx shape: torch.Size([142])
qst shape: torch.Size([14])
char_qst shape: torch.Size([59])
char_len_qst shape: torch.Size([14])
trg shape: torch.Size([2])


***Build data loaders***

In [None]:
class DotDict(dict):
    __getattr__ = dict.get

In [None]:
def pad_char_sequence(batch_padded_seq, batch_char_seq, batch_char_len, pad_index):
    """
    :param Tensor[batch_size, seq_len] batch_padded_seq
    :param ListTensor[batch_size, seq_len * n_chars] batch_char_seq
    :param ListTensor[batch_size, seq_len] batch_char_len
    """
    max_char_len = 0
    for char_len in batch_char_len:
        m = char_len.max().item()
        if m > max_char_len:
            max_char_len = m
    batch_padded_char = torch.ones((batch_padded_seq.size(0), batch_padded_seq.size(1), max_char_len), dtype=torch.long) * pad_index
    for i in range(batch_padded_seq.size(0)):
        seq = batch_padded_seq[i] # [seq_len]
        char_seq = batch_char_seq[i] # [seq_len * n_chars]
        char_len = batch_char_len[i] # [seq_len]
        j = 0
        for k, length in enumerate(char_len):
            batch_padded_char[i, k, j:length] = char_seq[j:length]
            j += length
    return batch_padded_char

In [None]:
def add_padding(batch, text_vocab=TEXT, char_vocab=CHAR, include_lengths=True, device=DEVICE):
    """Pad batch of sequence with different lengths"""
    batch_id, batch_ctx, batch_char_ctx, batch_char_len_ctx, batch_qst, batch_char_qst, batch_char_len_qst, batch_trg = zip(*batch)
    if include_lengths:
        len_ctx = torch.LongTensor([ctx.size(0) for ctx in batch_ctx]).to(device)
        len_qst = torch.LongTensor([qst.size(0) for qst in batch_qst]).to(device)
    batch_padded_id = pad_sequence(batch_id, batch_first=True).to(device)
    batch_padded_ctx = pad_sequence(batch_ctx, batch_first=True, padding_value=text_vocab.stoi(text_vocab.pad_token)).to(device)
    batch_padded_qst = pad_sequence(batch_qst, batch_first=True, padding_value=text_vocab.stoi(text_vocab.pad_token)).to(device)
    batch_padded_trg = pad_sequence(batch_trg, batch_first=True).to(device)

    pad_index = char_vocab.stoi(text_vocab.pad_token)
    batch_padded_char_ctx = pad_char_sequence(batch_padded_seq=batch_padded_ctx,
                                              batch_char_seq=batch_char_ctx,
                                              batch_char_len=batch_char_len_ctx,
                                              pad_index=pad_index).to(device)
    batch_padded_char_qst = pad_char_sequence(batch_padded_seq=batch_padded_qst,
                                              batch_char_seq=batch_char_qst,
                                              batch_char_len=batch_char_len_qst,
                                              pad_index=pad_index).to(device)
    return DotDict({
        'id': batch_padded_id,
        'ctx': (batch_padded_ctx, len_ctx) if include_lengths else batch_padded_ctx,
        'ctx_char': batch_padded_char_ctx,
        'qst': (batch_padded_qst, len_qst) if include_lengths else batch_padded_qst,
        'qst_char': batch_padded_char_qst,
        'trg': batch_padded_trg
    })

In [None]:
BATCH_SIZE = 60

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, collate_fn=add_padding)
valid_dataloader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, collate_fn=add_padding)

for batch in train_dataloader:
    print('batch.id shape:', batch.id.shape)
    print('batch.ctx shape:', batch.ctx[0].shape, batch.ctx[1].shape)
    print('batch.ctx_char shape:', batch.ctx_char.shape)
    print('batch.qst shape:', batch.qst[0].shape, batch.qst[1].shape)
    print('batch.qst_char shape:', batch.qst_char.shape)
    print('batch.trg shape:', batch.trg.shape)
    break

batch.id shape: torch.Size([60, 1])
batch.ctx shape: torch.Size([60, 253]) torch.Size([60])
batch.ctx_char shape: torch.Size([60, 253, 15])
batch.qst shape: torch.Size([60, 19]) torch.Size([60])
batch.qst_char shape: torch.Size([60, 19, 14])
batch.trg shape: torch.Size([60, 2])


***TODO: Download pretrained GloVe embedding***

## Modeling

***Character Embedding Layer***

In [None]:
class CharacterEmbeddingLayer(nn.Module):

    def __init__(self, char_vocab_size, char_embedding_size, token_embedding_size, kernel_size, pad_index):
        super(CharacterEmbeddingLayer, self).__init__()
        self.char_vocab_size = char_vocab_size
        self.char_embedding_size = char_embedding_size
        self.token_embedding_size = token_embedding_size
        self.kernel_size = kernel_size
        self.pad_index = pad_index
        self.embedding = nn.Embedding(char_vocab_size, char_embedding_size, padding_idx=pad_index)
        self.cond2d = nn.Conv2d(1, token_embedding_size, kernel_size=(char_embedding_size, kernel_size))
    
    def forward(self, char_sequences):
        """
        :param LongTensor[batch_size, seq_len, char_len]
        :return FloatTensor[batch_size, seq_len, token_embedding_size]
        """
        embedded = self.embedding(char_sequences) # [batch_size, seq_len, char_len, char_embedding_size]
        embedded = embedded.transpose(-1, -2) # [batch_size, seq_len, char_embedding_size, char_len]
        embedded = embedded.view(-1, self.char_embedding_size, embedded.size(-1)) # [batch_size * seq_len, char_embedding_size, char_len]
        embedded = embedded.unsqueeze(1) # [batch_size * seq_len, 1, char_embedding_size, char_len]
        conved = F.relu(self.cond2d(embedded)) # [batch_size * seq_len, token_embedding_size, 1, char_len - kernel_size + 1]
        conved = conved.squeeze(2) # [batch_size * seq_len, token_embedding_size, char_len - kernel_size + 1]
        out = F.max_pool1d(conved, kernel_size=conved.size(-1)) # [batch_size * seq_len, token_embedding_size, 1]
        # print(out.shape, '[batch_size * seq_len, token_embedding_size, 1, char_len - kernel_size + 1]')
        out = out.squeeze(-1) # [batch_size * seq_len, token_embedding_size]
        out = out.view(char_sequences.size(0), -1, out.size(-1)) # [batch_size, seq_len, token_embedding_size]
        # May be apply bias + tanh non linearity???
        return out

***Highway Network Layer***

[Highway Networks paper](https://arxiv.org/pdf/1505.00387.pdf)

In [None]:
class HighwayNetworkLayer(nn.Module):

    def __init__(self, hidden_size, n_layers):
        super(HighwayNetworkLayer, self).__init__()
        self.hidden_size = hidden_size
        self.n_layers = n_layers
        self.fc_flow = nn.ModuleList([nn.Linear(hidden_size, hidden_size) for _ in range(n_layers)])
        self.fc_gate = nn.ModuleList([nn.Linear(hidden_size, hidden_size) for _ in range(n_layers)])
    
    def forward(self, x):
        """
        :param FloatTensor[batch_size, seq_len, hidden_size] x
        :return FloatTensor[batch_size, seq_len, hidden_size]
        """
        for i in range(self.n_layers):
            flow = F.relu(self.fc_flow[i](x))
            gate = torch.sigmoid(self.fc_gate[i](x))
            x = gate * flow + (1 - gate) * x
        return x

***Contextual Embedding Layer***

In [None]:
class ContextualEmbeddingLayer(nn.Module):

    def __init__(self, hidden_size, n_layers):
        super(ContextualEmbeddingLayer, self).__init__()
        self.hidden_size = hidden_size
        self.n_layers = n_layers
        self.high_network_layer = HighwayNetworkLayer(hidden_size=hidden_size, n_layers=n_layers)
        self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True, bidirectional=True)

    def forward(self, inputs, sequence_lengths):
        """
        :param FloatTensor[batch_size, seq_len, hidden_size] inputs
        :param FloatTensor[batch_size,] sequence_lengths
        :return FloatTensor[batch_size, seq_len, hidden_size * 2]
        """
        highway = self.high_network_layer(inputs) # [batch_size, seq_len, hidden_size]
        packed = pack_padded_sequence(highway, sequence_lengths, batch_first=True, enforce_sorted=False)
        out_packed, _ = self.lstm(packed)
        out_padded, _ = pad_packed_sequence(out_packed, batch_first=True) # [batch_size, seq_len, hidden_size * 2]
        return out_padded

***Attention Flow Layer***

In [None]:
class AttentionFlowLayer(nn.Module):

    def __init__(self, hidden_size):
        super(AttentionFlowLayer, self).__init__()
        self.hidden_size = hidden_size
        self.W = nn.Linear(hidden_size * 3, 1, bias=False)
    
    def context_query_similarity(self, context, query):
        """
        :param FloatTensor[batch_size, context_len, hidden_size] context
        :param FloatTensor[batch_size, query_len, hidden_size] query
        :return FloatTensor[batch_size, context_len, query_len]
        """
        context_len, query_len = context.size(1), query.size(1)
        context = context.unsqueeze(2).repeat(1, 1, query_len, 1) # [batch_size, context_len, query_len, hidden_size]
        query = query.unsqueeze(1).repeat(1, context_len, 1, 1) # [batch_size, context_len, query_len, hidden_size]
        concat = torch.cat([context, query, context * query], dim=-1) # [batch_size, context_len, query_len, hidden_size * 3]
        return self.W(concat).squeeze(-1)
    
    def context_query_attention(self, query, scores):
        """
        :param FloatTensor[batch_size, query_len, hidden_size] query
        :param FloatTensor[batch_size, context_len, query_len] scores
        :return FloatTensor[batch_size, context_len, hidden_size]
        """
        attention_weights = F.softmax(scores, dim=-1) # [batch_size, context_len, query_len]
        return torch.bmm(attention_weights, query)
    
    def query_context_attention(self, context, scores):
        """
        :param FloatTensor[batch_size, context_len, hidden_size] context
        :param FloatTensor[batch_size, context_len, query_len] scores
        :return FloatTensor[batch_size, 1, hidden_size]
        """
        attention_weights = F.softmax(scores.max(dim=-1).values, dim=-1) # [batch_size, context_len]
        return torch.bmm(attention_weights.unsqueeze(1), context)

    def query_aware_representation(self, context, context2query, query2context):
        """
        :param FloatTensor[batch_size, context_len, hidden_size] context
        :param FloatTensor[batch_size, context_len, hidden_size] context2query
        :param FloatTensor[batch_size, 1, hidden_size] query2context
        :return FloatTensor[batch_size, context_len, hidden_size * 4]
        """
        query2context = query2context.repeat(1, context.size(1), 1) # [batch_size, context_len, hidden_size]
        return torch.cat([context, context2query, context * context2query, context * query2context], dim=-1)

    def forward(self, context, query):
        """
        :param FloatTensor[batch_size, context_len, hidden_size] context
        :param FloatTensor[batch_size, query_len, hidden_size] query
        :return FloatTensor[batch_size, context_len, hidden_size * 4]
        """
        scores = self.context_query_similarity(context=context, query=query)
        context2query = self.context_query_attention(query=query, scores=scores)
        query2context = self.query_context_attention(context=context, scores=scores)
        return self.query_aware_representation(context=context, context2query=context2query, query2context=query2context)

***Modeling Layer***

In [None]:
class ModelingLayer(nn.Module):

    def __init__(self, input_size, hidden_size, n_layers):
        super(ModelingLayer, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.n_layers = n_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers=n_layers, batch_first=True, bidirectional=True)
    
    def forward(self, inputs, sequence_lengths):
        """
        :param FloatTensor[batch_size, seq_len, input_size] inputs
        :param FloatTensor[batch_size,] sequence_lengths
        :return FloatTensor[batch_size, seq_len, hidden_size * 2]
        """
        packed = pack_padded_sequence(inputs, sequence_lengths, batch_first=True, enforce_sorted=False)
        out_packed, _ = self.lstm(packed)
        out_padded, _ = pad_packed_sequence(out_packed, batch_first=True) # [batch_size, seq_len, hidden_size * 2]
        return out_padded

***Bi-Directional Attention Flow Model***

In [None]:
class BiDAF(nn.Module):

    def __init__(self, char_vocab_size, token_vocab_size, char_embedding_size, token_embedding_size, kernel_size, char_pad_index, token_pad_index, highway_n_layers, modeling_n_layers):
        super(BiDAF, self).__init__()
        self.char_vocab_size = char_vocab_size
        self.token_vocab_size = token_vocab_size
        self.char_embedding_size = char_embedding_size
        self.token_embedding_size = token_embedding_size
        self.kernel_size = kernel_size
        self.char_pad_index = char_pad_index
        self.token_pad_index = token_pad_index
        self.highway_n_layers = highway_n_layers
        self.modeling_n_layers = modeling_n_layers
        new_hidden_size = token_embedding_size + token_embedding_size
        self.token_embedding = nn.Embedding(token_vocab_size, token_embedding_size, padding_idx=token_pad_index)
        self.char_embedding = CharacterEmbeddingLayer(char_vocab_size=char_vocab_size,
                                                      char_embedding_size=char_embedding_size,
                                                      token_embedding_size=token_embedding_size,
                                                      kernel_size=kernel_size,
                                                      pad_index=char_pad_index)
        self.contextual_embedding = ContextualEmbeddingLayer(hidden_size=new_hidden_size, n_layers=highway_n_layers)
        self.attention_flow = AttentionFlowLayer(hidden_size=new_hidden_size * 2)
        self.modeling = ModelingLayer(input_size=new_hidden_size * 8, hidden_size=new_hidden_size, n_layers=modeling_n_layers)
        self.fc_start = nn.Linear(new_hidden_size * 10, 1, bias=False)
        self.lstm_end = nn.LSTM(new_hidden_size * 2, new_hidden_size, batch_first=True, bidirectional=True)
        self.fc_end = nn.Linear(new_hidden_size * 10, 1, bias=False)
    
    def feed_lstm(self, inputs, sequence_lengths):
        """
        :param FloatTensor[batch_size, seq_len, hidden_size] x
        :param FloatTensor[batch_size,] sequence_lengths
        :return FloatTensor[batch_size, seq_len, hidden_size * 2]
        """
        packed = pack_padded_sequence(inputs, sequence_lengths, batch_first=True, enforce_sorted=False)
        out_packed, _ = self.lstm_end(packed)
        out_padded, _ = pad_packed_sequence(out_packed, batch_first=True) # [batch_size, seq_len, hidden_size * 2]
        return out_padded
    
    @staticmethod
    def decode(starts, ends):
        """
        :param IntTensor[batch_size, ctx_len] starts
        :param IntTensor[batch_size, ctx_len] ends
        :return list(int) start_indexes
        :return list(int) end_indexes
        :return list(float) pred_probas
        """
        start_indexes, end_indexes, pred_probas = [], [], []
        for i in range(starts.size(0)):
            probas = torch.ger(starts[i], ends[i]) # [ctx_len, ctx_len]
            proba, index = torch.topk(probas.view(-1), k=1)
            start_indexes.append(index.tolist()[0] // probas.size(0))
            end_indexes.append(index.tolist()[0] % probas.size(1))
            pred_probas.append(proba.tolist()[0])
        return start_indexes, end_indexes, pred_probas
    
    def forward(self, ctx, ctx_lengths, ctx_char, qst, qst_lengths, qst_char):
        """
        :param LongTensor[batch_size, ctx_len] ctx
        :param LongTensor[batch_size,] ctx_lengths
        :param LongTensor[batch_size, ctx_len, ctx_char_len] ctx_char
        :param LongTensor[batch_size, qst_len] qst
        :param LongTensor[batch_size,] qst_lengths
        :param LongTensor[batch_size, qst_len, qst_char_len] qst_char
        :return Tensor[batch_size, ctx_len] starts
        :return Tensor[batch_size, ctx_len] ends
        """
        # Embeddings + Highway Networks
        ctx_token_embedded = self.token_embedding(ctx) # [batch_size, ctx_len, token_embedding_size]
        qst_token_embedded = self.token_embedding(qst) # [batch_size, qst_len, token_embedding_size]
        ctx_char_embedded = self.char_embedding(char_sequences=ctx_char) # [batch_size, ctx_len, token_embedding_size]
        qst_char_embedded = self.char_embedding(char_sequences=qst_char) # [batch_size, qst_len, token_embedding_size]
        ctx_embedded = torch.cat([ctx_token_embedded, ctx_char_embedded], dim=-1) # [batch_size, ctx_len, hidden_size=token_embedding_size * 2]
        qst_embedded = torch.cat([qst_token_embedded, qst_char_embedded], dim=-1) # [batch_size, qst_len, hidden_size=token_embedding_size * 2]

        # Contextualized Embeddings
        ctx_contextual_embedded = self.contextual_embedding(inputs=ctx_embedded, sequence_lengths=ctx_lengths) # [batch_size, ctx_len, hidden_size * 2]
        qst_contextual_embedded = self.contextual_embedding(inputs=qst_embedded, sequence_lengths=qst_lengths) # [batch_size, qst_len, hidden_size * 2]

        # Attention Flow
        representation = self.attention_flow(context=ctx_contextual_embedded, query=qst_contextual_embedded) # [batch_size, ctx_len, hidden_size * 8]
        
        # Modeling
        modeled = self.modeling(inputs=representation, sequence_lengths=ctx_lengths) # [batch_size, ctx_len, hidden_size * 2]

        # Outputs
        x = torch.cat([representation, modeled], dim=-1) # [batch_size, ctx_len, hidden_size * 10]
        starts = self.fc_start(x) # [batch_size, ctx_len, 1]
        x = self.feed_lstm(inputs=modeled, sequence_lengths=ctx_lengths) # [batch_size, ctx_len, hidden_size * 2]
        x = torch.cat([representation, x], dim=-1) # [batch_size, ctx_len, hidden_size * 10]
        ends = self.fc_end(x) # [batch_size, ctx_len, 1]

        return starts.squeeze(-1), ends.squeeze(-1)

***Training routines***

In [None]:
class AverageMeter:
    
    def __init__(self):
        self.value = 0.
        self.sum = 0.
        self.count = 0
        self.average = 0.
        
    def reset(self):
        self.value = 0.
        self.sum = 0.
        self.count = 0
        self.average = 0.
        
    def update(self, value, n=1):
        self.value = value
        self.sum += value * n
        self.count += n
        self.average = self.sum / self.count

In [None]:
def normalize(answer: str):
    """Performs a series of cleaning steps on the ground truth and predicted answer."""
    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        return ''.join(ch for ch in text if ch not in set(string.punctuation))

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(answer))))

In [None]:
def get_scores(prediction: str, ground_truth: str):
    prediction, ground_truth = normalize(prediction), normalize(ground_truth)
    em_score = prediction == ground_truth

    prediction_tokens, ground_truth_tokens = prediction.split(), ground_truth.split()
    common = collections.Counter(prediction_tokens) & collections.Counter(ground_truth_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        f1_score = 0
    else:
        precision = 1.0 * num_same / len(prediction_tokens)
        recall = 1.0 * num_same / len(ground_truth_tokens)
        f1_score = (2 * precision * recall) / (precision + recall)

    return em_score, f1_score

In [None]:
def max_metrics_over_ground_truths(prediction: str, ground_truths: list):
    scores = [get_scores(prediction, ground_truth) for ground_truth in ground_truths]
    em_score = max(scores, key=lambda score: score[0])[0]
    f1_score = max(scores, key=lambda score: score[1])[1]
    return em_score, f1_score

In [None]:
def metrics(predictions: dict, qas=valid_qas):
    ground_truths = collections.defaultdict(lambda: [])
    for qa in qas:
        if qa['id'] in predictions:
            ground_truths[qa['id']].append(qa['answer'].text)

    em_scores, f1_scores, total = [], [], 0
    for id in predictions:
        em_score, f1_score = max_metrics_over_ground_truths(predictions[id], ground_truths[id])
        em_scores.append(em_score); f1_scores.append(f1_score)
        total += 1

    em_score = 100.0 * sum(em_scores) / total
    f1_score = 100.0 * sum(f1_scores) / total
    return em_score, f1_score

In [None]:
class Trainer:
    
    def __init__(self, model, optimizer, criterion, id_field, text_field):
        self.model = model
        self.optimizer = optimizer
        self.criterion = criterion
        self.id_field = id_field
        self.text_field = text_field
        
    def train_step(self, loader, epoch, grad_clip):
        loss_tracker = AverageMeter()
        self.model.train()
        progress_bar = tqdm.tqdm(enumerate(loader), total=len(loader))
        for i, batch in progress_bar:
            self.optimizer.zero_grad()
            starts, ends = self.model(*batch.ctx, batch.ctx_char, *batch.qst, batch.qst_char) # [batch_size, ctx_len]
            loss = self.criterion(starts, batch.trg[:, 0]) + self.criterion(ends, batch.trg[:, 1])
            loss.backward()
            nn.utils.clip_grad_norm_(self.model.parameters(), grad_clip)
            self.optimizer.step()
            loss_tracker.update(loss.item())
            progress_bar.set_description(f'Epoch: {epoch+1:02d} -     loss: {loss_tracker.average:.3f}')
        return loss_tracker.average
    
    def validate(self, loader, epoch):
        loss_tracker, predictions = AverageMeter(), {}
        self.model.eval()
        with torch.no_grad():
            progress_bar = tqdm.tqdm(enumerate(loader), total=len(loader))
            for i, batch in progress_bar:
                starts, ends = self.model(*batch.ctx, batch.ctx_char, *batch.qst, batch.qst_char) # [batch_size, ctx_len]
                loss = self.criterion(starts, batch.trg[:, 0]) + self.criterion(ends, batch.trg[:, 1])
                start_indexes, end_indexes, _ = BiDAF.decode(starts=F.softmax(starts, dim=-1), ends=F.softmax(ends, dim=-1))
                for i in range(starts.size(0)):
                    id = self.id_field.itos(batch.id[i].item())
                    prediction = batch.ctx[0][i][start_indexes[i]:end_indexes[i]+1]
                    predictions[id] = ' '.join([self.text_field.itos(indice.item()) for indice in prediction])
                loss_tracker.update(loss.item())
                progress_bar.set_description(f'Epoch: {epoch+1:02d} - val_loss: {loss_tracker.average:.3f}')
        return loss_tracker.average, predictions
    
    def train(self, train_loader, valid_loader, n_epochs, grad_clip):
        history, best_loss = {'loss': [], 'val_loss': [], 'em': [], 'f1': []}, float('inf')
        for epoch in range(n_epochs):
            loss = self.train_step(train_loader, epoch, grad_clip)
            val_loss, predictions = self.validate(valid_loader, epoch)
            em_score, f1_score = metrics(predictions)
            history['loss'].append(loss); history['val_loss'].append(val_loss)
            history['em'].append(em_score); history['f1'].append(f1_score)
            if best_loss > val_loss:
                best_loss = val_loss
                torch.save(self.model.state_dict(), './checkpoints/BiDAF.pth')
            time.sleep(1)
            print(f'\nEM={em_score:.3f}% - F1={f1_score:.3f}%')
        return history

***Train the model***

In [None]:
CHAR_EMBED_SIZE = 8
TOKEN_EMBED_SIZE = 100
KERNEL_SIZE = 5
HIGHWAY_N_LAYERS = 2
MODELING_N_LAYERS = 2
N_EPOCHS = 10
GRAD_CLIP = 1.0

In [None]:
bidaf = BiDAF(char_vocab_size=len(CHAR),
              token_vocab_size=len(TEXT),
              char_embedding_size=CHAR_EMBED_SIZE,
              token_embedding_size=TOKEN_EMBED_SIZE,
              kernel_size=KERNEL_SIZE,
              char_pad_index=CHAR.stoi(PAD_TOKEN),
              token_pad_index=TEXT.stoi(PAD_TOKEN),
              highway_n_layers=HIGHWAY_N_LAYERS,
              modeling_n_layers=MODELING_N_LAYERS)
bidaf.to(DEVICE)
optimizer = optim.Adadelta(params=bidaf.parameters())
criterion = nn.CrossEntropyLoss(ignore_index=TEXT.stoi(PAD_TOKEN))
print(f'Number of parameters of the model: {sum(p.numel() for p in bidaf.parameters() if p.requires_grad):,}')
print(bidaf)
trainer = Trainer(model=bidaf, optimizer=optimizer, criterion=criterion, id_field=ID, text_field=TEXT)

Number of parameters of the model: 8,321,488
BiDAF(
  (token_embedding): Embedding(26885, 100, padding_idx=1318)
  (char_embedding): CharacterEmbeddingLayer(
    (embedding): Embedding(1261, 8, padding_idx=28)
    (cond2d): Conv2d(1, 100, kernel_size=(8, 5), stride=(1, 1))
  )
  (contextual_embedding): ContextualEmbeddingLayer(
    (high_network_layer): HighwayNetworkLayer(
      (fc_flow): ModuleList(
        (0): Linear(in_features=200, out_features=200, bias=True)
        (1): Linear(in_features=200, out_features=200, bias=True)
      )
      (fc_gate): ModuleList(
        (0): Linear(in_features=200, out_features=200, bias=True)
        (1): Linear(in_features=200, out_features=200, bias=True)
      )
    )
    (lstm): LSTM(200, 200, batch_first=True, bidirectional=True)
  )
  (attention_flow): AttentionFlowLayer(
    (W): Linear(in_features=1200, out_features=1, bias=False)
  )
  (modeling): ModelingLayer(
    (lstm): LSTM(1600, 200, num_layers=2, batch_first=True, bidirectional=T

In [None]:
!mkdir -p ./checkpoints
history = trainer.train(train_loader=train_dataloader, valid_loader=valid_dataloader, n_epochs=N_EPOCHS, grad_clip=GRAD_CLIP)

Epoch: 01 -     loss: 9.496:   3%|▎         | 45/1444 [00:37<16:01,  1.46it/s]