### importing the required libraries

In [25]:
import random, collections, math, os, zipfile, time, re
import numpy as np
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior() 

from matplotlib import pylab
%matplotlib inline

from six.moves import range
from six.moves.urllib.request import urlretrieve

### download and extract dataset

In [26]:
dataset_link = 'http://mattmahoney.net/dc/'
zip_file = 'text8.zip'

def data_download(zip_file):
    if not os.path.exists(zip_file):
        zip_file, _ = urlretrieve(dataset_link + zip_file, zip_file)
        print('Dataset downloaded')
    print('Dataset already exists')    
    return None
data_download(zip_file)

Dataset already exists


In [27]:
extracted_folder = 'dataset'

if not os.path.isdir(extracted_folder):
    with zipfile.ZipFile(zip_file) as zf:
        zf.extractall(extracted_folder)
with open('dataset/text8') as ft_:
    full_text = ft_.read()

### Text processing

In [28]:
def text_processing(ft8_text):
    ft8_text = ft8_text.lower()
    ft8_text = ft8_text.replace('.', '<period> ')
    ft8_text = ft8_text.replace(',', '<comma> ')
    ft8_text = ft8_text.replace('""', '<quotation> ')
    ft8_text = ft8_text.replace(';', '<semicolon> ')
    ft8_text = ft8_text.replace('!', '<exclamation> ')
    ft8_text = ft8_text.replace('?', '<question> ')
    ft8_text = ft8_text.replace('(', '<paren_l> ')
    ft8_text = ft8_text.replace(')', '<paren_r> ')
    ft8_text = ft8_text.replace('--', '<hyphen> ')
    ft8_text = ft8_text.replace(':', '<colon> ')
    ft8_text_tokens = ft8_text.split()
    
    return ft8_text_tokens

ft_tokens = text_processing(full_text)

#### remove noise related to word

In [29]:
word_cnt = collections.Counter(ft_tokens)
print(len(word_cnt.values()))

shortlisted_words = [w for w in ft_tokens if word_cnt[w] > 7]
print(shortlisted_words[:10])

print(len(shortlisted_words))
print("Unique ones: ", len(set(shortlisted_words)))

253854
['anarchism', 'originated', 'as', 'a', 'term', 'of', 'abuse', 'first', 'used', 'against']
16616688
Unique ones:  53721


#### create a dictionary of the words present in dataset with their frequency order

In [30]:
def dict_creation(shortlisted_words):
    counts = collections.Counter(shortlisted_words)
    vocabulary = sorted(counts, key=counts.get, reverse=True)
    rev_dictionary_ = {ii: word for ii, word in enumerate(vocabulary)}
    dictionary_ = {word: ii for ii, word in rev_dictionary_.items()}
    return dictionary_, rev_dictionary_

dictionary_, rev_dictionary_ = dict_creation(shortlisted_words)

words_cnt = [dictionary_[word] for word in shortlisted_words]

print(words_cnt[0])

5233


## Let's start with Skip-Gram Model 

### Create a threshold and perform the subsampling

In [31]:
thresh = 0.00005
word_counts = collections.Counter(words_cnt)
total_count = len(words_cnt)
freqs = {word: count / total_count for word, count in word_counts.items()}
p_drop = {word: 1 - np.sqrt(thresh/freqs[word]) for word in word_counts}

train_words = [word for word in words_cnt if p_drop[word] < random.random()]

print(train_words[0])

5233


In [32]:
def skipG_target_set_generation(batch_, batch_index, word_window): 
    """The function combines the words of given word_window size next to the index, for the SkipGram model"""
    random_num = np.random.randint(1, word_window+1)
    words_start = batch_index - random_num if (batch_index - random_num) > 0 else 0
    words_stop = batch_index + random_num
    window_target = set(batch_[words_start:batch_index] + batch_[batch_index+1:words_stop+1])
    return list(window_target)

In [33]:
def skipG_batch_creation(short_words, batch_length, word_window):
    """The function internally makes use of the skipG_target_set_generation() function and combines each of the label 
    words in the shortlisted_words with the words of word_window size around"""
    batch_cnt = len(short_words)//batch_length
    short_words = short_words[:batch_cnt*batch_length]  
    
    for word_index in range(0, len(short_words), batch_length):
        input_words, label_words = [], []
        word_batch = short_words[word_index:word_index+batch_length]
        for index_ in range(len(word_batch)):
            batch_input = word_batch[index_]
            batch_label = skipG_target_set_generation(word_batch, index_, word_window)
            # Appending the label and inputs to the initial list. Replicating input to the size of labels in the window 
            label_words.extend(batch_label)
            input_words.extend([batch_input]*len(batch_label))
        yield input_words, label_words

In [34]:
tf_graph = tf.Graph()
with tf_graph.as_default():
    input_ = tf.placeholder(tf.int32, [None], name='input_')
    label_ = tf.placeholder(tf.int32, [None, None], name='label_')

with tf_graph.as_default():
    word_embed = tf.Variable(tf.random_uniform((len(rev_dictionary_), 300), -1, 1))
    embedding = tf.nn.embedding_lookup(word_embed, input_)

In [35]:
vocabulary_size = len(rev_dictionary_)

with tf_graph.as_default():
    sf_weights = tf.Variable(tf.truncated_normal((vocabulary_size, 300), stddev=0.1) )
    sf_bias = tf.Variable(tf.zeros(vocabulary_size) )

    loss_fn = tf.nn.sampled_softmax_loss(weights=sf_weights, biases=sf_bias, 
                                         labels=label_, inputs=embedding, 
                                         num_sampled=100, num_classes=vocabulary_size)
    cost_fn = tf.reduce_mean(loss_fn)
    optim = tf.train.AdamOptimizer().minimize(cost_fn)

In [36]:
with tf_graph.as_default():
    validation_cnt = 16
    validation_dict = 100
    
    validation_words = np.array(random.sample(range(validation_dict), validation_cnt//2))
    validation_words = np.append(validation_words, random.sample(range(1000,1000+validation_dict), validation_cnt//2))
    validation_data = tf.constant(validation_words, dtype=tf.int32)

    normalization_embed = word_embed / (tf.sqrt(tf.reduce_sum(tf.square(word_embed), 1, keep_dims=True)))
    validation_embed = tf.nn.embedding_lookup(normalization_embed, validation_data)
    word_similarity = tf.matmul(validation_embed, tf.transpose(normalization_embed))

Instructions for updating:
keep_dims is deprecated, use keepdims instead


In [None]:
epochs = 2            # Increase it as per computation resources. It has been kept low here for users to replicate the process, increase to 100 or more
batch_length = 1000
word_window = 10

with tf_graph.as_default():
    saver = tf.train.Saver()

with tf.Session(graph=tf_graph) as sess:
    iteration = 1
    loss = 0
    sess.run(tf.global_variables_initializer())

    for e in range(1, epochs+1):
        batches = skipG_batch_creation(train_words, batch_length, word_window)
        start = time.time()
        for x, y in batches:
            train_loss, _ = sess.run([cost_fn, optim], 
                                     feed_dict={input_: x, label_: np.array(y)[:, None]})
            loss += train_loss
            
            if iteration % 100 == 0: 
                end = time.time()
                print("Epoch {}/{}".format(e, epochs), ", Iteration: {}".format(iteration),
                      ", Avg. Training loss: {:.4f}".format(loss/100),", Processing : {:.4f} sec/batch".format((end-start)/100))
                loss = 0
                start = time.time()
            
            if iteration % 2000 == 0:
                similarity_ = word_similarity.eval()
                for i in range(validation_cnt):
                    validated_words = rev_dictionary_[validation_words[i]]
                    top_k = 8 # number of nearest neighbors
                    nearest = (-similarity_[i, :]).argsort()[1:top_k+1]
                    log = 'Nearest to %s:' % validated_words
                    for k in range(top_k):
                        close_word = rev_dictionary_[nearest[k]]
                        log = '%s %s,' % (log, close_word)
                    print(log)
            
            iteration += 1
    save_path = saver.save(sess, "model_checkpoint/skipGram_text8.ckpt")
    embed_mat = sess.run(normalization_embed)

Epoch 1/2 , Iteration: 100 , Avg. Training loss: 6.1584 , Processing : 0.2850 sec/batch
Epoch 1/2 , Iteration: 200 , Avg. Training loss: 6.1111 , Processing : 0.2601 sec/batch
Epoch 1/2 , Iteration: 300 , Avg. Training loss: 6.0691 , Processing : 0.2497 sec/batch
Epoch 1/2 , Iteration: 400 , Avg. Training loss: 6.0186 , Processing : 0.2514 sec/batch
Epoch 1/2 , Iteration: 500 , Avg. Training loss: 5.9604 , Processing : 0.2534 sec/batch
Epoch 1/2 , Iteration: 600 , Avg. Training loss: 5.9675 , Processing : 0.2533 sec/batch
Epoch 1/2 , Iteration: 700 , Avg. Training loss: 5.9078 , Processing : 0.2562 sec/batch
Epoch 1/2 , Iteration: 800 , Avg. Training loss: 5.7870 , Processing : 0.2674 sec/batch
Epoch 1/2 , Iteration: 900 , Avg. Training loss: 5.7256 , Processing : 0.2871 sec/batch
Epoch 1/2 , Iteration: 1000 , Avg. Training loss: 5.6398 , Processing : 0.2625 sec/batch
Epoch 1/2 , Iteration: 1100 , Avg. Training loss: 5.4889 , Processing : 0.2663 sec/batch
Epoch 1/2 , Iteration: 1200 , 