# Purporse: Read the data, tokenize it and store it in a format which can be massaged as needed later

In [1]:
from collections import Counter
import string
import re
import argparse
import json
import sys
import numpy as np
import nltk
import random
import math
import os
import pickle
from tqdm import tqdm, trange

import pdb

from pytorch_pretrained_bert.tokenization import (BasicTokenizer,
                                                  BertTokenizer,
                                                  whitespace_tokenize)

In [2]:
def pickler(path,pkl_name,obj):
    with open(os.path.join(path, pkl_name), 'wb') as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)

def unpickler(path,pkl_name):
    with open(os.path.join(path, pkl_name) ,'rb') as f:
        obj = pickle.load(f)
    return obj

In [3]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [4]:
TRAINING = True

out_pkl_path = "./"

if(TRAINING):
    file_path = "../../../hotpotqa/hotpot_train_v1.1.json"
    out_pkl_name = "preproc_train_1.pkl"
else:
    file_path = "../../../hotpotqa/hotpot_dev_distractor_v1.json"
    out_pkl_name = "preproc_dev_1.pkl"

# max_seq_len = 510
# max_num_paragraphs = 10

In [5]:
with open(file_path, encoding='utf8') as file:
    dataset = json.load(file)

In [6]:
def tokenize(text, tokens_to_text_mapping, bert_tokenizer):
    out_list = []
    tokens = whitespace_tokenize(text)
    for tok in tokens:
        ids = bert_tokenizer.convert_tokens_to_ids(bert_tokenizer.tokenize(tok))
        tokens_to_text_mapping[tuple(ids)] = tok
        out_list += ids
    return out_list

def un_tokenize(ids, tokens_to_text_mapping, bert_tokenizer):
    out_list = []
    start = 0
    end = start
    while (start < len(ids)) and (end < len(ids)):
        i = len(ids)
        decoded_anything = False
        while (decoded_anything == False) and (i > start):
            if(tuple(ids[start:i]) in tokens_to_text_mapping.keys()):
                out_list.append(tokens_to_text_mapping[tuple(ids[start:i])])
                decoded_anything = True
            else:
                i -= 1
        if(decoded_anything == False):
            start += 1
            end = start
        else:
            start = i
            end = i
    return " ".join(out_list)

In [7]:
question_ids = []
questions = []
paragraphs = []
paragraph_names = []
answers = []
answers_string = []
question_indices = []
yes_no_span = []
supporting_facts = []
ids_to_word_mappings = []
supporting_facts_raw = []
skipped = []

for item_index, item in enumerate(tqdm(dataset)):
    answers_string.append(item["answer"])
    id_to_word = {}
    para_names = []
    para_text = []
    for i,para in enumerate(item["context"]):
        p_name = para[0]
        p_sents = para[1]
        p_sents[0] = p_name + ". " +p_sents[0]
        para_names.append(p_name)
        para_text.append([tokenize(s, id_to_word, tokenizer) for s in p_sents])
    paragraphs.append(para_text)
    paragraph_names.append(para_names)
    supp_fact_list = []
    supporting_facts_raw.append(item["supporting_facts"])
    for sup_fact in item["supporting_facts"]:
        p_name = sup_fact[0]
        supporting_fact_index = sup_fact[1] 
        para_index = para_names.index(p_name)
        supp_fact_list.append([para_index, supporting_fact_index])
    
    supporting_facts.append(supp_fact_list)
    question_indices.append(item_index)
    question_ids.append(item["_id"])
    question = tokenize(item["question"], id_to_word, tokenizer)
    questions.append(question)
    answer_str = item["answer"]
    if(answer_str == "yes"):
        yes_no_span.append(0)
    elif(answer_str == "no"):
        yes_no_span.append(1)
    else:
        yes_no_span.append(2)
    answer_tokenized = tokenize(answer_str, {}, tokenizer)
    answers.append(answer_tokenized)
    ids_to_word_mappings.append(id_to_word)

100%|██████████| 90447/90447 [33:03<00:00, 45.59it/s]


In [8]:
assert(len(question_ids) ==
len(questions) ==
len(paragraphs) == 
len(paragraph_names) == 
len(answers) == 
len(question_indices) == 
len(yes_no_span) == 
len(supporting_facts) == 
len(ids_to_word_mappings) ==
len(supporting_facts_raw) == 
len(answers_string))

In [9]:
len(question_ids)

90447

In [10]:
out_dict = {
    "question_ids" : question_ids,
    "questions" : questions,
    "paragraphs" : paragraphs,
    "paragraph_names" : paragraph_names,
    "answers" : answers,
    "question_indices" : question_indices,
    "yes_no_span" : yes_no_span,
    "supporting_facts" : supporting_facts,
    "ids_to_word_mappings" : ids_to_word_mappings,
    "answers_string" : answers_string,
    "supporting_facts_raw": supporting_facts_raw
}

In [11]:
num_paras = Counter([len(p) for p in paragraphs])

In [12]:
num_paras

Counter({10: 89609, 8: 60, 2: 262, 6: 53, 3: 156, 5: 88, 4: 94, 7: 77, 9: 48})

In [13]:
supporting_facts_raw[0]

[["Arthur's Magazine", 0], ['First for Women', 0]]

In [14]:
pickler(out_pkl_path, out_pkl_name, out_dict)
print("Done")

Done
