In [None]:
#Word2Vec SkipGram Model

In [None]:
import tensorflow as tf
import zipfile
import collections
import numpy as np
import os
import time

In [None]:
batch_size = 128
num_sampled = 64
learning_rate = 1.0
embed_size = 300
vocab_size = 50000
n_epochs = 100001
num_skips = 4
skip_window = 2
file_name = '/home/aeros/GitHubRepos/tensorzone/datasets/text8.zip'
valid_window = 100
valid_size = 16

In [None]:
valid_examples = np.random.choice(valid_window, valid_size)    
valid_dataset = tf.constant(valid_examples, dtype = tf.int32)

In [None]:
class SkipGram(object):
    
    def __init__(self, batch_size, num_sampled, learning_rate, embed_size, vocab_size):
        self.batch_size = batch_size
        self.num_sampled = num_sampled
        self.learning_rate = learning_rate
        self.embed_size = embed_size
        self.vocab_size = vocab_size
        
    def _create_placeholders(self):
        with tf.device("/cpu:0"):
            with tf.name_scope('data'):
                self.center_words = tf.placeholder(shape = [self.batch_size], dtype = tf.int32, 
                                                   name = 'center_words')
                self.target_labels = tf.placeholder(shape = [self.batch_size,1], dtype = tf.int32, 
                                                    name = 'target_labels')
        
    def _create_embeddings(self):
        with tf.device("/cpu:0"):
            with tf.name_scope('embed_matrix'):
                self.embed_matrix = tf.Variable(tf.random_uniform(shape = [self.vocab_size, self.embed_size],
                                                                  minval = -1.0, 
                                                                  maxval = 1.0),
                                                name = 'embed_matrix')
    
    def _create_loss(self):
        with tf.device("/cpu:0"):
            with tf.name_scope('loss'):
                embed = tf.nn.embedding_lookup(self.embed_matrix, self.center_words, name = 'embed')
                nce_weight = tf.Variable(tf.truncated_normal(shape = [self.vocab_size, self.embed_size],
                                                             stddev = 1.0/np.sqrt(self.embed_size)),
                                         name = 'nce_weight'  )
                nce_bias = tf.Variable(tf.zeros(shape = [self.vocab_size]), name = 'nce_bias')
        
                self.loss = tf.reduce_mean(tf.nn.nce_loss(weights = nce_weight,
                                                 biases = nce_bias,
                                                 labels = self.target_labels,
                                                 inputs = embed,
                                                 num_sampled = self.num_sampled,
                                                 num_classes = self.vocab_size,
                                                 ),
                                  name = 'nce_loss')
        
    def _create_similarities(self, valid_dataset):
        norm = tf.sqrt(tf.reduce_sum(tf.square(self.embed_matrix), 1, keep_dims = True))
        normalized_embeddings = self.embed_matrix/norm
        valid_embeddings = tf.nn.embedding_lookup(normalized_embeddings, valid_dataset)
        self.similarity = tf.matmul(valid_embeddings, normalized_embeddings, transpose_b = True)

    def _create_optimizer(self):
        with tf.device("/cpu:0"):
            self.optimizer = tf.train.GradientDescentOptimizer(learning_rate = self.learning_rate).minimize(self.loss)
            
    def _create_summaries(self):
        with tf.name_scope('summary'):
            tf.summary.scalar('loss',self.loss)
            tf.summary.histogram('histogram_loss',self.loss)
            self.summary_op = tf.summary.merge_all()
        
    def build_graph(self):
        self._create_placeholders()
        self._create_embeddings()
        self._create_loss()
        self._create_optimizer()
        self._create_summaries()
        self._create_similarities(valid_dataset)
        
        
    
class GenerateBatch(object):
    
    def __init__(self, file_name, vocab_size, batch_size, num_skips, skip_window):
        self.filename = file_name
        self.vocab_size = vocab_size
        self.batch_size = batch_size
        self.num_skips = num_skips
        self.skip_window = skip_window
        
        
    def _read_data(self):
        with zipfile.ZipFile(self.filename) as f:
             data_words = tf.compat.as_str(f.read(f.namelist()[0])).split()
        return data_words
    
    def _build_dataset(self,words):
        count = [['UNk', -1]]
        count.extend(collections.Counter(words).most_common(self.vocab_size - 1))
        dictionary = {}
        for word, _ in count:
            dictionary[word] = len(dictionary)
        unk_count = 0
        data = []
        for word in words:
            if word in dictionary:
                index = dictionary[word]
            else:
                index = 0
                unk_count += 1
            data.append(index)
        reversed_dictionary = dict(zip(dictionary.values(), dictionary.keys()))
        
        return data, count, dictionary, reversed_dictionary
    
    def _generate_batch(self):
        words = self._read_data()
        data_list, count, dictionary, reversed_dictionary = self._build_dataset(words)
        data_index = 0
        assert self.batch_size % self.num_skips == 0
        assert self.num_skips <=  2 * self.skip_window

        batch = np.ndarray(shape = [self.batch_size], dtype = np.int32)
        labels = np.ndarray(shape = [self.batch_size,1], dtype = np.int32)

        #Spaning window for inital buffer space
        span = 2 * self.skip_window + 1

        #Buffer to hold part of elements in a list
        buffer = collections.deque(maxlen = span)
        for _ in range(span):
            buffer.append(data_list[data_index])
            data_index = (data_index + 1)%len(data_list)

        #Batch and label assignment
        for i in range(self.batch_size//self.num_skips):#batch_size//num_skips gives no. of iterations to update the buffer center
            target = self.skip_window
            targets_to_avoid = [self.skip_window]

            for j in range(self.num_skips):#No of elements to form groups with the target i.e skip_window or buffer_mid
                while target in targets_to_avoid:
                    target = np.random.randint(0,span)
                targets_to_avoid.append(target)
                batch[i*self.num_skips + j] = buffer[self.skip_window]
                labels[i*self.num_skips + j,0] = buffer[target] 
            buffer.append(data_list[data_index])
            data_index = (data_index + 1) % len(data_list)

        return batch, labels, data_list, count, dictionary, reversed_dictionary



def train_model(model, X_batch, Y_batch, n_epochs,data_list, count, dictionary, reversed_dictionary):
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        writer = tf.summary.FileWriter('/home/aeros/GitHubRepos/tensorzone/graphs/skipgram/', sess.graph)
        saver = tf.train.Saver()
        ckpt = tf.train.get_checkpoint_state(os.path.dirname('/home/aeros/GitHubRepos/tensorzone/checkpoints/skipgram/checkpoint'))
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
        loss_sum = 0.0    
        for i in range(n_epochs):
            opt, loss_batch, summary = sess.run([model.optimizer, model.loss, model.summary_op],
                                         feed_dict = {model.center_words: X_batch,
                                                     model.target_labels: Y_batch})
            loss_sum += loss_batch
            writer.add_summary(summary, global_step = i)
            if (i)%2000 == 0:
                if i!=0:
                    print(loss_sum)
                    print( "Loss at step {} : {:5.1f}".format(i, loss_sum/2000))
                    saver.save(sess,'/home/aeros/GitHubRepos/tensorzone/checkpoints/skipgram/', i)
                    loss_sum  = 0.0
        
            if i % 10000 == 0:
                sim = model.similarity.eval()
                for j in range(valid_size):
                    valid_word = reversed_dictionary[valid_examples[j]]
                    top_k = 8
                    nearest = (-sim[j, :]).argsort()[1:top_k + 1]
                    log_str = "nearest to %s:" % valid_word
                    for k in range(top_k):
                        close_word = reversed_dictionary[nearest[k]]
                        log_str = "%s %s," % (log_str, close_word)
                    print(log_str)


def main():
    gen = GenerateBatch(file_name, vocab_size, batch_size, num_skips, skip_window)
    X_batch, Y_batch, data_list, count, dictionary, reversed_dictionary = gen._generate_batch()
    model = SkipGram(batch_size, num_sampled, learning_rate, embed_size, vocab_size)
    model.build_graph()
    X = X_batch.tolist()
    Y = Y_batch.tolist()
#     print([(reversed_dictionary[X[i]],reversed_dictionary[Y[i][0]]) for i in range(len(X))])
    train_model(model,X_batch, Y_batch, n_epochs, data_list, count, dictionary, reversed_dictionary)

In [None]:
main()