In [None]:
from collections import Counter
import string
import re
import argparse
import json
import sys
import numpy as np
import nltk
import random
import math
import os
import dill as pickle
from tqdm import tqdm, trange

In [None]:
import pdb

In [None]:
def pickler(path,pkl_name,obj):
    with open(os.path.join(path, pkl_name), 'wb') as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)

def unpickler(path,pkl_name):
    with open(os.path.join(path, pkl_name) ,'rb') as f:
        obj = pickle.load(f)
    return obj

# Formatting 

Each data point will be [question_id, question, [context sentences], [supporting fact indicators]]

In [None]:
TRAIN = True

out_pkl_path = "./"


if TRAIN:
    dataset_file_path = "/home/bhargav/data/squad/train-v1.1.json"
    out_pkl_name = "preprocessed_train.pkl"
else:
    dataset_file_path = "/home/bhargav/data/squad/dev-v1.1.json"
    out_pkl_name = "preprocessed_dev.pkl"
    

In [None]:
def sent_has_answer(sent,ans):
    sent_tok = set(nltk.word_tokenize(sent))
    ans_tok = set(nltk.word_tokenize(ans))
    return len(ans_tok.difference(sent_tok)) == 0

def put_space_before_period(sents):
    output = []
    for sent in sents:
        s = nltk.word_tokenize(sent)
        output.append(" ".join(s))
    return output

# Input: number of characters in each sentence, start pointer of the answer(given in the dataset)
# Output: one hot vector indicating the sentence containing the answer
def find_answer_sentence(sentence_lengths, answer_start):
    length_so_far = 0
    sentence_index = 0
    for i, length in enumerate(sentence_lengths):
        if(length_so_far <= answer_start <= length_so_far+length):
            sentence_index = i
            break
        else:
            length_so_far += length
    out_vector = [0 for i in sentence_lengths]
    out_vector[sentence_index] = 1
    return out_vector

In [None]:
# with open(file_path, encoding='utf8') as file:
#     dataset = json.load(file)

In [None]:
def normalize(text):
#     exclude = set(string.punctuation)
#     clean = ''.join(ch for ch in text if ch not in exclude)
#     clean = clean.lower().strip()
    text = re.sub(
            r"[\*\"“”\n\\…\+\-\/\=\(\)‘•:\[\]\|’\!;]", " ", 
            str(text))
    text = re.sub(r"[ ]+", " ", text)
    text = re.sub(r"\!+", "!", text)
    text = re.sub(r"\,+", ",", text)
    text = re.sub(r"\?+", "?", text)
    text = text.lower().strip()
    return text

In [None]:
# def format_dataset(dataset):
#     ids = []
#     questions = []
#     sentences = []
#     answers = []
#     for article in tqdm(dataset):
#         for paragraph in article['paragraphs']:
#             context_raw = paragraph['context']
#             context_sents = nltk.sent_tokenize(context_raw)
#             # context_sents = put_space_before_period(context_sents)
#             for qa in paragraph['qas']:
#                 answer_indicators = []
#                 ground_truths = list(map(lambda x: x['text'], qa['answers']))
#                 longest_ans = ground_truths[np.argmax([len(x) for x in ground_truths])]
#                 for sent_index,sent in enumerate(context_sents):
#                     has_ans = sent_has_answer(sent,longest_ans)
#                     if(has_ans):
#                         answer_indicators.append(1)
#                     else:
#                         answer_indicators.append(0)
#                 sentences.append(context_sents)
#                 questions.append(qa['question'])
#                 ids.append(qa['id'])
#                 answers.append(answer_indicators)
#     return [ids, questions, sentences, answers]


def format_dataset(dataset):
    ids = []
    questions = []
    sentences = []
    answers = []
    for article in tqdm(dataset):
        for paragraph in article['paragraphs']:
            context_raw = paragraph['context']
            context_sents = nltk.sent_tokenize(context_raw)
            # +1 to count the spaces after the '.' after each sentence
            context_sent_lengths = [len(x)+1 for x in context_sents]
            for qa in paragraph['qas']:
                gt_start_pointers = list(map(lambda x: x['answer_start'], qa['answers']))
                answer_indicators = find_answer_sentence(sentence_lengths=context_sent_lengths, answer_start=gt_start_pointers[0])
                
#                 old_answer_indicators = []
#                 ground_truths = list(map(lambda x: x['text'], qa['answers']))
#                 for sent_index,sent in enumerate(context_sents):
#                     has_ans = sent_has_answer(sent,ground_truths[0])
#                     if(has_ans):
#                         old_answer_indicators.append(1)
#                     else:
#                         old_answer_indicators.append(0)
                
#                 if(not any([answer_indicators[i]==old_answer_indicators[i] for i in range(len(answer_indicators))])):
#                     print("PROBLEM!")
#                     print("===========")
#                     print(context_sents)
#                     print("===========")
#                     print(context_sent_lengths)
#                     print("===========")
#                     print("gt_start_pointers: ",gt_start_pointers)
#                     print("===========")
#                     print("GT answers:",ground_truths)
#                     print("===========")
#                     print("New ans vector:",answer_indicators)
#                     print("===========")
#                     print("old ans vector:",old_answer_indicators)
#                     print("===========")
#                     print("Question:",qa['question'])
#                     return [0,0,0,0]
                
                sentences.append(context_sents)
                questions.append(qa['question'])
                ids.append(qa['id'])
                answers.append(answer_indicators)
    return [ids, questions, sentences, answers]


  

In [None]:
def run(dataset_file_path):
    with open(dataset_file_path) as dataset_file:
        dataset_json = json.load(dataset_file)
        dataset = dataset_json['data']
    records = format_dataset(dataset)
    return records

In [None]:
ids, questions, sentences, supporting_facts = run(dataset_file_path)

In [None]:
print(len(questions))
print(len(sentences))
print(len(supporting_facts))

In [None]:
questions[0]

In [None]:
sentences[0]

In [None]:
supporting_facts[0]

In [None]:
def get_index_with_id(dataset, idx = "5ae61bfd5542992663a4f261"):
    index = -1
    for i,item in enumerate(dataset):
        if(item["_id"] == idx):
            return i
    return index
    

In [None]:
def print_formatted_example(index, question_ids, questions, sentences, 
                            supporting_facts):
    separator = "--xx--xx--xx--xx--xx--xx--"
    print("Question id:",question_ids[index])
    print(separator)
    print("Question:",questions[index])
    print(separator)
    print("sentences:")
    for i, sent in enumerate(sentences[index]):
        print("{} :{}".format(i,sent))
    print(separator)
    print("supporting_facts:",supporting_facts[index])
    print(separator)

In [None]:
print_formatted_example(10, ids, questions, sentences, 
                            supporting_facts)

In [None]:
# formatted_dataset = [questions, sentences, supporting_facts]

# Tokenization    

In [None]:
from pytorch_pretrained_bert import BertTokenizer

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [None]:
def foo():
    tokens = tokenizer.tokenize("     hello, world! I don't like you :-)  !!!! (YOLO)")
    print(tokens)

foo()

In [None]:
def tokenize(text):
    return tokenizer.tokenize(text)

In [None]:
def batch_tokenize(text_in):
    text_out = []
    for line in text_in:
        tokens = tokenize(line)
        tokens = tokenizer.convert_tokens_to_ids(tokens)
        text_out.append(tokens)
    return text_out

In [None]:
batch_tokenize(["tokenize this","and this!"])

In [None]:
questions_tokenized = batch_tokenize(questions)

In [None]:
print(questions_tokenized[0])

In [None]:
sentences_tokenized = []
for sent_list in tqdm(sentences):
    sents_tokenized = batch_tokenize(sent_list)
    sentences_tokenized.append(sents_tokenized)

In [None]:
print(sentences_tokenized[0])

In [None]:
max_question_plus_sentence_len = 103
max_sentences_per_passage = 10

In [None]:
cls_index = tokenizer.convert_tokens_to_ids(["[CLS]"])[0]
sep_index = tokenizer.convert_tokens_to_ids(["[SEP]"])[0]
pad_index = tokenizer.convert_tokens_to_ids(["[PAD]"])[0]

In [None]:
def merge_trim_pad(sent_1, sent_2, max_len, cls_index, sep_index, pad_index):
    merged_seq = [cls_index] + sent_1 + [sep_index] + sent_2
    merged_seq = merged_seq[:max_len-1]
    merged_seq.append(sep_index)
    merged_seq += [pad_index] * (max_len - len(merged_seq))
    num_zeros = len(sent_1) + 2
    segment_id = [0]*num_zeros + [1]*(len(merged_seq)-num_zeros)
    mask  = []
    for id in merged_seq:
        if(id == pad_index):
            mask.append(0)
        else:
            mask.append(1)
    return merged_seq, segment_id, mask

In [None]:
merge_trim_pad([1,1], [2,2,2], 10, "[CLS]", "[SEP]", ["PAD"])

In [None]:
merge_trim_pad([1,1], [2,2,2], 5, "[CLS]", "[SEP]", ["PAD"])

In [None]:
merge_trim_pad([1,1], [2,2,2], 8, "[CLS]", "[SEP]", ["PAD"])

In [None]:
merge_trim_pad([], [], 5, "[CLS]", "[SEP]", ["PAD"])

In [None]:
merge_trim_pad([1,1], [2,2,2], 10, cls_index, sep_index, pad_index)

In [None]:
# init data dict
data_out = {}

for i in range(max_sentences_per_passage):
    data_out["sequence_{}".format(i)] = []
    data_out["sequence_segment_id_{}".format(i)] = []
    #data_out["sequence_mask_{}".format(i)] = []

data_out["passage_mask"] = []
data_out["supporting_fact"] = []


for i in trange(len(questions_tokenized)):
    question = questions_tokenized[i]
    sentences = sentences_tokenized[i][:max_sentences_per_passage]
    num_pad_sentences = max_sentences_per_passage - len(sentences)
    sentences = sentences + [[]]*(num_pad_sentences)
    
    passage_mask = [1] * (max_sentences_per_passage-num_pad_sentences) + [0]*num_pad_sentences
    
    
    supp_fact = supporting_facts[i]
    
    supp_fact = supp_fact[:max_sentences_per_passage]
    
    # Skip the training questions who's passage loses supporting fact due to trimming
    if(TRAIN):
        if(sum(supp_fact) == 0):
            continue
            
    data_out["passage_mask"].append(passage_mask)
    
    supp_fact = supp_fact + [0]*(num_pad_sentences)
    data_out["supporting_fact"].append(supp_fact)
    
    for j,sent in enumerate(sentences):
        merged_seq, segment_id, mask = merge_trim_pad(sent_1=question, sent_2=sent, 
                                    max_len=max_question_plus_sentence_len, 
                                    cls_index=cls_index, sep_index=sep_index, pad_index=pad_index)
        data_out["sequence_{}".format(j)].append(merged_seq)
        data_out["sequence_segment_id_{}".format(j)].append(segment_id)
        #data_out["sequence_mask_{}".format(j)].append(mask)
        
        


        
    

In [None]:
data_out.keys()

In [None]:
print(data_out['sequence_0'][0])

In [None]:
print(data_out['sequence_segment_id_0'][0])

In [None]:
print(data_out['sequence_mask_0'][0])

In [None]:
data_out['sequence_0'][0].count(0)

In [None]:
data_out['sequence_mask_0'][0].count(0)

In [None]:
all_lengths = []
both = []
for key,value in data_out.items():
    both.append([key, len(value)])
    all_lengths.append(len(value))

In [None]:
print(min(all_lengths))
print(max(all_lengths))

In [None]:
both

In [None]:
pickler(out_pkl_path,out_pkl_name,data_out)
print("Done")