In [1]:
import numpy as np
import os
import pickle
import re

In [2]:
train_paths = []
val_paths = []
test_paths= []
for dirpath, dirnames, filenames in os.walk("./tasks_1-20_v1-2/en-valid-10k/"):
    for filename in filenames:
        if 'train' in filename:
            train_paths.append(os.path.join(dirpath, filename))
        elif 'val' in filename:
            val_paths.append(os.path.join(dirpath, filename))
        else:
            test_paths.append(os.path.join(dirpath, filename))

In [3]:
def split_paragraphs(path_to_file):
    """
    split into paragraphs
    
    Attributes: 
        file_path: path of the data
    
    Returns:
        paragraphs: list of paragraph
    """
    with open(path_to_file, 'r') as f:
        babi = f.readlines()
    
    paragraph = []
    paragraphs = []
    alphabet = re.compile('[a-zA-Z]')
    for d in babi:
        if d.startswith('1 '):
            if paragraph:
                paragraphs.append(paragraph)
            paragraph = []
        mark = re.search(alphabet, d).span()[0]
        paragraph.append(d[mark:])
        
    return paragraphs

In [4]:
def split_clqa(paragraphs):
    """
    for each paragraph, split into context, label, question and answer

    Attributes:
        paragraphs: list of paragraphs 
    
    Returns:
        context: list of contexts
        label: list of labels
        question: list of questions
        answer: list of answers
        answer_hint: list of answer_hints
    """
    context = []
    label = []
    question = []
    answer = []
    answer_hint = []
    for paragraph in paragraphs:
        for i, sent in enumerate(paragraph):
            if '?' in sent:
                related_para = [para.strip() for para in paragraph[:i] if '?' not in para][::-1]
                if len(related_para) > 20:
                    related_para = related_para[:20]
                context.append(related_para)
                label.append([i for i in range(len(related_para))])
                q_a_ah = sent.split('\t')
                question.append(q_a_ah[0].strip())
                answer.append(q_a_ah[1].strip())
                answer_hint.append(q_a_ah[2].strip())
    # check
    if (len(question) == len(answer)) & (len(answer) == len(context)) & (len(context) == len(label)):
        print("Data is well prepared!")
        print("total: {}".format(len(label)))
    else:
        print("Something is missing! check again")
        print("the number of questions: {}".format(len(question)))
        print("the number of answers: {}".format(len(answer)))
        print("the number of contexts: {}".format(len(context)))
        print("the number of labels: {}".format(len(label)))
        
    return context, label, question, answer, answer_hint

In [5]:
# prepare training data 
train_context = []
train_label = []
train_question = []
train_answer = []
train_answer_hint = []
for train_path in train_paths:
    print('=================')    
    paragraphs = split_paragraphs(train_path)
    print("data: {}".format(os.path.basename(train_path)))
    context, label, question, answer, answer_hint = split_clqa(paragraphs)
    train_context.extend(context)
    train_label.extend(label)
    train_question.extend(question)
    train_answer.extend(answer)
    train_answer_hint.extend(answer_hint)

data: qa13_train.txt
Data is well prepared!
total: 8995
data: qa3_train.txt
Data is well prepared!
total: 8995
data: qa14_train.txt
Data is well prepared!
total: 8995
data: qa20_train.txt
Data is well prepared!
total: 8998
data: qa16_train.txt
Data is well prepared!
total: 8999
data: qa1_train.txt
Data is well prepared!
total: 8995
data: qa12_train.txt
Data is well prepared!
total: 8995
data: qa17_train.txt
Data is well prepared!
total: 8992
data: qa7_train.txt
Data is well prepared!
total: 8995
data: qa8_train.txt
Data is well prepared!
total: 8995
data: qa5_train.txt
Data is well prepared!
total: 8995
data: qa2_train.txt
Data is well prepared!
total: 8995
data: qa15_train.txt
Data is well prepared!
total: 8996
data: qa6_train.txt
Data is well prepared!
total: 8995
data: qa18_train.txt
Data is well prepared!
total: 8991
data: qa4_train.txt
Data is well prepared!
total: 8999
data: qa9_train.txt
Data is well prepared!
total: 8995
data: qa19_train.txt
Data is well prepared!
total: 8999
d

In [6]:
# prepare validation data
val_context = []
val_label = []
val_question = []
val_answer = []
val_answer_hint = []
for val_path in val_paths:
    print('=================')    
    paragraphs = split_paragraphs(val_path)
    print("data: {}".format(os.path.basename(val_path)))
    context, label, question, answer, answer_hint = split_clqa(paragraphs)
    val_context.extend(context)
    val_label.extend(label)
    val_question.extend(question)
    val_answer.extend(answer)
    val_answer_hint.extend(answer_hint)

data: qa19_valid.txt
Data is well prepared!
total: 999
data: qa15_valid.txt
Data is well prepared!
total: 996
data: qa20_valid.txt
Data is well prepared!
total: 988
data: qa1_valid.txt
Data is well prepared!
total: 995
data: qa7_valid.txt
Data is well prepared!
total: 995
data: qa4_valid.txt
Data is well prepared!
total: 999
data: qa12_valid.txt
Data is well prepared!
total: 995
data: qa8_valid.txt
Data is well prepared!
total: 995
data: qa10_valid.txt
Data is well prepared!
total: 995
data: qa2_valid.txt
Data is well prepared!
total: 995
data: qa14_valid.txt
Data is well prepared!
total: 995
data: qa13_valid.txt
Data is well prepared!
total: 995
data: qa18_valid.txt
Data is well prepared!
total: 999
data: qa11_valid.txt
Data is well prepared!
total: 995
data: qa3_valid.txt
Data is well prepared!
total: 995
data: qa17_valid.txt
Data is well prepared!
total: 992
data: qa16_valid.txt
Data is well prepared!
total: 999
data: qa9_valid.txt
Data is well prepared!
total: 995
data: qa5_valid.t

In [7]:
# prepare test data
test_context = []
test_label = []
test_question = []
test_answer = []
test_answer_hint = []
for test_path in test_paths:
    print('=================')    
    paragraphs = split_paragraphs(test_path)
    print("data: {}".format(os.path.basename(test_path)))
    context, label, question, answer, answer_hint = split_clqa(paragraphs)
    test_context.extend(context)
    test_label.extend(label)
    test_question.extend(question)
    test_answer.extend(answer)
    test_answer_hint.extend(answer_hint)

data: qa7_test.txt
Data is well prepared!
total: 995
data: qa2_test.txt
Data is well prepared!
total: 995
data: qa6_test.txt
Data is well prepared!
total: 995
data: qa11_test.txt
Data is well prepared!
total: 995
data: qa3_test.txt
Data is well prepared!
total: 995
data: qa4_test.txt
Data is well prepared!
total: 999
data: qa8_test.txt
Data is well prepared!
total: 995
data: qa16_test.txt
Data is well prepared!
total: 999
data: qa10_test.txt
Data is well prepared!
total: 995
data: qa17_test.txt
Data is well prepared!
total: 992
data: qa20_test.txt
Data is well prepared!
total: 993
data: qa1_test.txt
Data is well prepared!
total: 995
data: qa5_test.txt
Data is well prepared!
total: 995
data: qa19_test.txt
Data is well prepared!
total: 999
data: qa18_test.txt
Data is well prepared!
total: 997
data: qa9_test.txt
Data is well prepared!
total: 995
data: qa12_test.txt
Data is well prepared!
total: 995
data: qa13_test.txt
Data is well prepared!
total: 995
data: qa15_test.txt
Data is well prep

In [8]:
# convert to index
# word set
cq_word_set = set()
list_of_context = [train_context, val_context, test_context]
list_of_question = [train_question, val_question, test_question]

for list_ in list_of_context:
    for para in list_:
        for sent in para:
            sent = sent.replace(".", "")
            sent = sent.replace("?", "")
            sent = sent.split(" ")
            cq_word_set.update(sent)
            
for list_ in list_of_question:
    for sent in list_:
        sent = sent.replace(".", "")
        sent = sent.replace("?", "")
        sent = sent.split()
        cq_word_set.update(sent)

In [9]:
answer_word_set = set()
list_of_answer = [train_answer, val_answer, test_answer]
for answers in list_of_answer:
    for answer in answers:
            answer_word_set.add(answer)

In [10]:
print("context and question words: {}".format(len(cq_word_set)))
print("answer words: {}".format(len(answer_word_set)))

context and question words: 161
answer words: 60


In [11]:
cq_word_index = {}
for i, word in enumerate(cq_word_set):
    cq_word_index[word] = i

answer_word_index = {}
for i, word in enumerate(answer_word_set):
    answer_one_hot = np.zeros([len(answer_word_set)])
    answer_one_hot[i] = 1
    answer_word_index[word] = answer_one_hot

In [12]:
train_context_index = []
val_context_index = []
test_context_index = []

for para in train_context:
    indexed_para = []
    for sent in para:
        sent = sent.replace(".", "")
        sent = sent.replace("?", "")
        sent = sent.split(" ")
        indexed_sent = []
        for word in sent:
            indexed_sent.append(cq_word_index[word])
        indexed_para.append(indexed_sent)
    train_context_index.append(np.array(indexed_para))
    
for para in val_context:
    indexed_para = []
    for sent in para:
        sent = sent.replace(".", "")
        sent = sent.replace("?", "")
        sent = sent.split(" ")
        indexed_sent = []
        for word in sent:
            indexed_sent.append(cq_word_index[word])
        indexed_para.append(indexed_sent)
    val_context_index.append(np.array(indexed_para))
    
for para in test_context:
    indexed_para = []
    for sent in para:
        sent = sent.replace(".", "")
        sent = sent.replace("?", "")
        sent = sent.split(" ")
        indexed_sent = []
        for word in sent:
            indexed_sent.append(cq_word_index[word])
        indexed_para.append(indexed_sent)
    test_context_index.append(np.array(indexed_para))

In [13]:
if (len(train_context_index) + len(val_context_index) + len(test_context_index)) == (len(train_context) + len(test_context) + len(val_context)):
    print("context encoding is completed!")
else:
    print("Something is missing! Check again")

context encoding is completed!


In [14]:
train_question_index = []
val_question_index = []
test_question_index = []

for sent in train_question:
    sent = sent.replace(".", "")
    sent = sent.replace("?", "")
    sent = sent.split()
    indexed_sent = []
    for word in sent:
        indexed_sent.append(cq_word_index[word])
    train_question_index.append(np.array(indexed_sent))
    
for sent in val_question:
    sent = sent.replace(".", "")
    sent = sent.replace("?", "")
    sent = sent.split()
    indexed_sent = []
    for word in sent:
        indexed_sent.append(cq_word_index[word])
    val_question_index.append(np.array(indexed_sent))
    
for sent in test_question:
    sent = sent.replace(".", "")
    sent = sent.replace("?", "")
    sent = sent.split()
    indexed_sent = []
    for word in sent:
        indexed_sent.append(cq_word_index[word])
    test_question_index.append(np.array(indexed_sent))

In [15]:
if (len(train_question_index) + len(val_question_index) + len(test_question_index)) == (len(train_question) + len(test_question) + len(val_question)):
    print("question encoding is completed!")
else:
    print("Something is missing! Check again")

question encoding is completed!


In [16]:
train_answer_index = []
val_answer_index = []
test_answer_index = []

for answer in train_answer:
    indexed_answer = answer_word_index[answer]
    train_answer_index.append(indexed_answer)
    
for answer in val_answer:
    indexed_answer = answer_word_index[answer]
    val_answer_index.append(indexed_answer)
    
for answer in test_answer:
    indexed_answer = answer_word_index[answer]
    test_answer_index.append(indexed_answer)

In [17]:
if (len(train_answer_index)+ len(val_answer_index) + len(test_answer_index)) == (len(train_answer) + len(test_answer) + len(val_answer)):
    print("answer encoding is completed!")
else:
    print("Something is missing! Check again")

answer encoding is completed!


In [18]:
train_label_index = []
val_label_index = []
test_label_index = []

for label in train_label:
    train_label_index.append(np.eye(20)[label])
    
for label in val_label:
    val_label_index.append(np.eye(20)[label])

for label in test_label:
    test_label_index.append(np.eye(20)[label])

In [19]:
if (len(train_label_index)+ len(val_label_index) + len(test_label_index)) == (len(train_label) + len(test_label) + len(val_label)):
    print("label encoding is completed!")
else:
    print("Something is missing! Check again")

label encoding is completed!


In [20]:
train_dataset = (train_question_index, train_answer_index, train_context_index, train_label_index)
val_dataset = (val_question_index, val_answer_index, val_context_index, val_label_index)
test_dataset = (test_question_index, test_answer_index, test_context_index, test_label_index)

In [21]:
with open('babi_preprocessd/train_dataset.pkl', 'wb') as f:
    pickle.dump(train_dataset, f)

with open('babi_preprocessd/val_dataset.pkl', 'wb') as f:
    pickle.dump(val_dataset, f)
    
with open('babi_preprocessd/test_dataset.pkl', 'wb') as f:
    pickle.dump(test_dataset, f)