In [1]:
from collections import Counter
import string
import re
import argparse
import json
import sys
import numpy as np
import nltk
import random
import math
import os
import pickle
from tqdm import tqdm

In [2]:
import spacy
nlp = spacy.blank("en")

In [3]:
def pickler(path,pkl_name,obj):
    with open(os.path.join(path, pkl_name), 'wb') as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)

def unpickler(path,pkl_name):
    with open(os.path.join(path, pkl_name) ,'rb') as f:
        obj = pickle.load(f)
    return obj

In [4]:
TRAINING = False

out_pkl_path = "./"

context_history_size = 2

if(TRAINING):
    file_path = "/home/bhargav/data/coqa/coqa-train-v1.0.json"
    out_pkl_name = "dataset_formatted_train.pkl"
    
else:
    file_path = "/home/bhargav/data/coqa/coqa-dev-v1.0.json"
    out_pkl_name = "dataset_formatted_dev.pkl"

In [5]:
def normalize(text):
    text = re.sub(
            r"[\*\"“”\n\\…\+\-\/\=\(\)‘•:\[\]\|’\!;]", " ", 
            str(text))
    text = re.sub(r"[ ]+", " ", text)
    text = re.sub(r"\!+", "!", text)
    text = re.sub(r"\,+", ",", text)
    text = re.sub(r"\?+", "?", text)
    text = text.lower().strip()
    return text

In [6]:
def word_tokenize(text):
    return [x.text for x in nlp.tokenizer(normalize(text)) if x.text != " "]

In [7]:
def sent_tokenize(text):
    paragraph_out = []
    sentences =  nltk.sent_tokenize(text)
    for s in sentences:
        paragraph_out.append(word_tokenize(s))
    return paragraph_out

In [8]:
def score_overlap(sent,ans):
    sent_tok = set(sent)
    ans_tok = set(ans)
    return 100 - len(ans_tok.difference(sent_tok))

In [9]:
def find_matching_sentence(passage, span):
    matching_sentence = [0 for i in range(len(passage))]
    matching_scores = []
    for sent in passage:
        matching_scores.append(score_overlap(sent, span))
    best_match_index = np.array(matching_scores).argmax()
    matching_sentence[best_match_index] = 1
    return matching_sentence

def get_prev_ids(current_index, context_history_size, turn_id):
    ids = list(range(current_index-context_history_size, current_index))
    assert(turn_id != 0)
    for i in range(context_history_size-turn_id+1):
        ids[i]=0
    return ids

def make_tables(dataset_in, context_history_size):
    passages = []
    questions = [[]]  # add a blank question. Use this as history when nothing is availabe
    answer_spans = [[]] # add a blank answer. Use this as history when nothing is availabe
    answer_sentences = []
    data_points = []
    for passage in tqdm(dataset_in['data']):
        passage_sents = sent_tokenize(normalize(passage['story']))
        passages.append(passage_sents)
        for i in range(len(passage['questions'])):
# data_point format: (passage_id, question_id, [prev qa_ids])
            d_p = []
            question = word_tokenize(normalize(passage['questions'][i]['input_text']))
            question_turn = passage['questions'][i]['turn_id']
            ans_span = word_tokenize(normalize(passage['answers'][i]['span_text']))
            ans_sent = find_matching_sentence(passages[-1], ans_span)
            
            questions.append(question)
            answer_spans.append(ans_span)
            answer_sentences.append(ans_sent)
            
            d_p.append(len(passages)-1)
            d_p.append(len(questions)-1)
            
            prev_qa_ids = get_prev_ids(current_index=len(questions)-1, 
                                       context_history_size=context_history_size,
                                       turn_id=question_turn)
            d_p.append(prev_qa_ids)
            data_points.append(d_p) 
    return {"passages":passages, "questions":questions, "answer_spans":answer_spans, 
            "answer_sentences":answer_sentences, "data_points":data_points}

In [10]:
with open(file_path, encoding='utf8') as file:
    dataset_original = json.load(file)

In [11]:
dataset_formatted = make_tables(dataset_original, context_history_size)

100%|██████████| 500/500 [00:13<00:00, 42.78it/s]


Expectation: number of questions and answer_spans is one more than number of answer_sentences and data_points

In [12]:
print(len(dataset_formatted['passages']))
print(len(dataset_formatted['questions']))
print(len(dataset_formatted['answer_spans']))
print(len(dataset_formatted['answer_sentences']))
print(len(dataset_formatted['data_points']))

500
7984
7984
7983
7983


In [13]:
# print(dataset_formatted['passages'][0])

In [14]:
# print(dataset_formatted['questions'][9])

In [15]:
# print(dataset_formatted['answer_spans'][9])

In [16]:
# print(dataset_formatted['answer_sentences'][0])

In [17]:
# print(dataset_formatted['data_points'][10])

In [18]:
pickler(out_pkl_path,out_pkl_name,dataset_formatted)
print("Done")

Done
