In [1]:
import numpy as np
import theano
import six.moves.cPickle
import os, re, json

from keras.preprocessing import sequence, text
from keras.optimizers import SGD, RMSprop, Adagrad
from keras.utils import np_utils, generic_utils
from keras.models import Sequential
from keras.layers.embeddings import WordContextProduct, Embedding
from six.moves import range
from six.moves import zip

Using gpu device 0: GRID K520


In [2]:
max_features = 50000 # vocabulary size: top 50,000 most common words in data
skip_top = 100 # ignore top 100 most common words
nb_epoch = 1
dim_proj = 256 # embedding space dimension

save = True
load_model = False
load_tokenizer = True
train_model = True
save_dir = os.path.expanduser("~/.keras/models")
model_load_fname = "HN_skipgram_model.pkl"
model_save_fname = "HN_skipgram_model.pkl"
tokenizer_fname = "HN_tokenizer.pkl"

data_path = os.path.expanduser("~/")+"HNCommentsAll.1perline.json"

In [3]:
# text preprocessing utils
html_tags = re.compile(r'<.*?>')
to_replace = [('&#x27;', "'")]
hex_tags = re.compile(r'&.*?;')

def clean_comment(comment):
    c = str(comment.encode("utf-8"))
    c = html_tags.sub(' ', c)
    for tag, char in to_replace:
        c = c.replace(tag, char)
    c = hex_tags.sub(' ', c)
    return c

def text_generator(path=data_path):
    f = open(path)
    for i, l in enumerate(f):
        comment_data = json.loads(l)
        comment_text = comment_data["comment_text"]
        comment_text = clean_comment(comment_text)
        if i % 10000 == 0:
            print(i)
        yield comment_text
    f.close()

In [5]:
# model management
if load_tokenizer:
    print('Load tokenizer...')
    tokenizer = six.moves.cPickle.load(open(os.path.join(save_dir, tokenizer_fname), 'rb'))
else:
    print("Fit tokenizer...")
    tokenizer = text.Tokenizer(nb_words=max_features)
    tokenizer.fit_on_texts(text_generator())
    if save:
        print("Save tokenizer...")
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        six.moves.cPickle.dump(tokenizer, open(os.path.join(save_dir, tokenizer_fname), "wb"))


Fit tokenizer...
0
10000
20000
30000
40000
50000
60000
70000
80000
90000
100000
110000
120000
130000
140000
150000
160000
170000
180000
190000
200000
210000
220000
230000
240000
250000
260000
270000
280000
290000
300000
310000
320000
330000
340000
350000
360000
370000
380000
390000
400000
410000
420000
430000
440000
450000
460000
470000
480000
490000
500000
510000
520000
530000
540000
550000
560000
570000
580000
590000
600000
610000
620000
630000
640000
650000
660000
670000
680000
690000
700000
710000
720000
730000
740000
750000
760000
770000
780000
790000
800000
810000
820000
830000
840000
850000
860000
870000
880000
890000
900000
910000
920000
930000
940000
950000
960000
970000
980000
990000
1000000
1010000
1020000
1030000
1040000
1050000
1060000
1070000
1080000
1090000
1100000
1110000
1120000
1130000
1140000
1150000
1160000
1170000
1180000
1190000
1200000
1210000
1220000
1230000
1240000
1250000
1260000
1270000
1280000
1290000
1300000
1310000
1320000
1330000
1340000
1350000
1360000
1

In [8]:
ls ~/.keras/models

HN_tokenizer.pkl


In [4]:
tokenizer = six.moves.cPickle.load(open(os.path.join(save_dir, tokenizer_fname), 'rb'))

In [5]:
tokenizer.document_count

5845908

In [6]:
tokenizer.filters

'!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n'

In [12]:
for key in tokenizer.word_counts.keys():
    print key
    print tokenizer.word_counts[key]
    break

ftdna
5


In [13]:
for key in tokenizer.word_index.keys():
    print key
    print tokenizer.word_index[key]
    break

ftdna
197942


In [14]:
for key in tokenizer.word_docs.keys():
    print key
    print tokenizer.word_docs[key]
    break

ftdna
5


In [15]:
sampling_table = sequence.make_sampling_table(max_features)

for i, seq in enumerate(tokenizer.texts_to_sequences_generator(text_generator())):
    print i
    print seq
    print 
    couples, labels = sequence.skipgrams(seq, max_features, window_size=4, negative_samples=1., sampling_table=sampling_table)
    print couples
    print labels
    break

0
0
[67, 11, 44, 20, 2, 2087, 13, 3, 943, 11399, 2, 587, 1, 140, 11597, 46, 1, 81, 831, 56, 2, 880, 12, 10814, 1197, 22, 11, 91, 2, 277, 53, 9, 839, 55, 61, 5, 41, 10]

[[10814, 9256], [11399, 2], [11399, 13003], [10814, 91], [10814, 880], [943, 1], [91, 48182], [11399, 943], [11399, 13], [943, 4838], [91, 35764], [11399, 2087], [10814, 37241], [11597, 46], [91, 45176], [91, 10814], [943, 49865], [11399, 140], [91, 2], [91, 10042], [943, 13], [11399, 33764], [11597, 27125], [91, 4077], [943, 2], [943, 26688], [11399, 1], [11597, 1], [10814, 28034], [91, 53], [10814, 56], [10814, 22], [10814, 40207], [10814, 12], [10814, 31487], [11597, 20483], [11597, 587], [11399, 587], [11399, 44736], [11597, 1], [943, 2], [11597, 29849], [91, 48270], [11597, 25261], [10814, 11], [91, 44786], [943, 11756], [943, 3], [11399, 36687], [11597, 35004], [91, 1197], [943, 11399], [943, 587], [11597, 6931], [943, 28406], [10814, 45564], [10814, 118], [11399, 13712], [91, 277], [11399, 39915], [943, 29603], [

In [17]:
# training process
if train_model:
    if load_model:
        print('Load model...')
        model = six.moves.cPickle.load(open(os.path.join(save_dir, model_load_fname), 'rb'))
    else:
        print('Build model...')
        model = Sequential()
        model.add(WordContextProduct(max_features, proj_dim=dim_proj, init="uniform"))
        model.compile(loss='mse', optimizer='rmsprop')

    sampling_table = sequence.make_sampling_table(max_features)

    for e in range(nb_epoch):
        print('-'*40)
        print('Epoch', e)
        print('-'*40)

        progbar = generic_utils.Progbar(tokenizer.document_count)
        samples_seen = 0
        losses = []
        
        for i, seq in enumerate(tokenizer.texts_to_sequences_generator(text_generator())):
            # get skipgram couples for one text in the dataset
            couples, labels = sequence.skipgrams(seq, max_features, window_size=4, negative_samples=1., sampling_table=sampling_table)
            if couples:
                # one gradient update per sentence (one sentence = a few 1000s of word couples)
                X = np.array(couples, dtype="int32")
                loss = model.train(X, labels)
                losses.append(loss)
                if len(losses) % 100 == 0:
                    progbar.update(i, values=[("loss", np.mean(losses))])
                    losses = []
                samples_seen += len(labels)
        print('Samples seen:', samples_seen)
    print("Training completed!")

    if save:
        print("Saving model...")
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        six.moves.cPickle.dump(model, open(os.path.join(save_dir, model_save_fname), "wb"))

Build model...
----------------------------------------
('Epoch', 0)
----------------------------------------
0
   9955/5845908 [..............................] - ETA: 86679s - loss: 0.249910000
  19969/5845908 [..............................] - ETA: 86977s - loss: 0.249520000
  29957/5845908 [..............................] - ETA: 87000s - loss: 0.247530000
  39894/5845908 [..............................] - ETA: 86817s - loss: 0.243540000
  49917/5845908 [..............................] - ETA: 86646s - loss: 0.238650000
  59958/5845908 [..............................] - ETA: 86460s - loss: 0.233560000
  69969/5845908 [..............................] - ETA: 86289s - loss: 0.228870000
  79918/5845908 [..............................] - ETA: 86136s - loss: 0.224480000
  89986/5845908 [..............................] - ETA: 85902s - loss: 0.220390000
  99908/5845908 [..............................] - ETA: 85736s - loss: 0.2165100000
 109939/5845908 [..............................] - ETA: 8

In [18]:
# takes 24 housrs

In [19]:
print("It's test time!")

# recover the embedding weights trained with skipgram:
weights = model.layers[0].get_weights()[0]

It's test time!


In [20]:
weights[:skip_top] = np.zeros((skip_top, dim_proj))
norm_weights = np_utils.normalize(weights)

word_index = tokenizer.word_index
reverse_word_index = dict([(v, k) for k, v in list(word_index.items())])
word_index = tokenizer.word_index

In [22]:
def embed_word(w):
    i = word_index.get(w)
    if (not i) or (i<skip_top) or (i>=max_features):
        return None
    return norm_weights[i]

def closest_to_point(point, nb_closest=10):
    proximities = np.dot(norm_weights, point)
    tups = list(zip(list(range(len(proximities))), proximities))
    tups.sort(key=lambda x: x[1], reverse=True)
    return [(reverse_word_index.get(t[0]), t[1]) for t in tups[:nb_closest]]  

def closest_to_word(w, nb_closest=10):
    i = word_index.get(w)
    if (not i) or (i<skip_top) or (i>=max_features):
        return []
    return closest_to_point(norm_weights[i].T, nb_closest)

In [23]:
''' the resuls in comments below were for: 
    5.8M HN comments
    dim_proj = 256
    nb_epoch = 2
    optimizer = rmsprop
    loss = mse
    max_features = 50000
    skip_top = 100
    negative_samples = 1.
    window_size = 4
    and frequency subsampling of factor 10e-5. 
'''

words = ["article", # post, story, hn, read, comments
"3", # 6, 4, 5, 2
"two", # three, few, several, each
"great", # love, nice, working, looking
"data", # information, memory, database
"money", # company, pay, customers, spend
"years", # ago, year, months, hours, week, days
"android", # ios, release, os, mobile, beta
"javascript", # js, css, compiler, library, jquery, ruby
"look", # looks, looking
"business", # industry, professional, customers
"company", # companies, startup, founders, startups
"after", # before, once, until
"own", # personal, our, having
"us", # united, country, american, tech, diversity, usa, china, sv
"using", # javascript, js, tools (lol)
"here", # hn, post, comments
]

for w in words:
    res = closest_to_word(w)
    print('====', w)
    for r in res:
        print(r)

('====', 'article')
('article', 1.0000002)
('post', 0.90891558)
('story', 0.89286608)
('posted', 0.89106327)
('here', 0.8900885)
('comments', 0.88936681)
('reddit', 0.88504016)
('pg', 0.88066208)
('posts', 0.87696922)
('thread', 0.87472731)
('====', '3')
('3', 0.99999988)
('6', 0.94339204)
('9', 0.94330382)
('2', 0.94284344)
('ff', 0.93928117)
('32', 0.93828988)
('24', 0.93781793)
('7', 0.93774015)
('36', 0.93521035)
('released', 0.93484235)
('====', 'two')
('two', 1.0)
('typically', 0.93890905)
('quantity', 0.93832433)
('defining', 0.93728578)
('sustain', 0.93695474)
('evolve', 0.93685579)
('letting', 0.93666261)
('dying', 0.93665802)
('generations', 0.93659782)
('avoiding', 0.93625355)
('====', 'great')
('great', 1.0)
('looking', 0.93422735)
('tell', 0.93256807)
('posting', 0.93174744)
('wish', 0.92921531)
('helpful', 0.92886698)
('informative', 0.92884183)
('cool', 0.92871445)
('idea', 0.92823637)
('fun', 0.92710769)
('====', 'data')
('data', 1.0000002)
('storage', 0.90516901)
('sto

In [25]:
closest_to_word('book')

[('book', 0.99999988),
 ('books', 0.93772721),
 ('tutorial', 0.9375813),
 ('paywall', 0.93704069),
 ('intro', 0.93491739),
 ('screenshots', 0.93375456),
 ('redirects', 0.9337393),
 ('favorite', 0.93335485),
 ('repo', 0.93316227),
 ('ff', 0.9330492)]

In [27]:
closest_to_word('paypal')

[('paypal', 0.99999976),
 ('stripe', 0.94726515),
 ('listing', 0.94713485),
 ('doge', 0.94603837),
 ('belongs', 0.94497705),
 ('coinbase', 0.94476163),
 ('automating', 0.94357979),
 ('3gs', 0.94293594),
 ('heartbleed', 0.9426229),
 ('placement', 0.94246304)]

In [29]:
closest_to_word('iphone')

[('iphone', 1.0),
 ('ipad', 0.93695891),
 ('mac', 0.93089706),
 ('android', 0.92847413),
 ('osx', 0.92499757),
 ('mobile', 0.9189598),
 ('desktop', 0.9188664),
 ('kindle', 0.91885668),
 ('app', 0.9180249),
 ('ios', 0.91645032)]

In [30]:
closest_to_word('samsung')

[('samsung', 1.0000002),
 ('nexus', 0.94058442),
 ('htc', 0.93408191),
 ('motorola', 0.93351519),
 ('shipped', 0.93209493),
 ('droid', 0.92996943),
 ('shows', 0.92826527),
 ('contains', 0.9279933),
 ('salesforce', 0.92773867),
 ('gem', 0.92740226)]

In [31]:
closest_to_word('obama')

[('obama', 0.99999988),
 ('clinton', 0.91198832),
 ('cue', 0.90583628),
 ('florida', 0.90488577),
 ('screencasts', 0.90177119),
 ('pending', 0.90155625),
 ('hits', 0.90109777),
 ('putin', 0.9001627),
 ('groupon', 0.90015745),
 ('esp', 0.90002215)]