WORD2VEC IN TENSORFLOW

In [2]:
import tensorflow as tf
import nltk
import urllib
import collections
import numpy as np
import random

Defining a function to open the text file and create a list of words. The total number of unique words are also printed.

In [3]:
def read_data(filename = 'text8'):
    text_file = open('data/'+str(filename),'r')
    words = text_file.read().split()
    return words

words = read_data()
print(len(set(words))) # Prints the total number of unique words

count = collections.Counter(words) # To get the count of each words and visualize the top 10 most frequent words
count.most_common()[:10]

253854


[('the', 1061396),
 ('of', 593677),
 ('and', 416629),
 ('one', 411764),
 ('in', 372201),
 ('a', 325873),
 ('to', 316376),
 ('zero', 264975),
 ('nine', 250430),
 ('two', 192644)]

Defining a function to create data,count, word_dictionary and reversed dictionary. Rare words (words 
outside our definded vocabulary) are replaced with the 'UNK' token. 'word_dictionary' is a dictionary which maps the actual word to it's integer representation. 'count' will be a list, each element being the word and its number of occurrence.(First element of count is 'UNK' and the count of all words not present in the vocabulary that we define). 'data' is a list which is obtained by substituting each 
word in 'words' with its integer representation (the input data generated from read_data). So basically, we convert words to integer representation. 'reversed_dictionary' is the reverse mapping of 'word_dictionary' (from integer labels to words)

In [5]:
def create_train_data(words,vocab_size):
    # initially setting the first element of 'count' as the 'UNK' token (for rare words) and its occurrence as -1
    # Its occurence will be changed later
    count = [['UNK',-1]] 
    # Populating the count list with the rest of the words as per vocab_size. (We take the most common words)
    count.extend(collections.Counter(words).most_common(vocab_size-1)) # 1 less as the first element is 'UNK'
    
    # Now creating 'word_dictionary' which maps words to integer labels.
    word_dictionary = dict()
    for word,occurence in count:
        word_dictionary[word] = len(word_dictionary) # So it assigns an integer label as the dictionary fills up.
    
    # Now creating 'data' which converts words list to a list of its integer representations
    # We only take words defined in the word_dictionary as we have a fixed vocabulary(vocab_size)
    # All other words are 'UNK'
    data = list()
    unk_count = 0
    for word in words:
        if word in word_dictionary: 
            index = word_dictionary[word]
        else:
            index = 0 # Word not in dictionary, so label is 0 which is 'UNK'
            unk_count += 1 # To get the unknown count at the end to update the first element of 'count'
        data.append(index)
    
    # updating the first element of count. i.e, updating the count of 'UNK'
    count[0][1] = unk_count 
    # To get the reverse mapping
    reversed_dictionary = dict(zip(word_dictionary.values(),word_dictionary.keys())) 
    return data, count, word_dictionary, reversed_dictionary

# Print and check the visualize these objects
data, count, word_dictionary, reversed_dictionary = create_train_data(words,50000)

Now we have our data prepared. Checking the first 10 elements of data and count.

In [6]:
data[:10], count[:10]

([5244, 3081, 12, 6, 195, 2, 3137, 46, 59, 156],
 [['UNK', 418391],
  ('the', 1061396),
  ('of', 593677),
  ('and', 416629),
  ('one', 411764),
  ('in', 372201),
  ('a', 325873),
  ('to', 316376),
  ('zero', 264975),
  ('nine', 250430)])

In [7]:
data_index = 0 
# This is a global variable and is declared as global in the function below to keep track of the index of data so as to generate the next batch.

def generate_batch(batch_size, num_skips, skip_window):
  global data_index
  assert batch_size % num_skips == 0
  assert num_skips <= 2 * skip_window
  batch = np.ndarray(shape=(batch_size), dtype=np.int32)
  labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32)
  span = 2 * skip_window + 1  # [ skip_window target skip_window ]
  buffer = collections.deque(maxlen=span)
  if data_index + span > len(data):
    data_index = 0
  buffer.extend(data[data_index:data_index + span])
  #print("Buffer")
  #print(buffer)
  data_index += span
  for i in range(batch_size // num_skips): 
      # basically i acts as a counter for taking the target word. Batch size is 8. 2 context words are taken,
      # 1 on the left and 1 on the right. So 8/2 = 4. Hence we get 4 target words. Uncomment the print statements to visualize. 
      
    context_words = [w for w in range(span) if w != skip_window]
    random.shuffle(context_words)
    words_to_use = collections.deque(context_words)
    #print("Words to use")
    #print(words_to_use)
    for j in range(num_skips):
      #print("prev Batch")
      #print(batch)
      batch[i * num_skips + j] = buffer[skip_window]  # Target Word. 
      # i*num_skips + j basically acts as a counter for the indices of the batch, just like writing numbers like 11 as (1*10 + 1)
      context_word = words_to_use.pop() 
      # Grab the index of a context word from words_to_use and then pop it from the deque. So now this popped index is stored in context_word
      labels[i * num_skips + j, 0] = buffer[context_word] # Now use this popped word  as the label. 
      # (We are predicting context words from center word)
      #print("Batch")
      #print(batch)
      #print("Labels")
      #print(labels)
    if data_index == len(data):
      buffer[:] = data[:span]
      data_index = span
    else:
      buffer.append(data[data_index])
      data_index += 1
  # Backtrack a little bit to avoid skipping words in the end of a batch
  data_index = (data_index + len(data) - span) % len(data)
  return batch, labels

batch, labels = generate_batch(batch_size=8, num_skips=2, skip_window=1)

Visualizing the current batch

In [20]:
context_words = [w for w in range(3) if w != 1]
print(context_words)
words_to_use = collections.deque(context_words)
words_to_use.pop()
print(words_to_use)

[0, 2]
deque([0])


In [21]:
batch_size = 128
embedding_size = 128  # Dimension of the embedding vector.
skip_window = 1       # How many words to consider left and right.
num_skips = 2         # How many times to reuse an input to generate a label.