## Question Answering from Wikipedia articles

Pre-requisites to run - 
Running this notebook requires certain packages to be pre-installed. 
* spacy
* pandas
* numpy
* torch
* sklearn

Instructions to run - 
1. Download this Jupyter Notebook in a root folder
2. Download and save the squad_train.json and squad_dev.json files for the SQuAD dataset in a folder called "data" and place this folder alongside the notebook in the root folder.
3. Download the glove embeddings with 300 dimension for 6B tokens in the root folder.
4. Run all cells in sequence.

In [1]:
import os
import string
import json
import re
import spacy
import time
from collections import Counter
import numpy as np
import pandas as pd
from sklearn.metrics import f1_score
import nltk
from nltk.corpus import stopwords
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence


In [2]:
EMBED_DIM = 300
BATCH_SIZE = 32

In [3]:
## Load data json files and preprocess data in pandas dataframes

squad_train_path = './data/squad_train.json'
squad_dev_path = './data/squad_dev.json'
glove_path = './glove.6B.300d.txt'
glove = {}
with open(squad_train_path, 'r', encoding='utf-8') as file:
        train_json = json.load(file)
file.close()
with open(squad_dev_path, 'r', encoding='utf-8') as file:
        dev_json = json.load(file)
file.close()
with open(glove_path, 'r', encoding='utf-8') as file:
        lines = file.readlines()
        for embed in lines:
            word = embed.split()[0]
            emb = np.asarray(embed.split()[1:],dtype='float32')
            glove[word] = emb
file.close()
print("Length of data: ", len(dev_json['data']))

Length of data:  48


In [4]:
textList = []
def parse_json(json_obj: dict) -> list :
    data = []
    total_context = 0
    for article in json_obj['data']:
        for paragraph in article['paragraphs']:
            ctx = " ".join(paragraph['context'].split())
            total_context += 1
            textList.append(ctx)
            for ques in paragraph['qas']:
                ques_text = " ".join(ques['question'].split())
                textList.append(ques_text)
                for ans in ques['answers']:
                    ans_start_tag = ans['answer_start']
                    ans_text = " ".join(ans['text'].split())
                    ans_end_tag = ans_start_tag + len(ans_text)
                    new_example = {}
                    new_example['context'] = ctx
                    new_example['question'] = ques_text
                    new_example['ques_id'] = ques['id']
                    new_example['ans_text'] = ans_text
                    new_example['ans_tags'] = [ans_start_tag,ans_end_tag]
                    data.append(new_example)
    print(total_context)
    return data

In [5]:
train = pd.DataFrame(parse_json(train_json))
val = pd.DataFrame(parse_json(dev_json))

18896
2067


In [6]:
textList = list(set(textList))
len(textList)

118825

In [7]:
def process_text(text):
    text = text.lower()
    text = ''.join(ch for ch in text if ch not in set(string.punctuation))
    text = re.sub(r'\b(a|an|the)\b', ' ', text)
    return text

In [8]:
nlp = spacy.load('en',disable=['parser','tagger','ner'])

In [9]:
tokens = []
for doc in nlp.pipe(textList,batch_size=50):
    for token in doc:
        tokens.append(token.text)
print(len(tokens))

3989780


In [10]:
vocab = Counter(tokens)
vocab = sorted(vocab, key=vocab.get, reverse=True)
vocab.insert(0,'<unk>')
vocab.insert(1,'<pad>')
print(len(vocab))

111082


In [11]:
word_embed = np.zeros((len(vocab),EMBED_DIM))
for i,word in enumerate(vocab):
    if word in glove.keys():
        word_embed[i] = glove[word]

In [12]:
word_embed.shape

(111082, 300)

In [13]:
idxInVocab = {w:i for i,w in enumerate(vocab)}
wordAtIdx = {i:w for w,i in idxInVocab.items()}

In [14]:
# Map each token in context and question text to corresponding index in vocab
def map_text2idx(text):
    indices = []
    for token in nlp(text):
        indices.append(idxInVocab[token.text])
    return indices

In [15]:
train['mapped_ctx'] = train.context.apply(map_text2idx)
train['mapped_ques'] = train.question.apply(map_text2idx)
val['mapped_ctx'] = val.context.apply(map_text2idx)
val['mapped_ques'] = val.question.apply(map_text2idx)

In [16]:
# Check for examples which have wrong start and end tags
## REFERENCE: https://github.com/kushalj001/pytorch-question-answering

def test_indices(data):
    start_tag_err = []
    end_tag_err = []
    assert_error = []
    for idx, row in data.iterrows():
        answer_tokens = [w.text for w in nlp(row['ans_text'])]
        start_token = answer_tokens[0]
        end_token = answer_tokens[-1]
        context_span  = [(word.idx, word.idx + len(word.text)) for word in nlp(row['context'])]
        starts, ends = zip(*context_span)
        answer_start, answer_end = row['ans_tags']
        try:
            start_idx = starts.index(answer_start)
        except:
            start_tag_err.append(idx)
        try:
            end_idx  = ends.index(answer_end)
        except:
            end_tag_err.append(idx)
        try:
            assert wordAtIdx[row['mapped_ctx'][start_idx]] == answer_tokens[0]
            assert wordAtIdx[row['mapped_ctx'][end_idx]] == answer_tokens[-1]
        except:
            assert_error.append(idx)
    return start_tag_err, end_tag_err, assert_error

def get_error_indices(data):
    start_value_error, end_value_error, assert_error = test_indices(data)
    err_idx = start_value_error + end_value_error + assert_error
    err_idx = set(err_idx)
    return err_idx

def index_answer(row):
    context_span = [(word.idx, word.idx + len(word.text)) for word in nlp(row.context)]
    starts, ends = zip(*context_span)
    
    answer_start, answer_end = row.ans_tags
    start_idx = starts.index(answer_start)
 
    end_idx  = ends.index(answer_end)
    
    ans_tokens = [w.text for w in nlp(row.ans_text)]
    ans_start = ans_tokens[0]
    ans_end = ans_tokens[-1]
    assert wordAtIdx[row.mapped_ctx[start_idx]] == ans_start
    assert wordAtIdx[row.mapped_ctx[end_idx]] == ans_end
    
    return [start_idx, end_idx]

train_err = get_error_indices(train)
valid_err = get_error_indices(val)

train.drop(train_err, inplace=True)
val.drop(valid_err, inplace=True)

train_tag_idx = train.apply(index_answer, axis=1)
valid_tag_idx = val.apply(index_answer, axis=1)

train['tag_idx'] = train_tag_idx
val['tag_idx'] = valid_tag_idx

train.head()

Unnamed: 0,context,question,ques_id,ans_text,ans_tags,mapped_ctx,mapped_ques,tag_idx
0,"Architecturally, the school has a Catholic cha...",To whom did the Virgin Mary allegedly appear i...,5733be284776f41900661182,Saint Bernadette Soubirous,"[515, 541]","[60168, 3, 2, 209, 42, 10, 551, 822, 5, 97837,...","[401, 582, 25, 2, 3432, 856, 6288, 1063, 8, 85...","[102, 104]"
1,"Architecturally, the school has a Catholic cha...",What is in front of the Notre Dame Main Building?,5733be284776f4190066117f,a copper statue of Christ,"[188, 213]","[60168, 3, 2, 209, 42, 10, 551, 822, 5, 97837,...","[11, 12, 8, 1507, 4, 2, 1240, 1198, 5650, 2748...","[37, 41]"
2,"Architecturally, the school has a Catholic cha...",The Basilica of the Sacred heart at Notre Dame...,5733be284776f41900661180,the Main Building,"[279, 296]","[60168, 3, 2, 209, 42, 10, 551, 822, 5, 97837,...","[16, 6079, 4, 2, 10508, 2005, 36, 1240, 1198, ...","[57, 59]"
3,"Architecturally, the school has a Catholic cha...",What is the Grotto at Notre Dame?,5733be284776f41900661181,a Marian place of prayer and reflection,"[381, 420]","[60168, 3, 2, 209, 42, 10, 551, 822, 5, 97837,...","[11, 12, 2, 23646, 36, 1240, 1198, 6]","[76, 82]"
4,"Architecturally, the school has a Catholic cha...",What sits on top of the Main Building at Notre...,5733be284776f4190066117e,a golden statue of the Virgin Mary,"[92, 126]","[60168, 3, 2, 209, 42, 10, 551, 822, 5, 97837,...","[11, 9367, 26, 454, 4, 2, 5650, 2748, 36, 1240...","[17, 23]"


In [17]:
val.head()

Unnamed: 0,context,question,ques_id,ans_text,ans_tags,mapped_ctx,mapped_ques,tag_idx
0,Super Bowl 50 was an American football game to...,Which NFL team represented the AFC at Super Bo...,56be4db0acb8001400a502ec,Denver Broncos,"[177, 191]","[645, 743, 791, 13, 35, 113, 705, 361, 9, 1784...","[68, 3251, 296, 1270, 2, 11507, 36, 645, 743, ...","[33, 34]"
1,Super Bowl 50 was an American football game to...,Which NFL team represented the AFC at Super Bo...,56be4db0acb8001400a502ec,Denver Broncos,"[177, 191]","[645, 743, 791, 13, 35, 113, 705, 361, 9, 1784...","[68, 3251, 296, 1270, 2, 11507, 36, 645, 743, ...","[33, 34]"
2,Super Bowl 50 was an American football game to...,Which NFL team represented the AFC at Super Bo...,56be4db0acb8001400a502ec,Denver Broncos,"[177, 191]","[645, 743, 791, 13, 35, 113, 705, 361, 9, 1784...","[68, 3251, 296, 1270, 2, 11507, 36, 645, 743, ...","[33, 34]"
3,Super Bowl 50 was an American football game to...,Which NFL team represented the NFC at Super Bo...,56be4db0acb8001400a502ed,Carolina Panthers,"[249, 266]","[645, 743, 791, 13, 35, 113, 705, 361, 9, 1784...","[68, 3251, 296, 1270, 2, 9450, 36, 645, 743, 7...","[44, 45]"
4,Super Bowl 50 was an American football game to...,Which NFL team represented the NFC at Super Bo...,56be4db0acb8001400a502ed,Carolina Panthers,"[249, 266]","[645, 743, 791, 13, 35, 113, 705, 361, 9, 1784...","[68, 3251, 296, 1270, 2, 9450, 36, 645, 743, 7...","[44, 45]"


In [18]:
is_cuda = torch.cuda.is_available()
device = torch.device('cuda' if is_cuda else 'cpu')

In [19]:
## Build a dataloader

class dataloader:
    def __init__(self, data, batch_size):
        self.batch_size = batch_size
        i = 0
        self.data = []
        while i+batch_size < len(data):
            self.data.append(data[i:i+batch_size])
            i = i+batch_size
        if i<len(data):
            self.data.append(data[i:])
    def __len__(self):
        return len(self.data)
    def __iter__(self):
        for batch in self.data:
            ctx_pad_len = max([len(ctx) for ctx in batch.mapped_ctx])
            ques_pad_len = max([len(ques) for ques in batch.mapped_ques])
            pad_ques = torch.LongTensor(len(batch),ques_pad_len).fill_(1)
            pad_ctx = torch.LongTensor(len(batch), ctx_pad_len).fill_(1)
            for i,ctx in enumerate(batch.mapped_ctx):
                pad_ctx[i,:len(ctx)] = torch.LongTensor(ctx)
            for i,ques in enumerate(batch.mapped_ques):
                pad_ques[i,:len(ques)] = torch.LongTensor(ques)
            pad_ctx_mask = torch.eq(pad_ctx,1)
            pad_ques_mask = torch.eq(pad_ques,1)
            yield (pad_ctx,
                   pad_ctx_mask,
                   pad_ques,
                   pad_ques_mask,
                   torch.LongTensor(list(batch.tag_idx)),
                   list(batch.context),
                   list(batch.ans_text),
                   list(batch.ques_id))

In [20]:
class ReaderModel(nn.Module):
    def __init__(self,hidden_dim,embed_dim,num_lstm,dropout):
        super().__init__()
        self.num_lstm = num_lstm
        self.hidden_dim = hidden_dim
        self.embed_dim = embed_dim
        self.dropout = dropout
        self.encoding_dim = hidden_dim*num_lstm*2
        def tune_embedding(grad, words=1000):
            grad[words:] = 0
            return grad
        self.dropout_layer = nn.Dropout(self.dropout)
        self.embed_layer = nn.Embedding.from_pretrained(torch.FloatTensor(word_embed).to(device),freeze=False)
        self.embed_layer.weight.register_hook(tune_embedding)
        self.f_align_layer = F_align(self.embed_dim)
        self.ques_encod_layer = Ques_Encode(self.encoding_dim)
        self.ctx_rnn = RNNLayer(self.embed_dim*2,self.hidden_dim,self.num_lstm,self.dropout)
        self.ques_rnn = RNNLayer(self.embed_dim,self.hidden_dim,self.num_lstm,self.dropout)
        self.start_classifier_layer = Classifier(self.encoding_dim)
        self.end_classifier_layer = Classifier(self.encoding_dim)
        
    def forward(self,ctx,ctx_mask,ques,ques_mask,flag):
        ctx_embed = self.embed_layer(ctx)
#         if flag:
#             print(f"got ctx_embed {ctx_embed.shape}")
        ques_embed = self.embed_layer(ques)
#         if flag:
#             print(f"got ques_embed {ques_embed.shape}")
        f_align = self.f_align_layer(ctx_embed,ques_embed,ques_mask)
#         if flag:
#             print(f"got f_align {f_align.shape}")
        ctx_features = torch.cat([ctx_embed,f_align],dim=2)
#         if flag:
#             print(f"got ctx_features {ctx_features.shape}")
        ctx_encoding = self.ctx_rnn(ctx_features)
#         if flag:
#             print(f"got ctx_encoding {ctx_encoding.shape}")
        ques_features = self.ques_rnn(ques_embed)
#         if flag:
#             print(f"got ques_features {ques_features.shape}")
        ques_features_bi = self.ques_encod_layer(ques_features,ques_mask)
#         if flag:
#             print(f"got ques_features_bi {ques_features_bi.shape}")
        ques_encoding = ques_features_bi.unsqueeze(1).bmm(ques_features).squeeze(1)
#         if flag:
#             print(f"got ques_encoding {ques_encoding.shape}")
        start_class = self.start_classifier_layer(ctx_encoding,ques_encoding,ctx_mask)
#         if flag:
#             print(f"got start_class {start_class.shape}")
        end_class = self.end_classifier_layer(ctx_encoding,ques_encoding,ctx_mask)
#         if flag:
#             print(f"got end_class {end_class.shape}")
        return start_class,end_class

In [21]:
class F_align(nn.Module):
    def __init__(self,embed_dim):
        super().__init__()
        self.dense = nn.Linear(embed_dim,embed_dim)
        self.relu = nn.ReLU()
    def forward(self,ctx,ques,mask):
        alpha_ctx = self.relu(self.dense(ctx))
        alpha_ques = self.relu(self.dense(ques)).permute(0,2,1)
        aij = torch.bmm(alpha_ctx,alpha_ques)
        mask_reshaped = mask.unsqueeze(1).expand(aij.size())
        aij = aij.masked_fill(mask_reshaped == 1, -float('inf')).view(-1, ques.shape[1])
        align_weights = F.softmax(aij,dim=1).view(-1,ctx.shape[1],ques.shape[1])
        return torch.bmm(align_weights,ques)

In [22]:
class Ques_Encode(nn.Module):
    def __init__(self,encoding_dim):
        super().__init__()
        self.dense = nn.Linear(encoding_dim,1)
    def forward(self,features,mask):
        bs,features_len,feature_dim = features.shape
        bj = self.dense(features.view(-1,feature_dim)).view(bs,features_len)
        bj = bj.masked_fill(mask == 1, -float('inf'))
        return F.softmax(bj,dim=1)

In [23]:
class RNNLayer(nn.Module):
    def __init__(self,embed_dim,hidden_dim,num_lstm,dropout):
        super().__init__()
        self.input_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.nlayers = num_lstm
        self.dropout = dropout
        self.blstm = nn.ModuleList()
        self.blstm.append(nn.LSTM(self.input_dim,self.hidden_dim,batch_first=True,bidirectional=True))
        n = 1
        while (n<num_lstm):
            self.blstm.append(nn.LSTM(self.hidden_dim*2,self.hidden_dim,batch_first=True,bidirectional=True))
            n += 1
    def forward(self,features):
        first_layer_out, (h,c) = self.blstm[0](features)
        all_out = [first_layer_out]
        i = 1
        while i<self.nlayers :
            out,(h,c) = self.blstm[i](all_out[-1])
            all_out.append(out)
            i += 1
        all_out = torch.cat(all_out,dim=2)
        return F.dropout(all_out,p=self.dropout)

In [24]:
class Classifier(nn.Module):
    def __init__(self,encoding_dim):
        super().__init__()
        self.dense = nn.Linear(encoding_dim,encoding_dim)
    def forward(self,ctx,ques,mask):
        q = self.dense(ques)
        class_score = (ctx.bmm(q.unsqueeze(2))).squeeze(2)
        class_score = class_score.masked_fill(mask == 1, -float('inf'))
        return class_score

In [25]:
DROPOUT = 0.3
HIDDEN_DIM = 128
NUM_LSTM = 3
EPOCHS = 4

In [26]:
model = ReaderModel(HIDDEN_DIM,EMBED_DIM,NUM_LSTM,DROPOUT)
opt = torch.optim.Adamax(model.parameters())

In [27]:
def training(model,data):
    model.train()
    batch_ind = 0
    total_loss = 0
    for batch in data:
        flag = False
        if batch_ind%500 == 0:
            print(f"Batch {batch_ind}")
            flag = True
        batch_ind +=1
        ctx,ctx_mask,ques,ques_mask,tags,ctx_text,ans_text,ques_id = batch
        opt.zero_grad()
        pred = model(ctx,ctx_mask,ques,ques_mask,flag)
        start_pred,end_pred = pred
#         if flag:
#             print(f" got preds {start_pred.shape} {end_pred.shape}")
        start_tags = tags[:,0]
        end_tags = tags[:,1]
        start_pred_loss = F.cross_entropy(start_pred,start_tags)
        end_pred_loss = F.cross_entropy(end_pred,end_tags)
        batch_loss = start_pred_loss + end_pred_loss
        batch_loss.backward()
#         if flag:
#             print(f"batch loss {batch_loss}")
        torch.nn.utils.clip_grad_norm_(model.parameters(), 10)
        opt.step()
        total_loss += batch_loss.item()
    return total_loss/batch_ind

In [28]:
def validation(model,data):
    model.eval()
    batch_ind = 0
    total_loss = 0
    predictions = {}
    for batch in data:
        flag = False
        if batch_ind%500 == 0:
            print(f"Batch {batch_ind}")
            flag = True
        batch_ind +=1
        ctx,ctx_mask,ques,ques_mask,tags,ctx_text,ans_text,ques_id = batch
        bs, len_c = ctx.shape
        lsoftmax = nn.LogSoftmax(dim=1)
        with torch.no_grad():
            start_pred,end_pred = model(ctx,ctx_mask,ques,ques_mask,flag)
            start_tags = tags[:,0]
            end_tags = tags[:,1]
            start_pred_loss = F.cross_entropy(start_pred,start_tags)
            end_pred_loss = F.cross_entropy(end_pred,end_tags)
            batch_loss = start_pred_loss + end_pred_loss
            total_loss += batch_loss.item()
            # Get predictions
            mask = torch.ones(len_c, len_c) * float('-inf')
            mask = mask.tril(-1).unsqueeze(0).expand(bs, -1, -1)
            score = (lsoftmax(start_pred).unsqueeze(2) + lsoftmax(end_pred).unsqueeze(1)) + mask
            score, s_idx = score.max(dim=1)
            score, e_idx = score.max(dim=1)
            s_idx = torch.gather(s_idx, 1, e_idx.view(-1, 1)).squeeze()
            for i in range(bs):
                q_id = ques_id[i]
                pred = ctx[i][s_idx[i]:e_idx[i]+1]
                pred = ' '.join([wordAtIdx[idx.item()] for idx in pred])
                predictions[q_id] = pred
                        
    f1 = 0
    total = 0
    ground_truths = {}
    for it,row in val.iterrows():
        q_id = row['ques_id']
        if q_id in predictions:
            total +=1
            if q_id not in ground_truths:
                ground_truths[q_id] = [row['ans_text']]
            else:
                ground_truths[q_id].append(row['ans_text'])
    for q_id in predictions:
        scores = []
        for gt in ground_truths[q_id]:
            pred = process_text(predictions[q_id]).split()
            ans = process_text(gt).split()
            num_common_words = sum((Counter(pred) & Counter(ans)).values())
            if num_common_words ==0:
                score = 0
            else:
                p = num_common_words / len(pred)
                r = num_common_words / len(ans)
                score = (2 * p * r) / (p + r)
            scores.append(score)
        f1_score = max(scores)
        f1 += f1_score
    return total_loss/batch_ind, f1*100/total

In [29]:
train_data = dataloader(train,BATCH_SIZE)
val_data = dataloader(val,BATCH_SIZE)
for i in range(EPOCHS) :
    print(f"EPOCH {i+1}")
    start = time.time()
    print("starting training")
    train_loss = training(model,train_data)
    print("starting validation")
    val_loss, f1 = validation(model,val_data)
    end = time.time()
    total_time = end-start
    minutes = int(total_time/60)
    secs = int(total_time-(minutes*60))
#     print(f"Time spent {minutes} min {secs} sec.")
    print(f"F1 score : {f1}")
    print("----------- End of epoch ",i+1)

EPOCH 1
starting training
Batch 0
Batch 500
Batch 1000
Batch 1500
Batch 2000
Batch 2500
starting validation
Batch 0
Batch 500
Batch 1000
F1 score : 15.72443297317382
----------- End of epoch  1
EPOCH 2
starting training
Batch 0
Batch 500
Batch 1000
Batch 1500
Batch 2000
Batch 2500
starting validation
Batch 0
Batch 500
Batch 1000
F1 score : 17.55806947400931
----------- End of epoch  2
EPOCH 3
starting training
Batch 0
Batch 500
Batch 1000
Batch 1500
Batch 2000
Batch 2500
starting validation
Batch 0
Batch 500
Batch 1000
F1 score : 18.15879685181556
----------- End of epoch  3
EPOCH 4
starting training
Batch 0
Batch 500
Batch 1000
Batch 1500
Batch 2000
Batch 2500
starting validation
Batch 0
Batch 500
Batch 1000
F1 score : 18.24309980073122
----------- End of epoch  4
