<a href="https://colab.research.google.com/github/dksifoua/Question-Answering/blob/master/1%20-%20DrQA%2C%20Document%20reader%20Question%20Answering.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Load dependencies

In [1]:
!pip install tqdm --upgrade >> /dev/null 2>&1
!pip install spacy --upgrade >> /dev/null 2>&1
!python -m spacy download en >> /dev/null 2>&1

In [185]:
import tqdm
import json
import spacy
import collections

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

## Prepare data

***Download data***

In [3]:
!rm -rf ./data
!mkdir ./data

!wget --no-check-certificate \
    https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json \
    -O ./data/train-v1.1.json

!wget --no-check-certificate \
    https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json \
    -O ./data/dev-v1.1.json

--2020-10-31 01:11:35--  https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json
Resolving rajpurkar.github.io (rajpurkar.github.io)... 185.199.108.153, 185.199.109.153, 185.199.110.153, ...
Connecting to rajpurkar.github.io (rajpurkar.github.io)|185.199.108.153|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 30288272 (29M) [application/json]
Saving to: ‘./data/train-v1.1.json’


2020-10-31 01:11:36 (26.3 MB/s) - ‘./data/train-v1.1.json’ saved [30288272/30288272]

--2020-10-31 01:11:36--  https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json
Resolving rajpurkar.github.io (rajpurkar.github.io)... 185.199.108.153, 185.199.109.153, 185.199.110.153, ...
Connecting to rajpurkar.github.io (rajpurkar.github.io)|185.199.108.153|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4854279 (4.6M) [application/json]
Saving to: ‘./data/dev-v1.1.json’


2020-10-31 01:11:36 (26.0 MB/s) - ‘./data/dev-v1.1.json’ saved [4854279/485427

***Load JSON data***

In [4]:
def load(path):
    with open(path, mode='r', encoding='utf-8') as file:
        return json.load(file)['data']
    raise FileNotFoundError

In [5]:
train_raw_data = load('./data/train-v1.1.json')
valid_raw_data = load('./data/dev-v1.1.json')
print(f'Length of raw train data: {len(train_raw_data):,}')
print(f'Length of raw valid data: {len(valid_raw_data):,}')

Length of raw train data: 442
Length of raw valid data: 48


***Parse JSON data***

In [6]:
def parse(data, nlp=spacy.load('en')):
    qas = []
    for paragraphs in tqdm.tqdm(data):
        for para in paragraphs['paragraphs']:
            context = nlp(para['context'], disable=['parser'])
            for qa in para['qas']:
                id = qa['id']
                question = nlp(qa['question'], disable=['parser', 'tagger', 'ner'])
                for ans in qa['answers']:
                    qas.append({
                        'id': id,
                        'context': context,
                        'question': question,
                        'answer': nlp(ans['text'], disable=['parser', 'tagger', 'ner']),
                        'answer_start': ans['answer_start'],
                    })
    return qas

In [7]:
train_qas = parse(train_raw_data)
valid_qas = parse(valid_raw_data)
print()
print(f'Length of train qa pairs: {len(train_qas):,}')
print(f'Length of valid qa pairs: {len(valid_qas):,}')
print('==================== Example ====================')
print('Id:', train_qas[0]['id'])
print('Context:', train_qas[0]['context'])
print('Question:', train_qas[0]['question'])
print('Answer starts at:', train_qas[0]['answer_start'])
print('Answer:', train_qas[0]['answer'])

100%|██████████| 442/442 [05:47<00:00,  1.27it/s]
100%|██████████| 48/48 [00:39<00:00,  1.23it/s]


Length of train qa pairs: 87,599
Length of valid qa pairs: 34,726
Id: 5733be284776f41900661182
Context: Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.
Question: To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?
Answer starts at: 515
Answer: Saint Bernadette Soubirous





In [8]:
def test_answer_start(qas):
    """Test answer_start are correct in train set"""
    for qa in tqdm.tqdm(qas):
        answer = qa['answer'].text
        context = qa['context'].text
        answer_start = qa['answer_start']
        assert answer == context[answer_start:answer_start + len(answer)]

In [9]:
test_answer_start(train_qas)
test_answer_start(valid_qas)

100%|██████████| 87599/87599 [00:09<00:00, 9338.46it/s]
100%|██████████| 34726/34726 [00:03<00:00, 10738.37it/s]


***Add targets***

In [10]:
def add_targets(qas):
    for qa in qas:
        context = qa['context']
        answer = qa['answer']
        ans_start = qa['answer_start']
        for i in range(len(context)):
            if context[i].idx == ans_start:
                ans = context[i:i + len(answer)]
                qa['target'] = [ans[0].i, ans[-1].i]
                break

In [11]:
%%time
add_targets(train_qas)
add_targets(valid_qas)
print(f'Length of train qa pairs: {len(train_qas):,}')
print(f'Length of valid qa pairs: {len(valid_qas):,}')

Length of train qa pairs: 87,599
Length of valid qa pairs: 34,726
CPU times: user 2 s, sys: 17 ms, total: 2.02 s
Wall time: 2.02 s


In [12]:
def filter_qas(qa):
    if 'target' in [*qa.keys()]:
        start, end = qa['target']
        return qa['context'][start:end + 1].text == qa['answer'].text
    return False

In [13]:
%%time
train_qas = [*filter(filter_qas, train_qas)]
valid_qas = [*filter(filter_qas, valid_qas)]
print(f'Length of train qa pairs after filtering out bad qa pairs: {len(train_qas):,}')
print(f'Length of valid qa pairs after filtering out bad qa pairs: {len(valid_qas):,}')

Length of train qa pairs after filtering out bad qa pairs: 86,597
Length of valid qa pairs after filtering out bad qa pairs: 34,295
CPU times: user 1.34 s, sys: 0 ns, total: 1.34 s
Wall time: 1.34 s


In [14]:
def test_targets(qas):
    for qa in qas:
        if 'target' in [*qa.keys()]:
            start, end = qa['target']
            assert qa['context'][start:end + 1].text == qa['answer'].text

In [15]:
%%time
test_targets(train_qas)
test_targets(valid_qas)

CPU times: user 1.29 s, sys: 928 µs, total: 1.29 s
Wall time: 1.3 s


***Add features***

In [174]:
def add_features(qas):
    for qa in tqdm.tqdm(qas):
        question = [token.text for token in qa['question']]
        context = qa['context']
        counts = collections.Counter(map(lambda token: token.text.lower(), context))
        freqs = {index: counts[token.text.lower()] for index, token in enumerate(context)}
        freqs_norm = sum(freqs.values())
        qa['em'], qa['pos'], qa['ner'], qa['ntf'] = zip(
            *map(lambda index: [
                context[index].text in question, context[index].tag_,
                context[index].ent_type_ or 'None',
                freqs[index] / freqs_norm
            ], range(len(context)))
        )

In [175]:
add_features(train_qas)
add_features(valid_qas)

100%|██████████| 86597/86597 [01:02<00:00, 1382.68it/s]
100%|██████████| 34295/34295 [00:23<00:00, 1468.85it/s]


***Build vocabularies***

In [181]:
class Vocab:

    def __init__(self):
        self.vocab = None
        self.word2count = None
        self.word2index = None
        self.index2word = None
    
    def build(self, data, specials):
        """
        :param List[Union[spacy.tokens.doc.Doc, str, Tuple]] data
        :param List[str] specials
        """
        words = specials
        type_0 = type(data[0])
        if type_0 == spacy.tokens.doc.Doc:
            for item in data:
                words += [word.text.lower() for word in item]
        elif type_0 == str:
            words += data
        elif type_0 == tuple:
            for item in data:
                words += [word.lower() for word in item]
        self.word2count = collections.Counter(words)
        self.vocab = sorted(self.word2count.keys())
        self.word2index = {word: index for index, word in enumerate(self.vocab)}
        self.index2word = {index: word for index, word in enumerate(self.vocab)}
    
    def __len__(self):
        return len(self.vocab)
    
    def stoi(self, word: str):
        return self.word2index[word]

    def itos(self, index: int):
        return self.index2word[index]

In [182]:
%%time
ID, POS, NER, TEXT = Vocab(), Vocab(), Vocab(), Vocab()

ids = [*map(lambda qa: qa['id'], train_qas)] + [*map(lambda qa: qa['id'], valid_qas)]
pos, ner, contexts, questions = zip(*map(lambda qa: (qa['pos'], qa['ner'], qa['context'], qa['question']), train_qas))

ID.build(data=[*set(ids)], specials=[])
POS.build(data=[*set(pos)], specials=[])
NER.build(data=[*set(ner)], specials=[])
TEXT.build(data=[*set(contexts)] + [*set(questions)], specials=['<pad>', '<unk>'])

print(f'Length of ID vocabulary: {len(ID):,}')
print(f'Length of POS vocabulary: {len(POS):,}')
print(f'Length of NER vocabulary: {len(NER):,}')
print(f'Length of TEXT vocabulary: {len(TEXT):,}')

Length of ID vocabulary: 97,106
Length of POS vocabulary: 50
Length of NER vocabulary: 19
Length of TEXT vocabulary: 91,446
CPU times: user 8.27 s, sys: 57.6 ms, total: 8.33 s
Wall time: 8.35 s


***Build datasets***

In [None]:
class SquadDataset(Dataset):

    def __init__(self, data, id_vocab, pos_vocab, ner_vocab, text_vocab):
        self.data = data
        self.id_vocab = id_vocab
        self.pos_vocab = pos_vocab
        self.ner_vocab = ner_vocab
        self.text_vocab = text_vocab
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        pass