# Skip-gram model of Text8 data

In this notebook, I apply the code in this repository to train word embeddings on a skip-gram model of the Text8 data (a few gigabytes of pre-pruned text from Wikipedia of 2006). This is essentially my own and extended implementation of an exercise in Udacity's Deep Learning course by Google. The results will, thus, also be very similar.

In [1]:
import zipfile
import tools as t
from batchgen import ContinuousBatchGenerator
from model import *

  from ._conv import register_converters as _register_converters


Let's read the data first. The .zip file can be downloaded from here if needed: http://mattmahoney.net/dc/textdata

In [2]:
filename = "data/text8.zip"
with zipfile.ZipFile(filename) as f:
    data = f.read(f.namelist()[0]).decode().split()

`tools.DataContainer` helps dealing with operations related to manipulating the data. We use it to convert the data into a list of IDs which each correspond to a unique word. Words that are too rare are replaced with ID=0. The `unknown_token` is only needed if we want to convert back from IDs to words at some point.

In [3]:
datacontainer = t.DataContainer(data, maxwords=100000,
                                unknown_token="_UNKNOWN_")

The generator we choose here takes a continuous list of data and creates an iterator that produces batches of data to the model. We use here the skip-gram model which tries to predict a random single word within a given window from a given word in the text.

In [4]:
datagen = ContinuousBatchGenerator(datacontainer.data, 4, 2, batch_size=64)

The model itself is a rather simple part of the program. It consists of only an embedding lookup table and a single densely connected layer with a softmax activation function. For the skip-gram model, rather long training times are needed to produce good results especially for less common words.

In [24]:
model = W2VModel(100001, 1000, save_path="./weights/skip-gram-text8.save")
model.train(datagen, 500001)
datacontainer.add_embeddings(model.final_embeddings)

INFO:tensorflow:Restoring parameters from ./weights/skip-gram-text8.save
Weights and embeddings loaded from ./weights/skip-gram-text8.save
Starting training with 100000 steps.
Step: 2000: loss = 5.449187862932682
Step: 4000: loss = 5.500022613659501
Step: 6000: loss = 5.471134647727013
Step: 8000: loss = 5.584422599673271
Step: 10000: loss = 5.435981060504913
Step: 12000: loss = 5.484842420578003
Step: 14000: loss = 5.5697263346910475
Step: 16000: loss = 5.311738821268082
Step: 18000: loss = 5.510166151165962
Step: 20000: loss = 5.537298298954964
Step: 22000: loss = 5.570484763622284
Step: 24000: loss = 5.652584486663342
Step: 26000: loss = 5.6008817142248155
Step: 28000: loss = 5.541498960852623
Step: 30000: loss = 5.4085981355905535
Step: 32000: loss = 5.407236380815506
Step: 34000: loss = 5.35655417740345
Step: 36000: loss = 5.563209484457969
Step: 38000: loss = 5.4056957747936245
Step: 40000: loss = 5.012357564091682
Step: 42000: loss = 5.000904550552368
Step: 44000: loss = 5.44974

Let's check the quality of the results by checking the nearest neighbors of some words.

In [28]:
some_words = ['two', 'january', 'russia', 'banana', 'mouse']
for word in some_words:
    print("Words closest to '{}' are:".format(word))
    print(datacontainer.closest_to(word)[1:6])

Words closest to 'two' are:
['zero', 'five', 'four', 'one', 'three']
Words closest to 'january' are:
['july', 'february', 'october', 'april', 'december']
Words closest to 'russia' are:
['italy', 'spain', 'germany', 'china', 'bulgaria']
Words closest to 'banana' are:
['pathos', 'distinctive', 'sadist', 'berman', 'seljuk']
Words closest to 'mouse' are:
['sources', 'call', 'file', 'playing', 'uses']
