In [None]:
import os
import json
import numpy as np
import pickle
import re
import string 

with open("dev-v1.1.json") as f:
    test = json.load(f)
with open("train-v1.1.json") as f:
    train = json.load(f)

In [None]:
def split_into_sentences(text):
    sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z]\.)(?<=\.|\?)\s', text)
    return sentences, len(sentences)

In [None]:
def split_clqa(topics):
    """
    for each topic, split into paragraph.
    for each paragraph, split context, label, question and answer

    Attributes:
        topics: list of topic (1 topic has n paragraphs)
    
    Returns:
        context: list of contexts
        label: list of labels
        question: list of questions
        answer: list of answers
    """
    result_context = []
    result_label = []
    result_question = []
    result_answer = []
    
    table = str.maketrans("","",string.punctuation)
    for i in table.keys():
        if i != 36:
            table[i] = " "+chr(i)+" "
            
    for i in range(len(topics)):
        paragraphs = topics[i]['paragraphs']
        for j in range(len(paragraphs)):
            paragraph = paragraphs[j]
            
            temp_context = paragraph['context']
            temp_qas = paragraph['qas']
            
            # remove '[\alphabet]'. ex) [i], [k]
            temp_context = re.sub('[[][a-zA-Z]+[]]', '', temp_context)
#             replace paragraph to sentences with \n
#             temp_context = re.sub('[a-z][.]', '\n', temp_context)
            context, length = split_into_sentences(temp_context)            

            for i in range(len(context)):
                context[i] = context[i].translate(str.maketrans(table))
            label = [i for i in reversed(range(length))] 
            
            for k in range(len(temp_qas)):
                temp_qa = temp_qas[k]
                
                question = temp_qa['question']
                temp_answers = temp_qa['answers']
                
                for l in range(len(temp_answers)):
                    answer = temp_answers[l]['text']
                    
                    result_context.append(context)
                    result_label.append(label)
                    result_question.append(question)
                    result_answer.append(answer)
    # check
    if (len(result_context) == len(result_label) == len(result_question) == len(result_answer)):
        print("Data is well prepared!")
        print("total: {}".format(len(result_context)))
    else:
        print("Something is missing! check again")
        print("the number of questions: {}".format(len(result_question)))
        print("the number of answers: {}".format(len(result_answer)))
        print("the number of contexts: {}".format(len(result_context)))
        print("the number of labels: {}".format(len(result_label)))
    
    
    
    return result_context, result_label, result_question, result_answer 

In [None]:
train_topics = train['data'][:400] # 400
val_topics = train['data'][400:] # 42
test_topics = test['data'] # 48
train_context, train_label, train_question, train_answer = split_clqa(train_topics)
val_context, val_label, val_question, val_answer = split_clqa(val_topics)
test_context, test_label, test_question, test_answer = split_clqa(test_topics)

In [None]:
del train
del test

In [None]:
# conver to index
# word set
cq_word_set = set()
list_of_context = [train_context, val_context, test_context]
list_of_question = [train_question, val_question, test_question]

for list_ in list_of_context:
    for para in list_:
        for sent in para:
            sent = sent.split()
            sent = list(map(str.lower, sent))
            cq_word_set.update(sent)
            
for list_ in list_of_question:
    for sent in list_:
        sent = sent.split()
        sent = list(map(str.lower, sent))
        cq_word_set.update(sent)

In [None]:
answer_word_set = set()
list_of_answer = [train_answer, val_answer, test_answer]
for answers in list_of_answer:
    for answer in answers:
            answer_word_set.add(answer)

In [None]:
print("context and question words: {}".format(len(cq_word_set)))
print("answer words: {}".format(len(answer_word_set)))

In [None]:
del train_topics
del val_topics
del test_topics

In [None]:
cq_word_index = {}
for i, word in enumerate(cq_word_set):
    cq_word_index[word] = i

answer_word_index = {}
for i, word in enumerate(answer_word_set):
    answer_one_hot = np.zeros([len(answer_word_set)])
    answer_one_hot[i] = 1
    answer_word_index[word] = answer_one_hot

In [None]:
train_context_index = []
val_context_index = []
test_context_index = []

for para in train_context:
    indexed_para = []
    for sent in para:
        sent = sent.split()
        sent = list(map(str.lower, sent))
        indexed_sent = []
        for word in sent:
            indexed_sent.append(cq_word_index[word])
        indexed_para.append(indexed_sent)
    train_context_index.append(np.array(indexed_para))
    
for para in val_context:
    indexed_para = []
    for sent in para:
        sent = sent.split()
        sent = list(map(str.lower, sent))
        indexed_sent = []
        for word in sent:
            indexed_sent.append(cq_word_index[word])
        indexed_para.append(indexed_sent)
    val_context_index.append(np.array(indexed_para))
    
for para in test_context:
    indexed_para = []
    for sent in para:
        sent = sent.split()
        sent = list(map(str.lower, sent))
        indexed_sent = []
        for word in sent:
            indexed_sent.append(cq_word_index[word])
        indexed_para.append(indexed_sent)
    test_context_index.append(np.array(indexed_para))

In [None]:
if (len(train_context_index) + len(val_context_index) + len(test_context_index)) == (len(train_context) + len(test_context) + len(val_context)):
    print("context encoding is completed!")
else:
    print("Something is missed! Check again")

In [None]:
train_question_index = []
val_question_index = []
test_question_index = []

for sent in train_question:
    sent = sent.split()
    sent = list(map(str.lower, sent))    
    indexed_sent = []
    for word in sent:
        indexed_sent.append(cq_word_index[word])
    train_question_index.append(np.array(indexed_sent))
    
for sent in val_question:
    sent = sent.split()
    sent = list(map(str.lower, sent))
    indexed_sent = []
    for word in sent:
        indexed_sent.append(cq_word_index[word])
    val_question_index.append(np.array(indexed_sent))
    
for sent in test_question:
    sent = sent.split()
    sent = list(map(str.lower, sent))
    indexed_sent = []
    for word in sent:
        indexed_sent.append(cq_word_index[word])
    test_question_index.append(np.array(indexed_sent))

In [None]:
if (len(train_question_index) + len(val_question_index) + len(test_question_index)) == (len(train_question) + len(test_question) + len(val_question)):
    print("question encoding is completed!")
else:
    print("Something is missed! Check again")

In [None]:
train_answer_index = []
val_answer_index = []
test_answer_index = []

for answer in train_answer:
    indexed_answer = answer_word_index[answer]
    train_answer_index.append(indexed_answer)
    
for answer in val_answer:
    indexed_answer = answer_word_index[answer]
    val_answer_index.append(indexed_answer)
    
for answer in test_answer:
    indexed_answer = answer_word_index[answer]
    test_answer_index.append(indexed_answer)

In [None]:
if (len(train_answer_index)+ len(val_answer_index) + len(test_answer_index)) == (len(train_answer) + len(test_answer) + len(val_answer)):
    print("answer encoding is completed!")
else:
    print("Something is missed! Check again")

In [None]:
train_label_index = []
val_label_index = []
test_label_index = []

for label in train_label:
    train_label_index.append(np.eye(20)[label])
    
for label in val_label:
    val_label_index.append(np.eye(20)[label])

for label in test_label:
    test_label_index.append(np.eye(20)[label])

In [None]:
if (len(train_label_index)+ len(val_label_index) + len(test_label_index)) == (len(train_label) + len(test_label) + len(val_label)):
    print("label encoding is completed!")
else:
    print("Something is missed! Check again")

In [None]:
train_dataset = (train_question_index, train_answer_index, train_context_index, train_label_index)
val_dataset = (val_question_index, val_answer_index, val_context_index, val_label_index)
test_dataset = (test_question_index, test_answer_index, test_context_index, test_label_index)

In [None]:
with open('babi_preprocessd/train_dataset.pkl', 'wb') as f:
    pickle.dump(train_dataset, f)

with open('babi_preprocessd/val_dataset.pkl', 'wb') as f:
    pickle.dump(val_dataset, f)
    
with open('babi_preprocessd/test_dataset.pkl', 'wb') as f:
    pickle.dump(test_dataset, f)