# Preprocessing HotpotQA for BERT

In [1]:
from collections import Counter
import string
import re
import argparse
import json
import sys
import numpy as np
import nltk
import random
import math
import os
import pickle
from tqdm import tqdm, trange

In [2]:
from pytorch_pretrained_bert import BertTokenizer

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [3]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [4]:
def pickler(path,pkl_name,obj):
    with open(os.path.join(path, pkl_name), 'wb') as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)

def unpickler(path,pkl_name):
    with open(os.path.join(path, pkl_name) ,'rb') as f:
        obj = pickle.load(f)
    return obj

In [5]:
TRAINING = True

out_pkl_path = "./"

if(TRAINING):
    file_path = "/home/bhargav/data/hotpotqa/hotpot_train_v1.json"
    out_pkl_name = "preprocessed_train.pkl"
    problem_indices = [8437, 25197, 34122, 46031, 52955, 63867, 82250]
else:
    file_path = "/home/bhargav/data/hotpotqa/hotpot_dev_distractor_v1.json"
    out_pkl_name = "preprocessed_dev.pkl"
    problem_indices = [5059]

In [6]:
with open(file_path, encoding='utf8') as file:
    dataset = json.load(file)

In [7]:
def normalize(text):
#     exclude = set(string.punctuation)
#     clean = ''.join(ch for ch in text if ch not in exclude)
#     clean = clean.lower().strip()
    text = re.sub(
            r"[\*\"“”\n\\…\+\-\/\=\(\)‘•:\[\]\|’\!;]", " ", 
            str(text))
    text = re.sub(r"[ ]+", " ", text)
    text = re.sub(r"\!+", "!", text)
    text = re.sub(r"\,+", ",", text)
    text = re.sub(r"\?+", "?", text)
    text = text.lower().strip()
    return text

In [8]:
def tokenize(text):
    return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))

In [9]:
questions = []
paragraphs = [] 
supporting_facts = []


for item_index, item in enumerate(tqdm(dataset)):
    if(item_index in problem_indices):
        continue
    question = tokenize(item["question"])
    questions.append(question)
    paragraph_names = []
    paragraph_text = []
    for i,para in enumerate(item["context"]):
        para_name = para[0]
        para_sents = para[1]
        paragraph_names.append(para_name)
        paragraph_text.append([tokenize(s) for s in para_sents])
    paragraphs.append(paragraph_text)
    supp_fact_list = []
    for sup_fact in item["supporting_facts"]:
        para_name = sup_fact[0]
        supporting_fact_index = sup_fact[1] 
        para_index = paragraph_names.index(para_name)
        supp_fact_list.append([para_index, supporting_fact_index])
    supporting_facts.append(supp_fact_list)

100%|██████████| 90564/90564 [45:26<00:00, 32.53it/s]


In [10]:
print(questions[0])

[2073, 2001, 1996, 2364, 2839, 1999, 1996, 2143, 22953, 15551, 14620, 1999, 2152, 1011, 3036, 13691, 8323, 1029]


In [11]:
print(paragraphs[0][0])

[[2798, 4300, 1000, 4918, 1000, 10582, 1006, 2141, 2745, 5146, 12001, 2036, 2124, 2004, 2798, 22953, 15551, 1025, 1020, 2285, 3999, 1007, 2003, 2019, 2394, 4735, 2040, 2003, 2411, 3615, 2000, 1999, 1996, 2329, 2811, 2004, 1996, 1000, 2087, 6355, 7267, 1999, 3725, 1000, 1998, 1000, 3725, 1005, 1055, 2087, 12536, 7267, 1000, 1012], [2002, 2038, 2985, 6993, 14620, 1999, 1996, 13276, 2669, 1010, 5041, 17622, 1998, 6683, 5172, 2152, 1011, 3036, 13691, 8323, 1012]]


In [12]:
print(supporting_facts[0])

[[7, 1], [0, 1]]


In [13]:
print(len(questions))
print(len(paragraphs))
print(len(paragraphs[0]))
print(len(supporting_facts))

90557
90557
10
90557


In [14]:
def compute_paragraph_lengths(document):
    lengths = []
    for para in document:
        lengths.append(len(para))
    return lengths, sum(lengths)


# returns supporting fact indices so that it can be used later while trimming documents.
def expand_supporting_facts(supporting_facts, paragraphs):
    supporting_facts_expanded = []
    problem_indices = []
    supporting_fact_indices = []
    for i,supp_facts in enumerate(tqdm(supporting_facts)):
        s_f_indices = []
        paragraph_lengths, total_num_sentences = compute_paragraph_lengths(paragraphs[i])
        s_f_expanded = [0] * total_num_sentences
        for para_idx, sentence_idx in supp_facts:
            fact_idx = sum(paragraph_lengths[:para_idx])+ sentence_idx
            if(fact_idx >= total_num_sentences):
                problem_indices.append(i)
            else:
                s_f_indices.append(fact_idx)
        for s_f_idx in s_f_indices:
            s_f_expanded[s_f_idx] = 1.0
        supporting_facts_expanded.append(s_f_expanded)
        supporting_fact_indices += s_f_indices
    return supporting_facts_expanded, set(problem_indices), s_f_indices
        

In [15]:
supporting_facts_expanded, problem_indices, supporting_fact_indices = expand_supporting_facts(supporting_facts, paragraphs)

100%|██████████| 90557/90557 [00:00<00:00, 114732.30it/s]


In [16]:
supporting_fact_indices = np.array(supporting_fact_indices)

print("Avg supporting fact index:{}".format(supporting_fact_indices.mean()))
print("min supporting fact index:{}".format(supporting_fact_indices.min()))
print("max supporting fact index:{}".format(supporting_fact_indices.max()))

max_supporting_fact_index = 35
np.sum(np.greater(supporting_fact_indices,max_supporting_fact_index))/supporting_fact_indices.shape[0]

Avg supporting fact index:18.0
min supporting fact index:3
max supporting fact index:33


0.0

In [17]:
problem_indices

set()

In [18]:
print(supporting_facts_expanded[0])

[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]


In [19]:
def flatten_documents(documents):
    flattened_documents = []
    for doc in tqdm(documents):
        f_d = []
        for para in doc:
            for sent in para:
                f_d.append(sent)
        flattened_documents.append(f_d)
    return flattened_documents

In [20]:
flattened_documents = flatten_documents(paragraphs)

100%|██████████| 90557/90557 [00:00<00:00, 95061.99it/s]


In [21]:
assert(len(flattened_documents[0]) == len(supporting_facts_expanded[0]))

In [22]:
num_sentences_per_document = []

for item in supporting_facts_expanded:
    num_sentences_per_document.append(len(item))

num_sentences_per_document = np.array(num_sentences_per_document)

print("Avg document len:{}".format(num_sentences_per_document.mean()))
print("min document len:{}".format(num_sentences_per_document.min()))
print("max document len:{}".format(num_sentences_per_document.max()))

Avg document len:40.926985213732785
min document len:2
max document len:144


In [23]:
max_sentences_per_document = 65
np.sum(np.greater(num_sentences_per_document,max_sentences_per_document))/num_sentences_per_document.shape[0]

0.025884249699084556

In [24]:
sentence_lengths = []

for doc in flattened_documents:
    for sent in doc:
        sentence_lengths.append(len(sent))

sentence_lengths = np.array(sentence_lengths)

print("Avg sentence len:{}".format(sentence_lengths.mean()))
print("min sentence len:{}".format(sentence_lengths.min()))
print("max sentence len:{}".format(sentence_lengths.max()))

Avg sentence len:28.858022381263954
min sentence len:0
max sentence len:949


In [25]:
max_sent_len = 55
np.sum(np.greater(sentence_lengths,max_sent_len))/sentence_lengths.shape[0]

0.05598472839614432

In [26]:
question_lengths = np.array([len(q) for q in questions])

print("Avg question len:{}".format(question_lengths.mean()))
print("min question len:{}".format(question_lengths.min()))
print("max question len:{}".format(question_lengths.max()))

Avg question len:22.418277990657817
min question len:1
max question len:141


In [27]:
max_question_len = 45
np.sum(np.greater(question_lengths,max_question_len))/question_lengths.shape[0]

0.05349117130646996

In [28]:
max_question_plus_sentence_len = 103

In [29]:
cls_index = tokenizer.convert_tokens_to_ids(["[CLS]"])[0]
sep_index = tokenizer.convert_tokens_to_ids(["[SEP]"])[0]
pad_index = tokenizer.convert_tokens_to_ids(["[PAD]"])[0]

In [30]:
def merge_trim_pad(sent_1, sent_2, max_len, cls_index, sep_index, pad_index):
    merged_seq = [cls_index] + sent_1 + [sep_index] + sent_2
    merged_seq = merged_seq[:max_len-1]
    merged_seq.append(sep_index)
    merged_seq = merged_seq + [pad_index] * (max_len - len(merged_seq))
    num_zeros = merged_seq.index(sep_index) + 1 # index of [SEP] + 1 =  number of zeros to add
    segment_id = [0]*num_zeros + [1]*(len(merged_seq)-num_zeros)
    return merged_seq, segment_id

In [31]:
# This is a messy function. I have wrapped it in a function because i have lifted this code from my squad sentence selector
# and the variable names inside and outside the functions conflict.
def do_a_lot_of_work(questions_tokenized, sentences_tokenized, supporting_facts, 
                     max_question_plus_sentence_len, cls_index, sep_index, pad_index):
    # init data dict
    data_out = {"sequences":[], "segment_ids":[], "supporting_fact":[], "document_lengths":[]}
    
    for i in trange(len(questions_tokenized)):
        question = questions_tokenized[i]
        sentences = sentences_tokenized[i]
        supp_fact = supporting_facts[i]
        
        data_out["document_lengths"].append(len(supp_fact))
        
        for j,sent in enumerate(sentences):
            merged_seq, segment_id = merge_trim_pad(sent_1=question, sent_2=sent, 
                                        max_len=max_question_plus_sentence_len, 
                                        cls_index=cls_index, sep_index=sep_index, pad_index=pad_index)
            data_out["sequences"].append(merged_seq)
            data_out["segment_ids"].append(segment_id)
            data_out["supporting_fact"].append(supp_fact[j])
    
    return data_out

In [32]:
data_out = do_a_lot_of_work(questions, flattened_documents, supporting_facts_expanded, 
                           max_question_plus_sentence_len, cls_index, sep_index, pad_index)

100%|██████████| 90557/90557 [01:01<00:00, 1479.03it/s]


In [33]:
for key, value in data_out.items():
    print("Key:{}, length:{}".format(key, len(value)))

Key:sequences, length:3706225
Key:segment_ids, length:3706225
Key:supporting_fact, length:3706225
Key:document_lengths, length:90557


In [34]:
seq_len = []
for seq in data_out["sequences"]:
    seq_len.append(len(seq))
for seq in data_out["segment_ids"]:
    seq_len.append(len(seq))

print(min(seq_len))
print(max(seq_len))

103
103


In [35]:
pickler(out_pkl_path,out_pkl_name,data_out)
print("Done")

Done
