In [1]:
from collections import Counter
import string
import re
import argparse
import json
import sys
import numpy as np
import nltk
import random
import math
import os
import dill as pickle
from tqdm import tqdm, trange

In [2]:
import pdb

In [3]:
def pickler(path,pkl_name,obj):
    with open(os.path.join(path, pkl_name), 'wb') as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)

def unpickler(path,pkl_name):
    with open(os.path.join(path, pkl_name) ,'rb') as f:
        obj = pickle.load(f)
    return obj

# Formatting 

Each data point will be [question_id, question, [context sentences], [supporting fact indicators]]

In [4]:
TRAIN = False

out_pkl_path = "./"


if TRAIN:
    dataset_file_path = "/home/bhargav/data/squad/train-v1.1.json"
    out_pkl_name = "preprocessed_train.pkl"
else:
    dataset_file_path = "/home/bhargav/data/squad/dev-v1.1.json"
    out_pkl_name = "preprocessed_dev.pkl"
    

In [5]:
def sent_has_answer(sent,ans):
    sent_tok = set(nltk.word_tokenize(sent))
    ans_tok = set(nltk.word_tokenize(ans))
    return len(ans_tok.difference(sent_tok)) == 0

def put_space_before_period(sents):
    output = []
    for sent in sents:
        s = nltk.word_tokenize(sent)
        output.append(" ".join(s))
    return output

# Input: number of characters in each sentence, start pointer of the answer(given in the dataset)
# Output: one hot vector indicating the sentence containing the answer
def find_answer_sentence(sentence_lengths, answer_start):
    length_so_far = 0
    sentence_index = 0
    for i, length in enumerate(sentence_lengths):
        if(length_so_far <= answer_start <= length_so_far+length):
            sentence_index = i
            break
        else:
            length_so_far += length
    out_vector = [0 for i in sentence_lengths]
    out_vector[sentence_index] = 1
    return out_vector

In [6]:
# with open(file_path, encoding='utf8') as file:
#     dataset = json.load(file)

In [7]:
def normalize(text):
#     exclude = set(string.punctuation)
#     clean = ''.join(ch for ch in text if ch not in exclude)
#     clean = clean.lower().strip()
    text = re.sub(
            r"[\*\"“”\n\\…\+\-\/\=\(\)‘•:\[\]\|’\!;]", " ", 
            str(text))
    text = re.sub(r"[ ]+", " ", text)
    text = re.sub(r"\!+", "!", text)
    text = re.sub(r"\,+", ",", text)
    text = re.sub(r"\?+", "?", text)
    text = text.lower().strip()
    return text

In [8]:
def format_dataset(dataset):
    ids = []
    questions = []
    sentences = []
    answers = []
    for article in tqdm(dataset):
        for paragraph in article['paragraphs']:
            context_raw = paragraph['context']
            context_sents = nltk.sent_tokenize(context_raw)
            # +1 to count the spaces after the '.' after each sentence
            context_sent_lengths = [len(x)+1 for x in context_sents]
            for qa in paragraph['qas']:
                gt_start_pointers = list(map(lambda x: x['answer_start'], qa['answers']))
                answer_indicators = find_answer_sentence(sentence_lengths=context_sent_lengths, answer_start=gt_start_pointers[0])                
                sentences.append(context_sents)
                questions.append(qa['question'])
                ids.append(qa['id'])
                answers.append(answer_indicators)
    return [ids, questions, sentences, answers]


  

In [9]:
def run(dataset_file_path):
    with open(dataset_file_path) as dataset_file:
        dataset_json = json.load(dataset_file)
        dataset = dataset_json['data']
    records = format_dataset(dataset)
    return records

In [10]:
ids, questions, sentences, supporting_facts = run(dataset_file_path)

100%|██████████| 48/48 [00:01<00:00, 44.14it/s]


In [11]:
print(len(questions))
print(len(sentences))
print(len(supporting_facts))

10570
10570
10570


In [12]:
questions[0]

'Which NFL team represented the AFC at Super Bowl 50?'

In [13]:
sentences[0]

['Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season.',
 'The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title.',
 "The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California.",
 'As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50.']

In [14]:
supporting_facts[0]

[0, 1, 0, 0]

In [15]:
def get_index_with_id(dataset, idx = "5ae61bfd5542992663a4f261"):
    index = -1
    for i,item in enumerate(dataset):
        if(item["_id"] == idx):
            return i
    return index
    

In [16]:
def print_formatted_example(index, question_ids, questions, sentences, 
                            supporting_facts):
    separator = "--xx--xx--xx--xx--xx--xx--"
    print("Question id:",question_ids[index])
    print(separator)
    print("Question:",questions[index])
    print(separator)
    print("sentences:")
    for i, sent in enumerate(sentences[index]):
        print("{} :{}".format(i,sent))
    print(separator)
    print("supporting_facts:",supporting_facts[index])
    print(separator)

In [17]:
print_formatted_example(10, ids, questions, sentences, 
                            supporting_facts)

Question id: 56bea9923aeaaa14008c91bb
--xx--xx--xx--xx--xx--xx--
Question: What day was the Super Bowl played on?
--xx--xx--xx--xx--xx--xx--
sentences:
0 :Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season.
1 :The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title.
2 :The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California.
3 :As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50.
--xx--xx--xx--xx--xx--xx--
supporting_facts: [0, 0, 1, 0]
--

# Tokenization    

In [18]:
from pytorch_pretrained_bert import BertTokenizer

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [19]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [20]:
def foo():
    tokens = tokenizer.tokenize("     hello, world! I don't like you :-)  !!!! (YOLO)")
    print(tokens)

foo()

['hello', ',', 'world', '!', 'i', 'don', "'", 't', 'like', 'you', ':', '-', ')', '!', '!', '!', '!', '(', 'yo', '##lo', ')']


In [21]:
def tokenize(text):
    return tokenizer.tokenize(text)

In [22]:
def batch_tokenize(text_in):
    text_out = []
    for line in text_in:
        tokens = tokenize(line)
        tokens = tokenizer.convert_tokens_to_ids(tokens)
        text_out.append(tokens)
    return text_out

In [23]:
batch_tokenize(["tokenize this","and this!"])

[[19204, 4697, 2023], [1998, 2023, 999]]

In [24]:
questions_tokenized = batch_tokenize(questions)

In [25]:
print(questions_tokenized[0])

[2029, 5088, 2136, 3421, 1996, 10511, 2012, 3565, 4605, 2753, 1029]


In [26]:
sentences_tokenized = []
for sent_list in tqdm(sentences):
    sents_tokenized = batch_tokenize(sent_list)
    sentences_tokenized.append(sents_tokenized)

100%|██████████| 10570/10570 [00:49<00:00, 213.45it/s]


In [27]:
print(sentences_tokenized[0])

[[3565, 4605, 2753, 2001, 2019, 2137, 2374, 2208, 2000, 5646, 1996, 3410, 1997, 1996, 2120, 2374, 2223, 1006, 5088, 1007, 2005, 1996, 2325, 2161, 1012], [1996, 2137, 2374, 3034, 1006, 10511, 1007, 3410, 7573, 14169, 3249, 1996, 2120, 2374, 3034, 1006, 22309, 1007, 3410, 3792, 12915, 2484, 1516, 2184, 2000, 7796, 2037, 2353, 3565, 4605, 2516, 1012], [1996, 2208, 2001, 2209, 2006, 2337, 1021, 1010, 2355, 1010, 2012, 11902, 1005, 1055, 3346, 1999, 1996, 2624, 3799, 3016, 2181, 2012, 4203, 10254, 1010, 2662, 1012], [2004, 2023, 2001, 1996, 12951, 3565, 4605, 1010, 1996, 2223, 13155, 1996, 1000, 3585, 5315, 1000, 2007, 2536, 2751, 1011, 11773, 11107, 1010, 2004, 2092, 2004, 8184, 28324, 2075, 1996, 4535, 1997, 10324, 2169, 3565, 4605, 2208, 2007, 3142, 16371, 28990, 2015, 1006, 2104, 2029, 1996, 2208, 2052, 2031, 2042, 2124, 2004, 1000, 3565, 4605, 1048, 1000, 1007, 1010, 2061, 2008, 1996, 8154, 2071, 14500, 3444, 1996, 5640, 16371, 28990, 2015, 2753, 1012]]


In [28]:
max_question_plus_sentence_len = 103
max_sentences_per_passage = 10

In [29]:
cls_index = tokenizer.convert_tokens_to_ids(["[CLS]"])[0]
sep_index = tokenizer.convert_tokens_to_ids(["[SEP]"])[0]
pad_index = tokenizer.convert_tokens_to_ids(["[PAD]"])[0]

In [30]:
def merge_trim_pad(sent_1, sent_2, max_len, cls_index, sep_index, pad_index):
    merged_seq = [cls_index] + sent_1 + [sep_index] + sent_2
    merged_seq = merged_seq[:max_len-1]
    merged_seq.append(sep_index)
    merged_seq += [pad_index] * (max_len - len(merged_seq))
    num_zeros = len(sent_1) + 2
    segment_id = [0]*num_zeros + [1]*(len(merged_seq)-num_zeros)
    mask  = []
    for id in merged_seq:
        if(id == pad_index):
            mask.append(0)
        else:
            mask.append(1)
    return merged_seq, segment_id, mask

In [31]:
merge_trim_pad([1,1], [2,2,2], 10, "[CLS]", "[SEP]", ["PAD"])

(['[CLS]', 1, 1, '[SEP]', 2, 2, 2, '[SEP]', ['PAD'], ['PAD']],
 [0, 0, 0, 0, 1, 1, 1, 1, 1, 1],
 [1, 1, 1, 1, 1, 1, 1, 1, 0, 0])

In [32]:
merge_trim_pad([1,1], [2,2,2], 5, "[CLS]", "[SEP]", ["PAD"])

(['[CLS]', 1, 1, '[SEP]', '[SEP]'], [0, 0, 0, 0, 1], [1, 1, 1, 1, 1])

In [33]:
merge_trim_pad([1,1], [2,2,2], 8, "[CLS]", "[SEP]", ["PAD"])

(['[CLS]', 1, 1, '[SEP]', 2, 2, 2, '[SEP]'],
 [0, 0, 0, 0, 1, 1, 1, 1],
 [1, 1, 1, 1, 1, 1, 1, 1])

In [34]:
merge_trim_pad([], [], 5, "[CLS]", "[SEP]", ["PAD"])

(['[CLS]', '[SEP]', '[SEP]', ['PAD'], ['PAD']],
 [0, 0, 1, 1, 1],
 [1, 1, 1, 0, 0])

In [35]:
merge_trim_pad([1,1], [2,2,2], 10, cls_index, sep_index, pad_index)

([101, 1, 1, 102, 2, 2, 2, 102, 0, 0],
 [0, 0, 0, 0, 1, 1, 1, 1, 1, 1],
 [1, 1, 1, 1, 1, 1, 1, 1, 0, 0])

In [36]:
# init data dict
data_out = {}

for i in range(max_sentences_per_passage):
    data_out["sequence_{}".format(i)] = []
    data_out["sequence_segment_id_{}".format(i)] = []
    #data_out["sequence_mask_{}".format(i)] = []

data_out["passage_mask"] = []
data_out["supporting_fact"] = []


for i in trange(len(questions_tokenized)):
    question = questions_tokenized[i]
    sentences = sentences_tokenized[i][:max_sentences_per_passage]
    num_pad_sentences = max_sentences_per_passage - len(sentences)
    sentences = sentences + [[]]*(num_pad_sentences)
    
    passage_mask = [1] * (max_sentences_per_passage-num_pad_sentences) + [0]*num_pad_sentences
    
    
    supp_fact = supporting_facts[i]
    
    supp_fact = supp_fact[:max_sentences_per_passage]
    
    # Dont append anything to data_out before this
    # Skip the training questions who's passage loses supporting fact due to trimming
    if(TRAIN):
        if(sum(supp_fact) == 0):
            continue
            
    data_out["passage_mask"].append(passage_mask)
    
    supp_fact = supp_fact + [0]*(num_pad_sentences)
    data_out["supporting_fact"].append(supp_fact)
    
    for j,sent in enumerate(sentences):
        merged_seq, segment_id, mask = merge_trim_pad(sent_1=question, sent_2=sent, 
                                    max_len=max_question_plus_sentence_len, 
                                    cls_index=cls_index, sep_index=sep_index, pad_index=pad_index)
        data_out["sequence_{}".format(j)].append(merged_seq)
        data_out["sequence_segment_id_{}".format(j)].append(segment_id)
        #data_out["sequence_mask_{}".format(j)].append(mask)
        
        


        
    

100%|██████████| 10570/10570 [00:03<00:00, 2682.39it/s]


In [37]:
data_out.keys()

dict_keys(['sequence_0', 'sequence_segment_id_0', 'sequence_1', 'sequence_segment_id_1', 'sequence_2', 'sequence_segment_id_2', 'sequence_3', 'sequence_segment_id_3', 'sequence_4', 'sequence_segment_id_4', 'sequence_5', 'sequence_segment_id_5', 'sequence_6', 'sequence_segment_id_6', 'sequence_7', 'sequence_segment_id_7', 'sequence_8', 'sequence_segment_id_8', 'sequence_9', 'sequence_segment_id_9', 'passage_mask', 'supporting_fact'])

In [38]:
print(data_out['sequence_0'][0])

[101, 2029, 5088, 2136, 3421, 1996, 10511, 2012, 3565, 4605, 2753, 1029, 102, 3565, 4605, 2753, 2001, 2019, 2137, 2374, 2208, 2000, 5646, 1996, 3410, 1997, 1996, 2120, 2374, 2223, 1006, 5088, 1007, 2005, 1996, 2325, 2161, 1012, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]


In [39]:
print(data_out['sequence_segment_id_0'][0])

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]


In [40]:
#print(data_out['sequence_mask_0'][0])

In [41]:
data_out['sequence_0'][0].count(0)

64

In [42]:
#data_out['sequence_mask_0'][0].count(0)

In [43]:
all_lengths = []
both = []
for key,value in data_out.items():
    both.append([key, len(value)])
    all_lengths.append(len(value))

In [44]:
print(min(all_lengths))
print(max(all_lengths))

10570
10570


In [45]:
both

[['sequence_0', 10570],
 ['sequence_segment_id_0', 10570],
 ['sequence_1', 10570],
 ['sequence_segment_id_1', 10570],
 ['sequence_2', 10570],
 ['sequence_segment_id_2', 10570],
 ['sequence_3', 10570],
 ['sequence_segment_id_3', 10570],
 ['sequence_4', 10570],
 ['sequence_segment_id_4', 10570],
 ['sequence_5', 10570],
 ['sequence_segment_id_5', 10570],
 ['sequence_6', 10570],
 ['sequence_segment_id_6', 10570],
 ['sequence_7', 10570],
 ['sequence_segment_id_7', 10570],
 ['sequence_8', 10570],
 ['sequence_segment_id_8', 10570],
 ['sequence_9', 10570],
 ['sequence_segment_id_9', 10570],
 ['passage_mask', 10570],
 ['supporting_fact', 10570]]

In [46]:
pickler(out_pkl_path,out_pkl_name,data_out)
print("Done")

Done
