In [1]:
from collections import Counter
import string
import re
import argparse
import json
import sys
import numpy as np
import nltk
import random
import math
import os
import pickle
from tqdm import tqdm, trange

import pdb

from pytorch_pretrained_bert.tokenization import (BasicTokenizer,
                                                  BertTokenizer,
                                                  whitespace_tokenize)

In [2]:
def pickler(path,pkl_name,obj):
    with open(os.path.join(path, pkl_name), 'wb') as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)

def unpickler(path,pkl_name):
    with open(os.path.join(path, pkl_name) ,'rb') as f:
        obj = pickle.load(f)
    return obj

In [3]:
TRAINING = True

out_pkl_path = "./"
in_pkl_path = "./"

if(TRAINING):
    in_pkl_name = "preproc_train_1.pkl"
    out_pkl_name = "preprocessed_train.pkl"
    small_out_pkl_name = "preprocessed_train_small.pkl"
    small_dataset_size = 5000
else:
    in_pkl_name = "preproc_dev_1.pkl"
    out_pkl_name = "preprocessed_dev.pkl"
    small_out_pkl_name = "preprocessed_dev_small.pkl"
    small_dataset_size = 500

max_seq_len = 512
max_question_len = 35
max_context_chunk_length = max_seq_len - max_question_len - 2
max_num_chunks = 4 

In [4]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

cls_id = tokenizer.convert_tokens_to_ids(["[CLS]"])[0]
sep_id = tokenizer.convert_tokens_to_ids(["[SEP]"])[0]
pad_id = tokenizer.convert_tokens_to_ids(["[PAD]"])[0]
                                         

print("[CLS] : {}".format(cls_id))
print("[SEP] : {}".format(sep_id))
print("[PAD] : {}".format(pad_id))

[CLS] : 101
[SEP] : 102
[PAD] : 0


In [5]:
data_in = unpickler(in_pkl_path, in_pkl_name)

In [6]:
data_in.keys()

dict_keys(['question_ids', 'questions', 'paragraphs', 'paragraph_names', 'answers', 'question_indices', 'yes_no_span', 'supporting_facts', 'ids_to_word_mappings', 'answers_string', 'supporting_facts_raw'])

In [7]:
def compute_paragraph_length(paragraph):
    length = 0
    for sent in paragraph:
        length += len(sent)
    return length

def make_chunks(paragraphs, max_num_chunks, max_chuck_length):
    chunks = [[] for i in range(max_num_chunks)]
    paragraph_lengths = [compute_paragraph_length(p) for p in paragraphs]
    paragraph_indices_arg_sorted = np.argsort(paragraph_lengths)
    num_sentences_in_paragraphs = [len(p) for p in paragraphs]
    paragraph_chunk_indices = [[] for i in range(max_num_chunks)]
    chunk_len_so_far = 0
    current_chunk=0
    for i in paragraph_indices_arg_sorted:
        if(chunk_len_so_far + min(max_chuck_length, paragraph_lengths[i]) > max_chuck_length):
            if(current_chunk < max_num_chunks-1):
                current_chunk += 1
                chunk_len_so_far = 0
        paragraph_chunk_indices[current_chunk].append(i)
        chunks[current_chunk] += paragraphs[i]
        chunk_len_so_far += paragraph_lengths[i]
        
    return chunks, paragraph_chunk_indices, num_sentences_in_paragraphs
    
def reorganize_supporting_fact_labels(paragraph_chunk_indices, num_sentences_in_paragraphs, supporting_facts):
    def find_new_index(paragraph_chunk_indices, sf):
        for i,chunk in enumerate(paragraph_chunk_indices):
            if sf[0] in chunk:
                sentence_index = sum([num_sentences_in_paragraphs[j] for j in chunk[:chunk.index(sf[0])] ]) + sf[1]
                return [i, sentence_index]
    
    sf_out = []
    
    for sf in supporting_facts:
        chunk_index, sentence_index = find_new_index(paragraph_chunk_indices, sf)
        sf_out.append([chunk_index, sentence_index])
    
    return sf_out

In [8]:
def pad_trim(sequences, max_len, pad_symbol=0):
    sequences_out = []
    for sequence in sequences:
        seq = sequence[:max_len]
        seq += [pad_symbol] * (max_len - len(seq))
        sequences_out.append(seq)
    return sequences_out

def trim_paragraph(paragraph, max_seq_len):
    assert(max_seq_len >= 0)
    sent_lengths = [len(s) for s in paragraph]
    out_paragraph = []
    length_so_far = 0
    for sent in paragraph:
        if(len(sent) == 0):
            continue
        if(length_so_far + len(sent) <= max_seq_len):
            out_paragraph.append(sent)
            if(length_so_far + len(sent) == max_seq_len):
                break
            length_so_far += len(sent)      
        else:
            sent = sent[:max_seq_len-length_so_far]
            out_paragraph.append(sent)
            break
    return out_paragraph
    
def pad_paragraph(paragraph, max_sequence_len, pad_index):
    assert(max_sequence_len >= 0)
    sent_lengths = [len(s) for s in paragraph]
    assert(sum(sent_lengths) <= max_sequence_len)
    paragraph.append([pad_index] * (max_sequence_len - sum(sent_lengths)))
    return paragraph

def merge_trim_pad_paragraphs(paragraph, paragraph_index, supporting_facts_in, max_seq_len,
                              max_sentences, pad_index=0):
    sentence_start_indices = []
    sentence_end_indices = []
    
    paragraph = paragraph[:max_sentences-1]
    
    total_para_len_words = sum([len(s) for s in paragraph])
    
    available_length_for_paragraph = max_seq_len
    
    if(total_para_len_words >= available_length_for_paragraph):
        paragraph = trim_paragraph(paragraph, available_length_for_paragraph-1) #-1 to make room for the next empty sentence
        paragraph.append([pad_index])
    elif(total_para_len_words < available_length_for_paragraph):
        paragraph = pad_paragraph(paragraph, available_length_for_paragraph, pad_index)
        
        
    #concatenate sentences, note starting and ending indices of sentences
    sentence_start_indices = []
    sentence_end_indices = []
    out_sequence = []
    for sent in paragraph:
        sentence_start_indices.append(len(out_sequence))
        out_sequence += sent
        sentence_end_indices.append(len(out_sequence)-1)
    
    assert(len(sentence_start_indices) == len(sentence_end_indices))
            
    #create supporting_facts vector
    supporting_facts = [0] * max_sentences
    for s_f in supporting_facts_in:
        if(s_f[0] == paragraph_index and s_f[1]<max_sentences):
            supporting_facts[s_f[1]] = 1
            
            
    # sanity check
    assert(len(out_sequence) == available_length_for_paragraph)
    assert(len(supporting_facts) == max_sentences)
    
    return {'sequence': out_sequence,
            'sentence_start_index': sentence_start_indices, 'sentence_end_index': sentence_end_indices,
            'supporting_fact': supporting_facts}

In [9]:
def find_all_in_sequence(sequence, key):
    start_indices = []
    end_indices = []
    for i in range(len(sequence)):
        if(sequence[i:i+len(key)] == key):
            start_indices.append(i)
            end_indices.append(i+len(key)-1)
    assert(len(start_indices) == len(end_indices))
    return start_indices,end_indices

def find_answer_locations(passages, answers, yes_no_span):
    '''
    Input: [Passage_chunk_0 , ... , Passage_chunk_j ] .  
    Finds the indices of true answer in the chunks. The indices are as if the passage chunks are concatenated 
    '''
    assert(len(passages) == len(answers))
    answer_start_indices = []
    answer_end_indices = []
    for i in range(len(passages)):
        if(yes_no_span[i] != 2):
            answer_start_indices.append([0])
            answer_end_indices.append([0])
        else:
            mega_sequence = []
            for j in range(len(passages[i])):
                mega_sequence += passages[i][j]
            a_s, a_e = find_all_in_sequence(mega_sequence, answers[i])
            assert(len(a_s) == len(a_e))
            answer_start_indices.append(a_s)
            answer_end_indices.append(a_e)
    return answer_start_indices, answer_end_indices

In [10]:
chunks, paragraph_chunk_indices, num_sentences_in_paragraphs = make_chunks(paragraphs=data_in['paragraphs'][0], 
                                                                           max_num_chunks=max_num_chunks, 
                                                                           max_chuck_length=max_context_chunk_length)

In [11]:
num_sentences_in_paragraphs

[7, 4, 9, 4, 1, 3, 5, 4, 4, 5]

In [12]:
paragraph_chunk_indices

[[4, 7, 5, 8, 9, 6], [3, 1, 0], [2], []]

In [13]:
len(chunks)

4

In [14]:
len(chunks[0][0])

33

In [15]:
# print(len(chunks[0]))
# print(len(chunks[1]))
# print(len(chunks[2]))

In [16]:
reorganize_supporting_fact_labels(paragraph_chunk_indices=[[0,2],[1],[3]], num_sentences_in_paragraphs=[5,2,3,2], 
                                  supporting_facts = [[0,2],[0,3],[2,2],[3,1]])

[[0, 2], [0, 3], [0, 7], [2, 1]]

In [17]:
paragraphs_chunked = []
paragraph_chunk_indices = []
num_sentences_in_paragraphs = []

for i in trange(len(data_in['paragraphs'])):
    chunk, p_indices, num_sentences = make_chunks(paragraphs=data_in['paragraphs'][i], 
                                                  max_num_chunks=max_num_chunks, 
                                                  max_chuck_length=max_context_chunk_length)
    paragraphs_chunked.append(chunk)
    paragraph_chunk_indices.append(p_indices)
    num_sentences_in_paragraphs.append(num_sentences)

100%|██████████| 90447/90447 [00:18<00:00, 4765.69it/s]


In [18]:
supporting_facts_in_paragraph_chunks = []

for i in trange(len(data_in['paragraphs'])):
    sf = reorganize_supporting_fact_labels(paragraph_chunk_indices= paragraph_chunk_indices[i], 
                                           num_sentences_in_paragraphs= num_sentences_in_paragraphs[i], 
                                           supporting_facts = data_in['supporting_facts'][i])
    supporting_facts_in_paragraph_chunks.append(sf)

100%|██████████| 90447/90447 [00:00<00:00, 136028.33it/s]


In [19]:
supporting_facts_in_paragraph_chunks[4000]

[[0, 0], [0, 1], [0, 11]]

In [20]:
len(paragraphs_chunked[0][0])

22

In [21]:
num_sentences_per_chunk = []

for p_chunks in paragraphs_chunked:
    for chunk in p_chunks:
        num_sentences_per_chunk.append(len(chunk))
        
num_sentences_per_chunk = np.array(num_sentences_per_chunk)

print("Avg number of sentences per chunk :{}".format(num_sentences_per_chunk.mean()))
print("min number of sentences per chunk :{}".format(num_sentences_per_chunk.min()))
print("max number of sentences per chunk :{}".format(num_sentences_per_chunk.max()))

Avg number of sentences per chunk :10.236226740522074
min number of sentences per chunk :0
max number of sentences per chunk :103


In [22]:
max_num_sentences_per_chunk = 18
np.sum(np.greater(num_sentences_per_chunk,max_num_sentences_per_chunk))/num_sentences_per_chunk.shape[0]

0.04600760666467655

In [23]:
i=990
print(data_in['supporting_facts'][i])
print(supporting_facts_in_paragraph_chunks[i])
print(paragraph_chunk_indices[i])

[[5, 0], [9, 0]]
[[1, 4], [1, 7]]
[[0, 2, 6, 1, 7, 4, 3], [8, 5, 9], [], []]


In [24]:
questions_fixed_length = pad_trim(sequences = data_in['questions'], max_len=max_question_len, pad_symbol=0)

In [25]:
for i in range(len(data_in['questions'])):
    assert(len(questions_fixed_length[i]) == max_question_len)

In [26]:
paragraph_chunks_fixed_length = []
sentence_start_indices = []
sentence_end_indices = []
supporting_facts_expanded = []


for i,q in enumerate(tqdm(questions_fixed_length)):
    p_chunks = []
    sent_start = []
    sent_end = []
    sf_expanded = []
    for j,para in enumerate(paragraphs_chunked[i]):
        info = merge_trim_pad_paragraphs(paragraph=para, paragraph_index=j, 
                                          supporting_facts_in=supporting_facts_in_paragraph_chunks[i], 
                                          max_seq_len=max_seq_len-max_question_len-2,
                                          max_sentences=max_num_sentences_per_chunk, pad_index=pad_id)
        p_chunks.append(info['sequence'])
        sent_start.append(info['sentence_start_index'])
        sent_end.append(info['sentence_end_index'])
        sf_expanded += info['supporting_fact']
    
    paragraph_chunks_fixed_length.append(p_chunks)
    sentence_start_indices.append(sent_start)
    sentence_end_indices.append(sent_end)
    supporting_facts_expanded.append(sf_expanded)
    

assert(len(paragraph_chunks_fixed_length) == len(sentence_start_indices))
assert(len(sentence_end_indices) == len(sentence_start_indices))
assert(len(supporting_facts_expanded) == len(sentence_start_indices))

100%|██████████| 90447/90447 [00:14<00:00, 6419.27it/s]


In [27]:
len(supporting_facts_expanded[0])

72

In [28]:
idx = 3601
print(supporting_facts_in_paragraph_chunks[idx])
print(np.where(np.array(supporting_facts_expanded[idx]) == 1)[0])

[[0, 12], [0, 0]]
[ 0 12]


In [29]:
for i in range(len(supporting_facts_expanded)):
    assert(len(supporting_facts_expanded[i]) == max_num_sentences_per_chunk * max_num_chunks)

In [30]:
print(len(paragraph_chunks_fixed_length))

90447


In [31]:
print(max_seq_len-max_question_len-2)
print(len(paragraph_chunks_fixed_length[0][0]))

475
475


In [32]:
answer_start_indices, answer_end_indices = find_answer_locations(passages = paragraph_chunks_fixed_length, 
                                                                 answers = data_in['answers'], 
                                                                 yes_no_span = data_in['yes_no_span'])

In [33]:
len(supporting_facts_expanded[0])

72

In [34]:
assert(len(answer_start_indices) == len(answer_end_indices))

In [35]:
for i in range(len(answer_start_indices)):
    assert(len(answer_start_indices[i]) == len(answer_end_indices[i]))

In [36]:
i = 344
print(answer_start_indices[i])
print(answer_end_indices[i])

[63, 369, 1158]
[65, 371, 1160]


In [37]:
i = 3
print(sentence_start_indices[i][0])
print(sentence_end_indices[i][0])

[0, 31, 65, 104, 155, 178, 214, 239, 264, 279, 351, 401, 423, 449]
[30, 64, 103, 154, 177, 213, 238, 263, 278, 350, 400, 422, 448, 474]


In [38]:
assert(len(paragraph_chunks_fixed_length) == len(questions_fixed_length))

question_context_sequences = []
for i in trange(len(paragraph_chunks_fixed_length)):
    sequences = []
    for j in range(len(paragraph_chunks_fixed_length[i])):
        seq = [cls_id] + questions_fixed_length[i] + [sep_id] + paragraph_chunks_fixed_length[i][j]
        sequences.append(seq)
    question_context_sequences.append(sequences)

100%|██████████| 90447/90447 [00:04<00:00, 19813.86it/s]


In [39]:
for i in range(len(question_context_sequences)):
    for j in range(len(question_context_sequences[i])):
        assert(len(question_context_sequences[i][j]) == max_seq_len)

In [40]:
segment_id = [0] + [0]*max_question_len + [1] + [1]* (max_seq_len - max_question_len - 2)

In [41]:
assert(len(segment_id) == max_seq_len)

Things to pkl:
- question_context_sequences
- segment_id
- sentence_start_indices
- sentence_end_indices
- answer_start_indices
- answer_end_indices
- supporting_facts_expanded
- question_ids
- question_indices
- yes_no_span
- ids_to_word_mappings
- max_seq_len
- max_question_len
- paragraph_chunk_indices
- num_sentences_in_paragraphs
- paragraph_names
- answers_string
- supporting_facts_raw

In [42]:
assert(
    len(question_context_sequences) == 
    len(sentence_start_indices) == 
    len(sentence_end_indices) == 
    len(answer_start_indices) == 
    len(answer_end_indices) == 
    len(data_in['question_ids']) == 
    len(data_in['question_indices']) == 
    len(data_in['yes_no_span']) == 
    len(data_in['ids_to_word_mappings']) == 
    len(paragraph_chunk_indices) == 
    len(data_in['paragraph_names']) == 
    len(data_in['answers_string']) == 
    len(data_in['supporting_facts_raw']) == 
    len(supporting_facts_expanded) == 
    len(num_sentences_in_paragraphs) 
)

In [43]:
out_dict = {
    'question_context_sequences': question_context_sequences,
    'segment_id' : segment_id,
    'sentence_start_indices': sentence_start_indices,
    'sentence_end_indices': sentence_end_indices,
    'answer_start_indices': answer_start_indices,
    'answer_end_indices':answer_end_indices,
    'supporting_facts_expanded': supporting_facts_expanded,
    'question_ids': data_in['question_ids'],
    'question_indices': data_in['question_indices'],
    'yes_no_span': data_in['yes_no_span'],
    'ids_to_word_mappings': data_in['ids_to_word_mappings'],
    'max_seq_len': max_seq_len,
    'max_question_len': max_question_len,
    'max_num_sentences_per_chunk': max_num_sentences_per_chunk,
    'num_chunks': len(question_context_sequences[0]),
    'paragraph_chunk_indices': paragraph_chunk_indices,
    'num_sentences_in_paragraphs': num_sentences_in_paragraphs,
    'paragraph_names': data_in['paragraph_names'],
    'answers_string': data_in['answers_string'],
    'supporting_facts_raw': data_in['supporting_facts_raw']
}

In [44]:
small_out_dict = {
    'question_context_sequences': question_context_sequences[:small_dataset_size],
    'segment_id' : segment_id,
    'sentence_start_indices': sentence_start_indices[:small_dataset_size],
    'sentence_end_indices': sentence_end_indices[:small_dataset_size],
    'answer_start_indices': answer_start_indices[:small_dataset_size],
    'answer_end_indices':answer_end_indices[:small_dataset_size],
    'supporting_facts_expanded': supporting_facts_expanded[:small_dataset_size],
    'question_ids': data_in['question_ids'][:small_dataset_size],
    'question_indices': data_in['question_indices'][:small_dataset_size],
    'yes_no_span': data_in['yes_no_span'][:small_dataset_size],
    'ids_to_word_mappings': data_in['ids_to_word_mappings'][:small_dataset_size],
    'max_seq_len': max_seq_len,
    'max_question_len': max_question_len,
    'max_num_sentences_per_chunk': max_num_sentences_per_chunk,
    'num_chunks': len(question_context_sequences[0]),
    'paragraph_chunk_indices': paragraph_chunk_indices[:small_dataset_size],
    'num_sentences_in_paragraphs': num_sentences_in_paragraphs[:small_dataset_size],
    'paragraph_names': data_in['paragraph_names'][:small_dataset_size],
    'answers_string': data_in['answers_string'][:small_dataset_size],
    'supporting_facts_raw': data_in['supporting_facts_raw'][:small_dataset_size]
}

In [45]:
pickler(out_pkl_path, out_pkl_name, out_dict)
print("done")

done


In [46]:
pickler(out_pkl_path, small_out_pkl_name, small_out_dict)
print("done")

done
