<a href="https://colab.research.google.com/github/dksifoua/AbricotJustfy/blob/master/2%20-%20BiDAF%2C%20Bi-Directional%20Attention%20Flow.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Load dependencies

In [1]:
!pip install tqdm --upgrade >> /dev/null 2>&1
!pip install spacy --upgrade >> /dev/null 2>&1
!python -m spacy download en >> /dev/null 2>&1

In [2]:
import re
import json
import time
import tqdm
import spacy
import string
import collections
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

from IPython.core.display import display, HTML

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence, pad_packed_sequence
from torch.utils.data import Dataset, DataLoader

In [3]:
SEED = 546
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {DEVICE}')

Device: cpu


## Prepare data

***Download data***

In [4]:
!rm -rf ./data
!mkdir ./data

!wget --no-check-certificate \
    https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json \
    -O ./data/train-v1.1.json

!wget --no-check-certificate \
    https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json \
    -O ./data/dev-v1.1.json

--2020-11-01 18:18:55--  https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json
Resolving rajpurkar.github.io (rajpurkar.github.io)... 185.199.109.153, 185.199.108.153, 185.199.110.153, ...
Connecting to rajpurkar.github.io (rajpurkar.github.io)|185.199.109.153|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 30288272 (29M) [application/json]
Saving to: ‘./data/train-v1.1.json’


2020-11-01 18:18:57 (45.0 MB/s) - ‘./data/train-v1.1.json’ saved [30288272/30288272]

--2020-11-01 18:18:58--  https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json
Resolving rajpurkar.github.io (rajpurkar.github.io)... 185.199.109.153, 185.199.108.153, 185.199.110.153, ...
Connecting to rajpurkar.github.io (rajpurkar.github.io)|185.199.109.153|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4854279 (4.6M) [application/json]
Saving to: ‘./data/dev-v1.1.json’


2020-11-01 18:18:58 (12.9 MB/s) - ‘./data/dev-v1.1.json’ saved [4854279/485427

***Load JSON data***

In [5]:
def load(path):
    with open(path, mode='r', encoding='utf-8') as file:
        return json.load(file)['data']
    raise FileNotFoundError

In [6]:
train_raw_data = load('./data/train-v1.1.json')
valid_raw_data = load('./data/dev-v1.1.json')
print(f'Length of raw train data: {len(train_raw_data):,}')
print(f'Length of raw valid data: {len(valid_raw_data):,}')

Length of raw train data: 442
Length of raw valid data: 48


***Parse JSON data***

In [7]:
def parse(data, nlp=spacy.load('en')):
    qas = []
    for paragraphs in tqdm.tqdm(data):
        for para in paragraphs['paragraphs']:
            context = nlp(para['context'], disable=['parser'])
            for qa in para['qas']:
                id = qa['id']
                question = nlp(qa['question'], disable=['parser', 'tagger', 'ner'])
                for ans in qa['answers']:
                    qas.append({
                        'id': id,
                        'context': context,
                        'question': question,
                        'answer': nlp(ans['text'], disable=['parser', 'tagger', 'ner']),
                        'answer_start': ans['answer_start'],
                    })
    return qas

In [8]:
train_qas = parse(train_raw_data)
valid_qas = parse(valid_raw_data)
print()
print(f'Length of train qa pairs: {len(train_qas):,}')
print(f'Length of valid qa pairs: {len(valid_qas):,}')
print('==================== Example ====================')
print('Id:', train_qas[0]['id'])
print('Context:', train_qas[0]['context'])
print('Question:', train_qas[0]['question'])
print('Answer starts at:', train_qas[0]['answer_start'])
print('Answer:', train_qas[0]['answer'])

100%|██████████| 442/442 [05:44<00:00,  1.28it/s]
100%|██████████| 48/48 [00:44<00:00,  1.08it/s]


Length of train qa pairs: 87,599
Length of valid qa pairs: 34,726
Id: 5733be284776f41900661182
Context: Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.
Question: To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?
Answer starts at: 515
Answer: Saint Bernadette Soubirous





In [9]:
def test_answer_start(qas):
    """Test answer_start are correct in train set"""
    for qa in tqdm.tqdm(qas):
        answer = qa['answer'].text
        context = qa['context'].text
        answer_start = qa['answer_start']
        assert answer == context[answer_start:answer_start + len(answer)]

In [10]:
test_answer_start(train_qas)
test_answer_start(valid_qas)

100%|██████████| 87599/87599 [00:11<00:00, 7645.42it/s]
100%|██████████| 34726/34726 [00:03<00:00, 10591.13it/s]


***Add targets***

In [11]:
def add_targets(qas):
    """Add start and end index token"""
    for qa in qas:
        context = qa['context']
        answer = qa['answer']
        ans_start = qa['answer_start']
        for i in range(len(context)):
            if context[i].idx == ans_start:
                ans = context[i:i + len(answer)]
                qa['target'] = [ans[0].i, ans[-1].i]
                break

In [12]:
%%time
add_targets(train_qas)
add_targets(valid_qas)
print(f'Length of train qa pairs: {len(train_qas):,}')
print(f'Length of valid qa pairs: {len(valid_qas):,}')

Length of train qa pairs: 87,599
Length of valid qa pairs: 34,726
CPU times: user 1.61 s, sys: 26.9 ms, total: 1.64 s
Wall time: 1.64 s


In [13]:
def filter_qas(qa):
    """Remove bad targets"""
    if 'target' in [*qa.keys()]:
        start, end = qa['target']
        return qa['context'][start:end + 1].text == qa['answer'].text
    return False

In [14]:
%%time
train_qas = [*filter(filter_qas, train_qas)]
valid_qas = [*filter(filter_qas, valid_qas)]
print(f'Length of train qa pairs after filtering out bad qa pairs: {len(train_qas):,}')
print(f'Length of valid qa pairs after filtering out bad qa pairs: {len(valid_qas):,}')

Length of train qa pairs after filtering out bad qa pairs: 86,597
Length of valid qa pairs after filtering out bad qa pairs: 34,295
CPU times: user 1.33 s, sys: 7 ms, total: 1.34 s
Wall time: 1.34 s


In [15]:
def test_targets(qas):
    for qa in qas:
        if 'target' in [*qa.keys()]:
            start, end = qa['target']
            assert qa['context'][start:end + 1].text == qa['answer'].text

In [16]:
%%time
test_targets(train_qas)
test_targets(valid_qas)

CPU times: user 1.31 s, sys: 4.97 ms, total: 1.31 s
Wall time: 1.32 s


***Build vocabularies***

In [17]:
class Vocab:

    def __init__(self, pad_token, unk_token):
        self.pad_token = pad_token
        self.unk_token = unk_token
        self.vocab = None
        self.word2count = None
        self.word2index = None
        self.index2word = None
    
    def build(self, data, min_freq):
        """
        :param List[Union[spacy.tokens.doc.Doc, str, Tuple]] data
        :param int min_freq
        """
        words = [self.pad_token, self.unk_token]
        type_0 = type(data[0])
        if type_0 == spacy.tokens.doc.Doc:
            for item in data: # context and question
                words += [word.text.lower() for word in item]
        elif type_0 == str: # id
            words += data
        elif type_0 == tuple: # pos and ner
            for item in data:
                words += [word.lower() for word in item]
        self.word2count = collections.Counter(words)
        self.vocab = sorted(filter(
            lambda word: self.word2count[word] >= min_freq or word == self.pad_token or word == self.unk_token, self.word2count
        ))
        self.word2index = {word: index for index, word in enumerate(self.vocab)}
        self.index2word = {index: word for index, word in enumerate(self.vocab)}
    
    def __len__(self):
        return len(self.vocab)
    
    def stoi(self, word):
        return self.word2index.get(str(word), self.word2index[self.unk_token])

    def itos(self, index):
        return self.index2word[index]

In [18]:
%%time
PAD_TOKEN = '<pad>'
UNK_TOKEN = '<unk>'

ID = Vocab(pad_token=PAD_TOKEN, unk_token=UNK_TOKEN)
TEXT = Vocab(pad_token=PAD_TOKEN, unk_token=UNK_TOKEN)

ids = [*map(lambda qa: qa['id'], train_qas)] + [*map(lambda qa: qa['id'], valid_qas)]
contexts, questions = zip(*map(lambda qa: (qa['context'], qa['question']), train_qas))

ID.build(data=[*set(ids)], min_freq=0)
TEXT.build(data=[*set(contexts)] + [*set(questions)], min_freq=5)

print(f'Length of ID vocabulary: {len(ID):,}')
print(f'Length of TEXT vocabulary: {len(TEXT):,}')

Length of ID vocabulary: 97,108
Length of TEXT vocabulary: 26,885
CPU times: user 6.78 s, sys: 503 ms, total: 7.28 s
Wall time: 7.29 s


***Build datasets***

In [19]:
class SQuADV1Dataset(Dataset):

    def __init__(self, data, id_vocab, text_vocab):
        self.data = data
        self.id_vocab = id_vocab
        self.text_vocab = text_vocab
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        id = torch.LongTensor([self.id_vocab.stoi(item['id'])])
        ctx = torch.LongTensor([*map(lambda token: self.text_vocab.stoi(token.text.lower()), item['context'])])
        qst = torch.LongTensor([*map(lambda token: self.text_vocab.stoi(token.text.lower()), item['question'])])
        trg = torch.LongTensor(item['target'])
        return id, ctx, qst, trg

In [21]:
train_dataset = SQuADV1Dataset(data=train_qas, id_vocab=ID, text_vocab=TEXT)
valid_dataset = SQuADV1Dataset(data=valid_qas, id_vocab=ID, text_vocab=TEXT)

id, ctx, qst, trg = train_dataset[0]
print(f'id shape: {id.shape}')
print(f'ctx shape: {ctx.shape}')
print(f'qst shape: {qst.shape}')
print(f'trg shape: {trg.shape}')

id shape: torch.Size([1])
ctx shape: torch.Size([142])
qst shape: torch.Size([14])
trg shape: torch.Size([2])


***Build data loaders***

In [22]:
class DotDict(dict):
    __getattr__ = dict.get

In [23]:
def add_padding(batch, pad_token=PAD_TOKEN, text_vocab=TEXT, include_lengths=True, device=DEVICE):
    """Pad batch of sequence with different lengths"""
    batch_id, batch_ctx, batch_qst, batch_trg = zip(*batch)
    if include_lengths:
        len_ctx = torch.LongTensor([ctx.size(0) for ctx in batch_ctx]).to(device)
        len_qst = torch.LongTensor([qst.size(0) for qst in batch_qst]).to(device)
    batch_padded_id = pad_sequence(batch_id, batch_first=True).to(device)
    batch_padded_ctx = pad_sequence(batch_ctx, batch_first=True, padding_value=text_vocab.stoi(pad_token)).to(device)
    batch_padded_qst = pad_sequence(batch_qst, batch_first=True, padding_value=text_vocab.stoi(pad_token)).to(device)
    batch_padded_trg = pad_sequence(batch_trg, batch_first=True).to(device)
    return DotDict({
        'id': batch_padded_id,
        'ctx': (batch_padded_ctx, len_ctx) if include_lengths else batch_padded_ctx,
        'qst': (batch_padded_qst, len_qst) if include_lengths else batch_padded_qst,
        'trg': batch_padded_trg
    })

In [24]:
BATCH_SIZE = 64

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, collate_fn=add_padding)
valid_dataloader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, collate_fn=add_padding)

for batch in train_dataloader:
    print('batch.id shape:', batch.id.shape)
    print('batch.ctx shape:', batch.ctx[0].shape, batch.ctx[1].shape)
    print('batch.qst shape:', batch.qst[0].shape, batch.qst[1].shape)
    print('batch.trg shape:', batch.trg.shape)
    break

batch.id shape: torch.Size([64, 1])
batch.ctx shape: torch.Size([64, 253]) torch.Size([64])
batch.qst shape: torch.Size([64, 19]) torch.Size([64])
batch.trg shape: torch.Size([64, 2])


***TODO: Download pretrained GloVe embedding***

## Modeling

***Character Embedding Layer***