In [61]:
#import utils
import os
from collections import Counter
import pickle
import numpy as np
import config

ImportError: No module named 'config'

In [3]:
data_path = '/Users/kellyzhang/Documents/ReadingComprehension/DeepMindDataset/cnn/questions'
train_path = os.path.join(data_path, "training") # 380298
validation_path = os.path.join(data_path, "validation") # 3924
test_path = os.path.join(data_path, "test") #3198

In [4]:
def create_data_files(in_file_path, max_example=None, relabeling=True, write_file=None):
    """
    load CNN / Daily Mail data from {train | dev | test} directories
    relabeling: relabel the entities by their first occurence if it is True.
    """
    documents = []
    questions = []
    answers = []
    num_examples = 0

    for document in os.listdir(in_file_path):
        f = open(os.path.join(in_file_path, document), 'r')

        content = f.read().splitlines()
        document = content[2].strip().lower()
        question = content[4].strip().lower()
        answer = content[6]

        if relabeling:
            q_words = question.split(' ')
            d_words = document.split(' ')
            assert answer in d_words

            entity_dict = {}
            entity_id = 0
            for word in d_words + q_words:
                if (word.startswith('@entity')) and (word not in entity_dict):
                    entity_dict[word] = '@entity' + str(entity_id)
                    entity_id += 1

            q_words = [entity_dict[w] if w in entity_dict else w for w in q_words]
            d_words = [entity_dict[w] if w in entity_dict else w for w in d_words]
            answer = entity_dict[answer]

            question = ' '.join(q_words)
            document = ' '.join(d_words)

        questions.append(question)
        answers.append(answer)
        documents.append(document)
        num_examples += 1

        if (max_example is not None) and (num_examples >= max_example):
            break
    print("#Examples: {}".format(len(documents)))
    f.close()
    
    if write_file:
        f = open(write_file, 'w')
        for i in range(len(questions)):
            f.write(documents[i]+"\n")
            f.write(questions[i]+"\n")
            f.write(answers[i]+"\n")
        f.close()
    
    return (documents, questions, answers)


In [41]:
documents, questions, answers = create_data_files(test_path), write_file=os.path.join(data_path, "test.txt"))

#Examples: 3198


In [30]:
def load_data(in_file_path, max_examples=None):
    """
    load CNN / Daily Mail data from {train | dev | test}.txt
    """
    documents = []
    questions = []
    answers = []
    num_examples = 0
    
    f = open(in_file_path, 'r')
    while True:
        line = f.readline()
        if not line:
            break
            
        document = line.strip().lower()
        question = f.readline().strip()
        answer = f.readline().strip().lower()
        num_examples += 1
        
        questions.append(question)
        answers.append(answer)
        documents.append(document)
        
        if (max_examples is not None) and (num_examples >= max_examples):
            break
    print("#Examples: {}".format(len(documents)))
    f.close()
    return (documents, questions, answers)

In [77]:
documents, questions, answers = load_data(os.path.join(data_path, "train.txt"), max_examples=100)

#Examples: 100


In [78]:
def build_dict(sentences, max_words=50000):
    """
        Build a dictionary for the words in `sentences`.
        Only the max_words ones are kept and the remaining will be mapped to <UNK>.
    """
    word_count = Counter()
    for sent in sentences:
        for w in sent.split(' '):
            word_count[w] += 1

    ls = word_count.most_common(max_words)

    # leave 0 to UNK
    # leave 1 to delimiter |||
    return {w[0]: index + 2 for (index, w) in enumerate(ls)}

In [79]:
vocabulary_dict = build_dict(documents+questions)

In [80]:
entity_markers = list(set([w for w in vocabulary_dict.keys()
                              if w.startswith('@entity')]+answers))
entity_markers = ['<unk_entity>'] + entity_markers
entity_dict = {w: index for (index, w) in enumerate(entity_markers)}
num_labels = len(entity_dict)

In [81]:
# gen_embeddings for pretrained

In [91]:
def vectorize(documents, questions, answers, vocabulary_dict, entity_dict,
              sort_by_len=True, verbose=True):
    """
        Vectorize `examples`.
        in_d, in_q: sequences for document and question respecitvely.
        in_y: label
        in_l: whether the entity label occurs in the document.
    """
    in_d = []
    in_q = []
    in_l = np.zeros((len(answers), len(entity_dict)))
    in_y = []
    for idx in range(len(answers)):
        d_words = documents[idx].split(' ')
        q_words = questions[idx].split(' ')
        assert (answers[idx] in d_words)
        seq1 = [vocabulary_dict[w] if w in vocabulary_dict else 0 for w in d_words]
        seq2 = [vocabulary_dict[w] if w in vocabulary_dict else 0 for w in q_words]
        if (len(seq1) > 0) and (len(seq2) > 0):
            in_d.append(seq1)
            in_q.append(seq2)
            in_l[idx, [entity_dict[w] for w in d_words if w in entity_dict]] = 1.0
            in_y.append(entity_dict[answers[idx]] if answers[idx] in entity_dict else 0)
        if verbose and (idx % 10000 == 0):
            print('Vectorization: processed {} / {}'.format(idx, len(answers)))

    def len_argsort(seq):
        return sorted(range(len(seq)), key=lambda x: len(seq[x]))

    if sort_by_len:
        # sort by the document length
        sorted_index = len_argsort(in_d)
        in_d = [in_d[i] for i in sorted_index]
        in_q = [in_q[i] for i in sorted_index]
        in_l = in_l[sorted_index]
        in_y = [in_y[i] for i in sorted_index]

    return in_d, in_q, in_l, in_y

In [92]:
a = vectorize(documents, questions, answers, vocabulary_dict, entity_dict)

Vectorization: processed 0 / 100


In [93]:
a[1][0]

[52, 64, 10, 648, 5, 42, 359, 310, 3, 42, 96, 181, 388, 116]

100

96