# Preprocessing HotpotQA for BERT

In [1]:
from collections import Counter
import string
import re
import argparse
import json
import sys
import numpy as np
import nltk
import random
import math
import os
import pickle
from tqdm import tqdm, trange

In [2]:
from pytorch_pretrained_bert import BertTokenizer

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [3]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [4]:
def pickler(path,pkl_name,obj):
    with open(os.path.join(path, pkl_name), 'wb') as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)

def unpickler(path,pkl_name):
    with open(os.path.join(path, pkl_name) ,'rb') as f:
        obj = pickle.load(f)
    return obj

In [5]:
TRAINING = False

out_pkl_path = "./"

if(TRAINING):
    file_path = "/home/bhargav/data/hotpotqa/hotpot_train_v1.json"
    out_pkl_name = "preprocessed_train.pkl"
    problem_indices = [8437, 25197, 34122, 46031, 52955, 63867, 82250]
else:
    file_path = "/home/bhargav/data/hotpotqa/hotpot_dev_distractor_v1.json"
    out_pkl_name = "preprocessed_dev.pkl"
    problem_indices = [5059]

In [6]:
with open(file_path, encoding='utf8') as file:
    dataset = json.load(file)

In [7]:
def normalize(text):
#     exclude = set(string.punctuation)
#     clean = ''.join(ch for ch in text if ch not in exclude)
#     clean = clean.lower().strip()
    text = re.sub(
            r"[\*\"“”\n\\…\+\-\/\=\(\)‘•:\[\]\|’\!;]", " ", 
            str(text))
    text = re.sub(r"[ ]+", " ", text)
    text = re.sub(r"\!+", "!", text)
    text = re.sub(r"\,+", ",", text)
    text = re.sub(r"\?+", "?", text)
    text = text.lower().strip()
    return text

In [8]:
def tokenize(text):
    return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))

In [9]:
questions = []
paragraphs = [] 
supporting_facts = []


for item_index, item in enumerate(tqdm(dataset)):
    if(item_index in problem_indices):
        continue
    question = tokenize(item["question"])
    questions.append(question)
    paragraph_names = []
    paragraph_text = []
    for i,para in enumerate(item["context"]):
        para_name = para[0]
        para_sents = para[1]
        paragraph_names.append(para_name)
        paragraph_text.append([tokenize(s) for s in para_sents])
    paragraphs.append(paragraph_text)
    supp_fact_list = []
    for sup_fact in item["supporting_facts"]:
        para_name = sup_fact[0]
        supporting_fact_index = sup_fact[1] 
        para_index = paragraph_names.index(para_name)
        supp_fact_list.append([para_index, supporting_fact_index])
    supporting_facts.append(supp_fact_list)

100%|██████████| 7405/7405 [03:50<00:00, 32.15it/s]


In [10]:
print(questions[0])

[2020, 3660, 18928, 3385, 1998, 3968, 3536, 1997, 1996, 2168, 10662, 1029]


In [11]:
print(paragraphs[0][0])

[[3968, 3536, 2003, 1037, 2807, 2137, 16747, 2558, 4038, 1011, 3689, 2143, 2856, 1998, 2550, 2011, 5199, 9658, 1010, 1998, 4626, 5206, 2139, 9397, 2004, 8754, 12127, 3968, 3536, 1012], [1996, 2143, 5936, 1996, 2558, 1999, 3536, 1005, 1055, 2166, 2043, 2002, 2081, 2010, 2190, 1011, 2124, 3152, 2004, 2092, 2004, 2010, 3276, 2007, 3364, 20252, 11320, 12333, 2072, 1010, 2209, 2011, 3235, 28570, 1012], [4532, 8201, 6262, 1010, 10717, 12098, 29416, 1010, 10799, 3557, 1010, 7059, 5032, 1010, 1998, 3021, 6264, 2024, 2426, 1996, 4637, 3459, 1012]]


In [12]:
print(supporting_facts[0])

[[1, 0], [4, 0]]


In [13]:
print(len(questions))
print(len(paragraphs))
print(len(paragraphs[0]))
print(len(supporting_facts))

7404
7404
10
7404


In [14]:
def compute_paragraph_lengths(document):
    lengths = []
    for para in document:
        lengths.append(len(para))
    return lengths, sum(lengths)


# returns supporting fact indices so that it can be used later while trimming documents.
def expand_supporting_facts(supporting_facts, paragraphs):
    supporting_facts_expanded = []
    problem_indices = []
    supporting_fact_indices = []
    for i,supp_facts in enumerate(tqdm(supporting_facts)):
        s_f_indices = []
        paragraph_lengths, total_num_sentences = compute_paragraph_lengths(paragraphs[i])
        s_f_expanded = [0] * total_num_sentences
        for para_idx, sentence_idx in supp_facts:
            fact_idx = sum(paragraph_lengths[:para_idx])+ sentence_idx
            if(fact_idx >= total_num_sentences):
                problem_indices.append(i)
            else:
                s_f_indices.append(fact_idx)
        for s_f_idx in s_f_indices:
            s_f_expanded[s_f_idx] = 1
        supporting_facts_expanded.append(s_f_expanded)
        supporting_fact_indices += s_f_indices
    return supporting_facts_expanded, set(problem_indices), s_f_indices
        

In [15]:
supporting_facts_expanded, problem_indices, supporting_fact_indices = expand_supporting_facts(supporting_facts, paragraphs)

100%|██████████| 7404/7404 [00:00<00:00, 105082.20it/s]


In [16]:
supporting_fact_indices = np.array(supporting_fact_indices)

print("Avg supporting fact index:{}".format(supporting_fact_indices.mean()))
print("min supporting fact index:{}".format(supporting_fact_indices.min()))
print("max supporting fact index:{}".format(supporting_fact_indices.max()))

max_supporting_fact_index = 35
np.sum(np.greater(supporting_fact_indices,max_supporting_fact_index))/supporting_fact_indices.shape[0]

Avg supporting fact index:24.0
min supporting fact index:17
max supporting fact index:31


0.0

In [17]:
problem_indices

set()

In [18]:
print(supporting_facts_expanded[0])

[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]


In [19]:
def flatten_documents(documents):
    flattened_documents = []
    for doc in tqdm(documents):
        f_d = []
        for para in doc:
            for sent in para:
                f_d.append(sent)
        flattened_documents.append(f_d)
    return flattened_documents

In [20]:
flattened_documents = flatten_documents(paragraphs)

100%|██████████| 7404/7404 [00:00<00:00, 84650.73it/s]


In [21]:
assert(len(flattened_documents[0]) == len(supporting_facts_expanded[0]))

In [22]:
num_sentences_per_document = []

for item in supporting_facts_expanded:
    num_sentences_per_document.append(len(item))

num_sentences_per_document = np.array(num_sentences_per_document)

print("Avg document len:{}".format(num_sentences_per_document.mean()))
print("min document len:{}".format(num_sentences_per_document.min()))
print("max document len:{}".format(num_sentences_per_document.max()))

Avg document len:41.38614262560778
min document len:2
max document len:147


In [23]:
max_sentences_per_document = 65
np.sum(np.greater(num_sentences_per_document,max_sentences_per_document))/num_sentences_per_document.shape[0]

0.027552674230145867

In [24]:
sentence_lengths = []

for doc in flattened_documents:
    for sent in doc:
        sentence_lengths.append(len(sent))

sentence_lengths = np.array(sentence_lengths)

print("Avg sentence len:{}".format(sentence_lengths.mean()))
print("min sentence len:{}".format(sentence_lengths.min()))
print("max sentence len:{}".format(sentence_lengths.max()))

Avg sentence len:28.85067047839098
min sentence len:0
max sentence len:453


In [25]:
max_sent_len = 55
np.sum(np.greater(sentence_lengths,max_sent_len))/sentence_lengths.shape[0]

0.05610871246610078

In [26]:
question_lengths = np.array([len(q) for q in questions])

print("Avg question len:{}".format(question_lengths.mean()))
print("min question len:{}".format(question_lengths.min()))
print("max question len:{}".format(question_lengths.max()))

Avg question len:19.590626688276608
min question len:7
max question len:65


In [27]:
max_question_len = 45
np.sum(np.greater(question_lengths,max_question_len))/question_lengths.shape[0]

0.0035116153430578066

In [28]:
max_question_plus_sentence_len = 103

In [29]:
cls_index = tokenizer.convert_tokens_to_ids(["[CLS]"])[0]
sep_index = tokenizer.convert_tokens_to_ids(["[SEP]"])[0]
pad_index = tokenizer.convert_tokens_to_ids(["[PAD]"])[0]

In [30]:
def merge_trim_pad(sent_1, sent_2, max_len, cls_index, sep_index, pad_index):
    merged_seq = [cls_index] + sent_1 + [sep_index] + sent_2
    merged_seq = merged_seq[:max_len-1]
    merged_seq.append(sep_index)
    merged_seq = merged_seq + [pad_index] * (max_len - len(merged_seq))
    num_zeros = merged_seq.index(sep_index) + 1 # index of [SEP] + 1 =  number of zeros to add
    segment_id = [0]*num_zeros + [1]*(len(merged_seq)-num_zeros)
    return merged_seq, segment_id

In [31]:
# This is a messy function. I have wrapped it in a function because i have lifted this code from my squad sentence selector
# and the variable names inside and outside the functions conflict.
def do_a_lot_of_work(questions_tokenized, sentences_tokenized, supporting_facts, max_sentences_per_passage, 
                     max_question_plus_sentence_len, cls_index, sep_index, pad_index):
    # init data dict
    data_out = {}

    for i in range(max_sentences_per_passage):
        data_out["sequence_{}".format(i)] = []
        data_out["sequence_segment_id_{}".format(i)] = []

    data_out["passage_mask"] = []
    data_out["supporting_fact"] = []


    for i in trange(len(questions_tokenized)):
        question = questions_tokenized[i]
        sentences = sentences_tokenized[i][:max_sentences_per_passage]
        num_pad_sentences = max_sentences_per_passage - len(sentences)
        sentences = sentences + [[]]*(num_pad_sentences)

        passage_mask = [1] * (max_sentences_per_passage-num_pad_sentences) + [0]*num_pad_sentences
        data_out["passage_mask"].append(passage_mask)
        
        supp_fact = supporting_facts[i]
        supp_fact = supp_fact[:max_sentences_per_passage]
        supp_fact = supp_fact + [0]*(num_pad_sentences)
        data_out["supporting_fact"].append(supp_fact)

        for j,sent in enumerate(sentences):
            merged_seq, segment_id = merge_trim_pad(sent_1=question, sent_2=sent, 
                                        max_len=max_question_plus_sentence_len, 
                                        cls_index=cls_index, sep_index=sep_index, pad_index=pad_index)
            data_out["sequence_{}".format(j)].append(merged_seq)
            data_out["sequence_segment_id_{}".format(j)].append(segment_id)

    return data_out

In [32]:
data_out = do_a_lot_of_work(questions, flattened_documents, supporting_facts_expanded, 
                           max_sentences_per_document, max_question_plus_sentence_len, cls_index, sep_index, pad_index)

100%|██████████| 7404/7404 [00:08<00:00, 826.14it/s]


In [33]:
data_out.keys()

dict_keys(['sequence_0', 'sequence_segment_id_0', 'sequence_1', 'sequence_segment_id_1', 'sequence_2', 'sequence_segment_id_2', 'sequence_3', 'sequence_segment_id_3', 'sequence_4', 'sequence_segment_id_4', 'sequence_5', 'sequence_segment_id_5', 'sequence_6', 'sequence_segment_id_6', 'sequence_7', 'sequence_segment_id_7', 'sequence_8', 'sequence_segment_id_8', 'sequence_9', 'sequence_segment_id_9', 'sequence_10', 'sequence_segment_id_10', 'sequence_11', 'sequence_segment_id_11', 'sequence_12', 'sequence_segment_id_12', 'sequence_13', 'sequence_segment_id_13', 'sequence_14', 'sequence_segment_id_14', 'sequence_15', 'sequence_segment_id_15', 'sequence_16', 'sequence_segment_id_16', 'sequence_17', 'sequence_segment_id_17', 'sequence_18', 'sequence_segment_id_18', 'sequence_19', 'sequence_segment_id_19', 'sequence_20', 'sequence_segment_id_20', 'sequence_21', 'sequence_segment_id_21', 'sequence_22', 'sequence_segment_id_22', 'sequence_23', 'sequence_segment_id_23', 'sequence_24', 'sequence_

In [34]:
num_sequences_in_each_position = []
for i in range(max_sentences_per_document):
    num_sequences_in_each_position.append(len(data_out["sequence_{}".format(i)]))
    num_sequences_in_each_position.append(len(data_out["sequence_segment_id_{}".format(i)]))

num_sequences_in_each_position.append(len(data_out["passage_mask"]))
num_sequences_in_each_position.append(len(data_out["supporting_fact"]))
    
print(min(num_sequences_in_each_position))
print(max(num_sequences_in_each_position))

7404
7404


In [35]:
all_lengths = []

for i in range(max_sentences_per_document):
#     for item in data_out["sequence_{}".format(i)]:
#         all_lengths.append(len(item))
    for item in data_out["sequence_segment_id_{}".format(i)]:
        all_lengths.append(len(item))
    
print(min(all_lengths))
print(max(all_lengths))

103
103


In [36]:
from collections import Counter

c = Counter(all_lengths)
c

Counter({103: 481260})

In [37]:
passage_mask_and_sf_lengths = []

for item in data_out["passage_mask"]:
    passage_mask_and_sf_lengths.append(len(item))
    
for item in data_out["supporting_fact"]:
    passage_mask_and_sf_lengths.append(len(item))
    
print(min(passage_mask_and_sf_lengths))
print(max(passage_mask_and_sf_lengths))

65
65


In [38]:
pickler(out_pkl_path,out_pkl_name,data_out)
print("Done")

Done
