In [89]:
from __future__ import absolute_import
from __future__ import print_function

import tensorflow as tf
import numpy as np

import os
import codecs
import nltk.data
import collections
from glob import glob
from nltk.tokenize import RegexpTokenizer
from nltk.corpus import stopwords as nltk_stopwords

punctuation_remover = RegexpTokenizer(r'\w+')
stopwords = nltk_stopwords.words('english')
tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')

sess = tf.InteractiveSession()

Exception AssertionError: AssertionError() in <bound method InteractiveSession.__del__ of <tensorflow.python.client.session.InteractiveSession object at 0x7f950ec9ded0>> ignored


In [93]:
vocab_size = 50000

def stopword_filter(text):
    return " ".join([word for word in text.split() if word not in stopwords])

def read_name_data(data_dir):
    with open(os.path.join(data_dir, 'name.txt')) as f:
        name_lists = f.readlines()
        
    names = [name.lower().strip().split('\t') for name in name_lists]

    #name_counter = collections.Counter([word for name_set in names for name in name_set for word in name.split()])
    #print(name_counter.most_common(100))

    name2idx = {}
    nameword2idx = {}

    for idx, name_set in enumerate(names):
        for name in name_set:
            name2idx[name] = idx
            word_in_name = name.split()
            name_without_punctuation = " ".join(punctuation_remover.tokenize(name))

            for name in [name, name_without_punctuation] + word_in_name:
                try:
                    if idx not in name_dict[name]:
                        name2idx[name].append(idx)
                except:
                    nameword2idx[name] = idx

    idx2name = dict(zip(name2idx.values(), name2idx.keys()))
    
    return names, name2idx, idx2name, nameword2idx

def read_data_as_words(data_dir):
    tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')
    
    text = ""
    for filename in glob(os.path.join(data_dir, "*.txt")):
        if 'name.txt' in filename:
            continue
        with open(filename) as f:
            text += f.read()
    return text.split()

def read_data_as_sentences(data_dir, nameword2idx):
    sentences = []

    for filename in glob(os.path.join(data_dir, "*.txt")):
        if 'name.txt' in filename:
            continue

        with codecs.open(filename, 'r', 'utf-8') as f:
            nltk_splited_sentences = tokenizer.tokenize(stopword_filter(f.read().encode('ascii','ignore').lower()))
            current_sentences = [" ".join(punctuation_remover.tokenize(sentence)) for sentence in nltk_splited_sentences
                                 if any(word in nameword2idx.keys() for word in sentence.split())]
            sentences.extend(current_sentences)

            print(" [*] %s finished: %d / %d" % (filename, len(current_sentences), len(nltk_splited_sentences)))

    idx2sentences = {}
    for sentence in sentences:
        for idx in [idx for nameword, idx in nameword2idx.items() if nameword in sentence]:
            idx2sentences.setdefault(idx, []).append(" ".join([word for word in sentence.split() if word not in nameword2idx.keys()]))
    
    new_sentences = []
    name_idx_of_sentence = []
    for idx in idx2sentences.keys():
        for sentence in idx2sentences[idx]:
            new_sentences.append(sentence)
            name_idx_of_sentence.append(idx)
    
    print(" [*] Total sentences : %d" % (len(sentences)))
    return sentences, name_idx_of_sentence

In [123]:
def build_dataset_from_sentences(sentences, name_idx_of_sentence):
    words = [word for sentence in sentences for word in sentence.split()]
    
    count = [['UNK', -1]]
    count.extend(collections.Counter(words).most_common(vocab_size - 1))

    word2idx = dict()
    for word, _ in count:
        word2idx[word] = len(word2idx)

    data = list()
    label_data = list()
    unk_count = 0
    for sentence, name_idx in zip(sentences, name_idx_of_sentence):
        for word in sentence.split():
            if word in word2idx:
                index = word2idx[word]
            else:
                index = 0
                unk_count = unk_count + 1
            data.append(index)
            label_data.append(name_idx)

    count[0][1] = unk_count
    idx2word = dict(zip(word2idx.values(), word2idx.keys()))

    return word2idx, idx2word, data, label_data, count

def generate_batch(data, label_data, batch_size, num_skips, skip_window):
    global data_index
    assert batch_size % num_skips == 0
    assert num_skips <= 2 * skip_window
    
    batch = np.ndarray(shape=(batch_size), dtype=np.int32)
    labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32)

    for i in xrange(batch_size):
        batch[i] = data[data[data_index]]
        labels[i] = label_data[label_data[data_index]]
        
        data_index = (data_index + 1) % len(data)

    return batch, labels

In [94]:
data_dir = './data'

words = read_data_as_words(data_dir)
print('Data size :', len(words))

names, name2idx, idx2name, nameword2idx = read_name_data(data_dir)
print('# of names :', len(names))

sentences, name_idx_of_sentence = read_data_as_sentences(data_dir, nameword2idx)

Data size : 1081571
# of names : 189
 [*] ./data/1.txt finished: 2263 / 6358
 [*] ./data/6.txt finished: 4416 / 11431
 [*] ./data/2.txt finished: 2704 / 6496
 [*] ./data/5.txt finished: 6956 / 17300
 [*] ./data/7.txt finished: 4917 / 14304
 [*] ./data/4.txt finished: 5466 / 13804
 [*] ./data/3.txt finished: 3483 / 8660
 [*] Total sentences : 30205


In [133]:
print("Unkown names : %s" % (set(name2idx.values()) - set(name_idx_of_sentence)))

{25, 78, 101, 113, 179, 183}

In [124]:
word2idx, idx2word, data, label_data, count = build_dataset_from_sentences(sentences, name_idx_of_sentence)
print('Most common words (+UNK) :', count[:5])

Most common words (+UNK) : [['UNK', 0], ('harry', 13177), ('said', 6757), ('ron', 4293), ('hermione', 3792)]


In [139]:
data_index = 0
batch_size = 128
skip_window = 4
num_skips = 2

embed_size = 200
neg_sample_size = 64

x = tf.placeholder(tf.int32, [batch_size])
y = tf.placeholder(tf.int32, [batch_size, 1])
neg_y = tf.placeholder(tf.int32, [neg_sample_size])

init_width = 0.5 / embed_size

embed = tf.Variable(tf.random_uniform([len(idx2name), embed_size], -init_width, init_width))
w = tf.Variable(tf.zeros([vocab_size, embed_size]))
b = tf.Variable(tf.zeros([vocab_size]))

pos_embed = tf.nn.embedding_lookup(embed, x) # [batch_size x embed_size]
pos_w = tf.nn.embedding_lookup(w, y)         # [batch_size x embed_size]
pos_b = tf.nn.embedding_lookup(b, y)         # [batch_size x 1]

pos_y_ = tf.add(tf.reduce_sum(tf.mul(pos_embed, pos_w), 1), pos_b) # [batch_size]

# neg_embed = pos_embed
neg_w = tf.nn.embedding_lookup(w, neg_y)     # [neg_sample_size x embed_size]
neg_b = tf.nn.embedding_lookup(b, neg_y)     # [neg_sample_size]

neg_y_ = tf.matmul(pos_embed, neg_w, transpose_b=True) + neg_b # [batch_size x neg_sample_size]

pos_y = tf.ones_like(pos_y_)
neg_y = tf.ones_like(neg_y_)

pos_loss = tf.nn.sigmoid_cross_entropy_with_logits(pos_y_, tf.ones_like(pos_y_))
neg_loss = tf.nn.sigmoid_cross_entropy_with_logits(neg_y_, tf.zeros_like(neg_y_))

loss = tf.reduce_mean(
      tf.nn.nce_loss(w, b, pos_embed, y, neg_sample_size, vocab_size)
)

##################
# Optimizer
##################

global_step = tf.Variable(0, name="global_step")
inc = global_step.assign_add(1)

# total_word_processed = float(word_per_epoch * epochs_to_train)

learning_rate = 0.01
lr = learning_rate * tf.maximum(
    0.001,
    1.0 - tf.cast(global_step, tf.float32) / num_steps
)

# loss = (tf.reduce_sum(pos_loss) + tf.reduce_sum(neg_loss))/batch_size

with tf.control_dependencies([inc]):
    train = tf.train.GradientDescentOptimizer(lr).minimize(loss, global_step=global_step)

In [142]:
num_steps = 100001

tf.initialize_all_variables().run()
average_loss = 0
for step in xrange(num_steps):
    batch_inputs, batch_labels = generate_batch(
        data, label_data, batch_size, num_skips, skip_window
    )
    feed_dict = {x: batch_inputs, y: batch_labels}
    _, loss_val = sess.run([train, loss], feed_dict=feed_dict)
    average_loss += loss_val

    if step % 2000 == 0:
        if step > 0:
            average_loss = average_loss / 2000

        print("Average loss at step ", step, ": ", average_loss)
        average_loss = 0

InvalidArgumentError: Index 312 at offset 0 in Tindices is out of range
	 [[Node: embedding_lookup_10 = Gather[Tindices=DT_INT32, Tparams=DT_FLOAT, _device="/job:localhost/replica:0/task:0/cpu:0"](Variable_6, _recv_Placeholder_9_0)]]
Caused by op u'embedding_lookup_10', defined at:
  File "/usr/lib/python2.7/runpy.py", line 162, in _run_module_as_main
    "__main__", fname, loader, pkg_name)
  File "/usr/lib/python2.7/runpy.py", line 72, in _run_code
    exec code in run_globals
  File "/usr/local/lib/python2.7/dist-packages/ipykernel/__main__.py", line 3, in <module>
    app.launch_new_instance()
  File "/usr/local/lib/python2.7/dist-packages/traitlets/config/application.py", line 592, in launch_instance
    app.start()
  File "/usr/local/lib/python2.7/dist-packages/ipykernel/kernelapp.py", line 403, in start
    ioloop.IOLoop.instance().start()
  File "/usr/lib/python2.7/dist-packages/zmq/eventloop/ioloop.py", line 160, in start
    super(ZMQIOLoop, self).start()
  File "/usr/local/lib/python2.7/dist-packages/tornado/ioloop.py", line 883, in start
    handler_func(fd_obj, events)
  File "/usr/local/lib/python2.7/dist-packages/tornado/stack_context.py", line 275, in null_wrapper
    return fn(*args, **kwargs)
  File "/usr/lib/python2.7/dist-packages/zmq/eventloop/zmqstream.py", line 433, in _handle_events
    self._handle_recv()
  File "/usr/lib/python2.7/dist-packages/zmq/eventloop/zmqstream.py", line 465, in _handle_recv
    self._run_callback(callback, msg)
  File "/usr/lib/python2.7/dist-packages/zmq/eventloop/zmqstream.py", line 407, in _run_callback
    callback(*args, **kwargs)
  File "/usr/local/lib/python2.7/dist-packages/tornado/stack_context.py", line 275, in null_wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python2.7/dist-packages/ipykernel/kernelbase.py", line 260, in dispatcher
    return self.dispatch_shell(stream, msg)
  File "/usr/local/lib/python2.7/dist-packages/ipykernel/kernelbase.py", line 212, in dispatch_shell
    handler(stream, idents, msg)
  File "/usr/local/lib/python2.7/dist-packages/ipykernel/kernelbase.py", line 370, in execute_request
    user_expressions, allow_stdin)
  File "/usr/local/lib/python2.7/dist-packages/ipykernel/ipkernel.py", line 175, in do_execute
    shell.run_cell(code, store_history=store_history, silent=silent)
  File "/usr/local/lib/python2.7/dist-packages/IPython/core/interactiveshell.py", line 2902, in run_cell
    interactivity=interactivity, compiler=compiler, result=result)
  File "/usr/local/lib/python2.7/dist-packages/IPython/core/interactiveshell.py", line 3006, in run_ast_nodes
    if self.run_code(code, result):
  File "/usr/local/lib/python2.7/dist-packages/IPython/core/interactiveshell.py", line 3066, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-139-3aad513e457d>", line 19, in <module>
    pos_embed = tf.nn.embedding_lookup(embed, x) # [batch_size x embed_size]
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/embedding_ops.py", line 50, in embedding_lookup
    return array_ops.gather(params[0], ids, name=name)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/gen_array_ops.py", line 302, in gather
    name=name)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/op_def_library.py", line 639, in apply_op
    op_def=op_def)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 1757, in create_op
    original_op=self._default_original_op, op_def=op_def)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 1008, in __init__
    self._traceback = _extract_stack()


In [None]:
flags = tf.app.flags

flags.DEFINE_string("data_dir", './data/', "Directory which contains data files")

FLAGS = flags.FLAGS

class Options(object):
    def __init__(self):
        self.data_dir = FLAGS.data_dir

def main():
    if not FLAGS.data_dir:
        print("--data_dir must be specified")
        sys.exit(1)

    opts = Options()
    read_data(opts.data_dir)
