In [1]:
import os
import collections
import random
import itertools
import logging

from datetime import datetime
from contextlib import closing

import nltk
import numpy as np

import tensorflow as tf
from tensorflow.contrib.tensorboard.plugins import projector

In [2]:
def timestamp():
    return datetime.now().isoformat()

In [3]:
logging.basicConfig(
    filename='notebook.log',
    level=logging.DEBUG,
    format='%(asctime)-6s: %(name)s - %(levelname)s - %(message)s',
)
logger = logging.getLogger('02-text')

https://github.com/wangz10/UdacityDeepLearning/blob/master/5_word2vec.ipynb

In [4]:
nltk.download('brown')

[nltk_data] Downloading package brown to /home/dnm11/nltk_data...
[nltk_data]   Package brown is already up-to-date!


True

In [5]:
corpus = nltk.corpus.brown

In [6]:
'Total words: {}, total sentences {}'.format(len(corpus.words()), len(corpus.sents()))

'Total words: 1161192, total sentences 57340'

In [7]:
vocab_processor = tf.contrib.learn.preprocessing.VocabularyProcessor(
    max(len(s) for s in corpus.sents()),
    tokenizer_fn=iter,
)

In [8]:
def skip_grams(documents, samples_per_target):
    """Target-context pairs for SkipGram.

    :param int samples_per_target: the number of contexts to be sampled at most for each target word.

    Contexts are drawn from a whole document, in other words, the context window is equal to the document.

    """
    for document in documents:
        left = []
        right = document

        while right:
            target, *right = right

            context = [*left, *right]
            if context:
                sampled_contexts = random.sample(context, k=min(samples_per_target, len(context)))
                for c in sampled_contexts:
                    yield target, c

            left.append(target)


In [9]:
random.seed(2)
list(skip_grams(corpus.sents()[:1], samples_per_target=2))[:10]

[('The', 'County'),
 ('The', 'Grand'),
 ('Fulton', 'Grand'),
 ('Fulton', 'primary'),
 ('County', 'Friday'),
 ('County', '.'),
 ('Grand', 'took'),
 ('Grand', "Atlanta's"),
 ('Jury', 'of'),
 ('Jury', 'any')]

In [10]:
def generate_batches(input_, vocabulary, batch_size=128):
    input_ = iter(input_)

    while True:
        batch = itertools.islice(input_, batch_size)
        
        x = np.array([[vocabulary.vocabulary_.get(w) for w in pair] for pair in batch])

        if len(x):
            yield x
        else:
            break
        

In [11]:
random.seed(2)
next(generate_batches(skip_grams(corpus.sents()[:10], samples_per_target=2), batch_size=6, vocabulary=vocab_processor))

array([[1, 2],
       [1, 3],
       [4, 3],
       [4, 5],
       [2, 6],
       [2, 7]])

In [12]:
vocab_processor.fit(corpus.sents())
len(vocab_processor.vocabulary_)

56058

In [13]:
class Word2Vec:

    def __init__(self, vocabulary_size, embedding_size=300, negative_samples=16):

        self.embeddings = tf.Variable(tf.random_uniform([vocabulary_size, embedding_size], -1.0, 1.0), name='embeddings')

        with tf.name_scope('input'):
            self.batch_inputs = tf.placeholder(tf.int32, shape=None)
            self.batch_labels = tf.placeholder(tf.int32, shape=[None, 1])

            self.batch_embeddings = tf.nn.embedding_lookup(self.embeddings, self.batch_inputs, name='embeddings')

        with tf.name_scope('NCE'):
            self.nce_weights = tf.Variable(
                tf.truncated_normal([vocabulary_size, embedding_size], stddev=1.0 / np.sqrt(embedding_size)),
                name='weights',
            )
            self.nce_biases = tf.Variable(tf.zeros([vocabulary_size]), name='biases')

            self.loss = tf.reduce_mean(
                tf.nn.nce_loss(
                    weights=self.nce_weights,
                    biases=self.nce_biases,
                    labels=self.batch_labels,
                    inputs=self.batch_embeddings,
                    num_sampled=negative_samples,
                    num_classes=vocabulary_size,
                ),
                name='loss',
            )

            self.global_step = tf.Variable(1, name='global_step', trainable=False)
            self.optimizer = tf.train.GradientDescentOptimizer(1.0).minimize(self.loss, name='loss', global_step=self.global_step)

In [None]:
g = tf.Graph()
ts = timestamp()
log_dir = 'notebook_runs'
metadata_path = os.path.join(log_dir, 'metadata.tsv')

batch_size = 1024 * 10

with open(metadata_path, 'w') as f:
    for w in vocab_processor.vocabulary_._reverse_mapping:
        print(w, file=f, end='\n')

with g.as_default():
    w2v = Word2Vec(len(vocab_processor.vocabulary_), embedding_size=300)

    init_op = tf.global_variables_initializer()

    tf.summary.scalar(w2v.loss.op.name, w2v.loss)
    summary_op = tf.summary.merge_all()

    config = projector.ProjectorConfig()
    projector_embedding = config.embeddings.add()
    projector_embedding.tensor_name = w2v.embeddings.name
    projector_embedding.metadata_path = metadata_path

    saver = tf.train.Saver()

with tf.Session(graph=g) as sess, \
     closing(tf.summary.FileWriter(os.path.join(log_dir, 'w2v-skipgram-{}'.format(ts)), sess.graph)) as train_summary_writer:

    sess.run(init_op)
    projector.visualize_embeddings(train_summary_writer, config)
    
    for epoch in range(20):
        logger.info('Epoch: %s', epoch)

        batches = generate_batches(
            skip_grams(corpus.sents(), samples_per_target=16),
            vocabulary=vocab_processor,
            batch_size=batch_size,
        )

        for i, batch in enumerate(batches):
            if len(batch) != batch_size:
                logger.debug('The batch size of %s is smaller than expected %s.', len(batch), batch_size)

            x = batch[:, 0]
            y = batch[:, 1, np.newaxis]

            _, summary, current_step = sess.run(
                [w2v.optimizer, summary_op, w2v.global_step],
                feed_dict={
                    w2v.batch_inputs: x,
                    w2v.batch_labels: y,
                },
            )
            train_summary_writer.add_summary(summary, current_step)
            
            if current_step % 1000 == 0:
                logger.info('Step: %s, saving the model', current_step)
                saver.save(sess, os.path.join(log_dir, "model.ckpt"), current_step)