In [9]:
import gzip
import pickle
import numpy as np
import h5py
import argparse
import sys
import re
import codecs
from itertools import izip

In [10]:
# Helper functions

# Converts indices to words 
def convert_to_words(indices, indices_to_word):
    return (' '.join([indices_to_word[ind] for ind in indices]))

In [11]:
directory = '../data/MovieTriples/'

# Loading all the possible files into memory
with open(directory + 'Training.triples.pkl') as f:
    train_set = pickle.load(f)
    
with open(directory + 'Validation.triples.pkl') as f:
    valid_set = pickle.load(f)
    
with open(directory + 'Test.triples.pkl') as f:
    test_set = pickle.load(f)
    
with open(directory + 'Word2Vec_WordEmb.pkl') as f:
    emb_wordvec = pickle.load(f)
    
with open(directory + 'MT_WordEmb.pkl') as f:
    emb_mt = pickle.load(f)

In [12]:
# Implement the word indices according to the format for the seq2seq model

# Make sure that the word_indices are 1 indexed for lua

# Do a swap with the embeddings and word_indices so it follows the conventions for indices 
# self.d = {self.PAD: 1, self.UNK: 2, self.BOS: 3, self.EOS: 4}

with open(directory + 'Training.dict.pkl') as f:
    word_mappings = pickle.load(f)

# All the swaps necessary to make the formatting consistent with seq2seq (sorry, it's so messy)
del_ind = []
for i in range(len(word_mappings)):
    word_mapping = word_mappings[i]
    if word_mapping[0] == '<unk>' or word_mapping[0] == '<s>' or \
        word_mapping[0] == '</s>' or word_mapping[0] == '.' or \
        word_mapping[0] == "'":
            print(word_mapping, i)
            del_ind.append(i)

del_ind.sort(reverse=True)
for ind in del_ind:
    del word_mappings[ind]
        
word_mappings.append(('<blank>', 1, 0, 0))
word_mappings.append(('<unk>', 2, 190588, 89059))
word_mappings.append(('<s>', 3, 588827, 785135))
word_mappings.append(('</s>', 4, 588827, 785135))
word_mappings.append(('.', 10003, 855616, 192250))
word_mappings.append(("'", 10004, 457542, 160249))
word_mappings.append(('<t>', 10005, 0, 0))

# Sanity check
check_mappings = range(1, len(word_mappings)+1)
for word_mapping in word_mappings:
    check_mappings.remove(word_mapping[1])
assert check_mappings == []

(('<s>', 1, 588827, 785135), 90)
(('</s>', 2, 588827, 785135), 378)
(('.', 3, 855616, 192250), 426)
(('<unk>', 0, 190588, 89059), 604)
(("'", 4, 457542, 160249), 2115)


In [13]:
# The changes that need to occur in the actual text examples are: 
# ., 3 -> 10003
# ', 4 -> 10004
# <unk>, 0 -> 2
# <s>, 1 -> 3
# </s>, 2 -> 4

print(train_set[0])
data_sets = [train_set, valid_set, test_set]
for i in range(len(data_sets)):
    for j in range(len(data_sets[i])):
        line = data_sets[i][j]
        for k in range(len(line)):
            ind = line[k]
            if ind == 3:
                line[k] = 10003
            elif ind == 4:
                line[k] = 10004
            elif ind == 0:
                line[k] = 2
            elif ind == 1:
                line[k] = 3
            elif ind == 2:
                line[k] = 4
        data_sets[i][j] = line
print(train_set[0])

[1, 6, 1577, 11, 22, 52, 300, 413, 28, 2, 1, 5433, 28, 497, 22, 308, 121, 28, 190, 3, 2, 1, 43, 7, 112, 194, 6, 27, 90, 5, 8, 9, 2]
[3, 6, 1577, 11, 22, 52, 300, 413, 28, 4, 3, 5433, 28, 497, 22, 308, 121, 28, 190, 10003, 4, 3, 43, 7, 112, 194, 6, 27, 90, 5, 8, 9, 4]


In [14]:
# Not entirely sure what the other two numbers reprsent in the word index table
# Maybe corresponds to the counts in train... or something?

print(len(word_mappings))
word_mappings[0:5]

10005


[('raining', 4959, 53, 48),
 ('writings', 9977, 18, 15),
 ('yellow', 2155, 175, 142),
 ('four', 341, 2299, 2081),
 ('prices', 5660, 43, 40)]

In [15]:
# Move through the list of words and indices and generate a dictionary
# matching the indices to words

# indices -> word
indices_to_word = {}
for word_ex in word_mappings: 
    indices_to_word[word_ex[1]] = word_ex[0]
    
# word -> indices
word_to_indices = {}
for word_ex in word_mappings: 
    word_to_indices[word_ex[0]] = word_ex[1]

In [16]:
# It looks like the </s> <s> denotes different speakers
# We want to break out the first to examples and then generate the 
# third as output

# For now we can join the first two sentences and assume that the encoder will figure it out with the </s><s>
# Afterwards, we can think about ways to incorporate the three uttterances

line = ' '.join([indices_to_word[ind] for ind in train_set[0]])
line = line.split('</s> <s>')
context = line[0] + '</s> <s>' + line[1]
output = line[2]

# So our input would be
print(context)
# And our output would be
print(output)

# I'll now generate matrices with that format for the rest of the data. 
# Everything will be padded with a 10003 character at the end

<s> you lied to me so many times -- </s> <s> reggie -- trust me once more -- please . 
 can i really believe you this time , <person> ? </s>


In [17]:
line = ' '.join([indices_to_word[ind] for ind in train_set[0]])
line

'<s> you lied to me so many times -- </s> <s> reggie -- trust me once more -- please . </s> <s> can i really believe you this time , <person> ? </s>'

In [18]:
pattern = [word_to_indices['</s>'], word_to_indices['<s>']]

for ind in range(len(train_set[0]))[::-1]:
    if pattern == train_set[0][ind:ind+2]:
        break_pt = ind
        break
        
context = train_set[0][:break_pt]
output = train_set[0][break_pt+2:]

print(convert_to_words(context, indices_to_word))
print(convert_to_words(output, indices_to_word))

<s> you lied to me so many times -- </s> <s> reggie -- trust me once more -- please .
can i really believe you this time , <person> ? </s>


In [32]:
# Apply above basic parsing to all contexts and outputs

PADDING = word_to_indices['<blank>']
END_OF_CONV = word_to_indices['<t>']

full_context = []
full_output = []
max_len_context = 0
max_len_output = 0 

for i in range(len(train_set)):
    break_pt = []
    for ind in range(len(train_set[i]))[::-1]:
        if pattern == train_set[i][ind:ind+2]:
            break_pt.append(ind)

    context = train_set[i][:break_pt[0]]
    output = train_set[i][break_pt[0]+2:]
    
    context = context + [word_to_indices['</s>']]
    output = [word_to_indices['<s>']] + output
    
    # Start of sentence and end of sentence is ONLY used at the end
    # We create a new character that represents the start and end of a conversation
    context = context[:break_pt[1]] + [END_OF_CONV] + context[break_pt[1]+2:]
    
    
    # Cap the target and src length at 302 words to make computation simpler, goes up to ~1500
    if len(context) > 52:
        continue
    if len(output) > 52:
        continue
    
    max_len_output = max(max_len_output, len(output))
    max_len_context = max(max_len_context, len(context))
    max_len_output = 52
    max_len_context = 52
        
    full_context.append(context)
    full_output.append(output)
    
# Add padding to all contexts and outputs
for i in range(len(full_context)):
    full_context[i] = full_context[i] + [PADDING] * (max_len_context - len(full_context[i]))
    full_output[i] = full_output[i] + [PADDING] * (max_len_output - len(full_output[i]))
    
full_context = (full_context)
full_output = (full_output)

# TODO: split randomly rather than at set index
# NB: this data is already shuffled so it's not _that_ big a deal
ind = int(0.8*len(full_context))
train_full_context = full_context[:ind]
train_full_output = full_output[:ind]
valid_full_context = full_context[ind+1:]
valid_full_output = full_output[ind+1:]

# Create micro datasets for quick n dirty code checks
m_ind = 5000
train_micro_context = full_context[:m_ind]
train_micro_output = full_output[:m_ind]
valid_micro_context = full_context[m_ind+1:m_ind+1+(m_ind/10)]
valid_micro_output = full_output[m_ind+1:m_ind+1+(m_ind/10)]

print(len(train_full_context))
print(len(valid_full_context))

127753
31938


In [33]:
# This is super inefficient, put it together last minute. Don't judge :)
def write_context_to_file(filename, full_context):
    f = open(filename, 'w')
    for context in full_context: 
        for ind in context:
            f.write(str(ind) + ' ')
        f.write('\n')
    f.close()
    
# Full training/validation
write_context_to_file('../seq2seq-attn/data/train_full_context.txt', train_full_context)
write_context_to_file('../seq2seq-attn/data/train_full_output.txt', train_full_output)
write_context_to_file('../seq2seq-attn/data/dev_full_context.txt', valid_full_context)
write_context_to_file('../seq2seq-attn/data/dev_full_output.txt', valid_full_output)

# Micro training/validation
write_context_to_file('../seq2seq-attn/data/train_micro_context.txt', train_micro_context)
write_context_to_file('../seq2seq-attn/data/train_micro_output.txt', train_micro_output)
write_context_to_file('../seq2seq-attn/data/dev_micro_context.txt', valid_micro_context)
write_context_to_file('../seq2seq-attn/data/dev_micro_output.txt', valid_micro_output)

with open('../seq2seq-attn/data/targ.dict', 'w') as f:
    for i in range(1, len(indices_to_word)+1):
        f.write(indices_to_word[i] + ' ' + str(i) + '\n')
        
with open('../seq2seq-attn/data/src.dict', 'w') as f:
    for i in range(1, len(indices_to_word)+1):
        f.write(indices_to_word[i] + ' ' + str(i) + '\n')

In [149]:
# # Embeddings map to the generated word_dict 
# print(emb_wordvec)
# print(emb_wordvec[0].shape)
# print(emb_mt)
# print(emb_mt[0].shape)