### Load preprocessed data

If you'd like to play around with this notebook, start by downloading the skipgram dataset from here:

https://www.dropbox.com/s/nd1zxh538o6psal/skipgram_full.npz

WARNING: it's a 1Gb download, so it may take a while!

In [1]:
import numpy as np

codes = np.load("../data/skipgram_full.npz")['coded']
# Remove duplicate skipgrams
codes = codes[codes[:, 0] != codes[:, 1]]
code2token = np.load("../data/skipgram_full.npz")['c2t'].tolist()
token2code = np.load("../data/skipgram_full.npz")['t2c'].tolist()

In [3]:
# First column is the first token code
# second column is the 2nd token code
# third column is the skip gram count
# fourth is PMI * 1e6
codes

array([[  13835,    3257,    4605,  592814],
       [  12071,    3257,      16,  491071],
       [   4136,    3257,       2, -621270],
       ...,
       [  12293,    1390,       1, 1092727],
       [   5103,    1390,       1, 2368132],
       [   6789,    1390,       1,  427689]], dtype=int32)

In [4]:
train_x = codes[:, :2].copy().astype(np.int64)
train_y = codes[:, 3].astype(np.float32) / 1e6
train_y

array([ 0.592814,  0.491071, -0.62127 , ...,  1.092727,  2.368132,
        0.427689], dtype=float32)

In [5]:
train_y.max()

12.09618

In [6]:
top_codes = np.argsort(train_y)[-10:]
[[code2token[c[0]], code2token[c[1]]] for c in codes[top_codes, :2]]

[['norris', 'roundhouse'],
 ['palpatine', 'skywalker'],
 ['palpatine', 'sith'],
 ['roundhouse', 'norris'],
 ['lankan', 'sri'],
 ['palpatine', 'anakin'],
 ['skywalker', 'palpatine'],
 ['anakin', 'palpatine'],
 ['blahblah', 'blah'],
 ['blah', 'blahblah']]

In [7]:
n_user = np.max(train_x[:, :2]) + 1
n_item = np.max(train_x[:, :2]) + 1
n_user

14003

### Define the MF Model

In [16]:
import torch
from torch import nn
import torch.nn.functional as F

def l2_regularize(array):
    loss = torch.sum(array ** 2.0)
    return loss


class MF(nn.Module):
    itr = 0
    
    def __init__(self, n_user, n_item, k=18, c_vector=1.0, c_bias=1.0, writer=None):
        super(MF, self).__init__()
        self.writer = writer
        self.k = k
        self.n_user = n_user
        self.n_item = n_item
        self.c_bias = c_bias
        self.c_vector = c_vector
        self.user = nn.Embedding(n_user, k)
        self.item = nn.Embedding(n_item, k)
        self.user.weight.data.normal_(0, 1.0 / n_user)
        self.item.weight.data.normal_(0, 1.0 / n_item)
        
        # We've added new terms here:
        self.bias_user = nn.Embedding(n_user, 1)
        self.bias_item = nn.Embedding(n_item, 1)
        self.bias = nn.Parameter(torch.ones(1))

    
    def __call__(self, train_x):
        user_id = train_x[:, 0]
        item_id = train_x[:, 1]
        vector_user = self.user(user_id)
        vector_item = self.item(item_id)
        bias_user = self.bias_user(user_id).squeeze()
        bias_item = self.bias_item(item_id).squeeze()
        biases = (self.bias + bias_user + bias_item)
        ui_interaction = torch.sum(vector_user * vector_item, dim=1)
        prediction = ui_interaction + biases
        return prediction
    
    def loss(self, prediction, target):
        loss_mse = F.mse_loss(prediction, target.squeeze())
        prior_bias_user =  l2_regularize(self.bias_user.weight) * self.c_bias
        prior_bias_item = l2_regularize(self.bias_item.weight) * self.c_bias
        prior_user =  l2_regularize(self.user.weight) * self.c_vector
        prior_item = l2_regularize(self.item.weight) * self.c_vector
        total = loss_mse #+ prior_user + prior_item
        for name, var in locals().items():
            if type(var) is torch.Tensor and var.nelement() == 1 and self.writer is not None:
                self.writer.add_scalar(name, var, self.itr)
        return total

### Train model

In [17]:
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Loss
from tensorboardX import SummaryWriter
from ignite.metrics import MeanSquaredError

from loader import Loader
from datetime import datetime

#### Hyperparameters

In [40]:
lr = 1e-3
k = 128
c_bias = 1e-9
c_vector = 1e-9
log_dir = 'runs/simple_mf_05_word2vec_' + str(datetime.now()).replace(' ', '_')
print(log_dir)

runs/simple_mf_05_word2vec_2018-08-23_02:37:14.471432


In [41]:
writer = SummaryWriter(log_dir=log_dir)
model = MF(n_user, n_item,  k=k, c_bias=c_bias, 
           c_vector=c_vector, writer=writer)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
trainer = create_supervised_trainer(model, optimizer, model.loss)
metrics = {'accuracy': MeanSquaredError()}
train_loader = Loader(train_x, train_y, batchsize=1024)


def log_training_loss(engine, log_interval=400):
    epoch = engine.state.epoch
    itr = engine.state.iteration
    fmt = "Epoch[{}] Iteration[{}/{}] Loss: {:.2f}"
    msg = fmt.format(epoch, itr, len(train_loader), engine.state.output)
    model.itr = itr
    if itr % log_interval == 0:
        print(msg)

trainer.add_event_handler(event_name=Events.ITERATION_COMPLETED, handler=log_training_loss)

model

MF(
  (user): Embedding(14003, 128)
  (item): Embedding(14003, 128)
  (bias_user): Embedding(14003, 1)
  (bias_item): Embedding(14003, 1)
)

In [42]:
model.load_state_dict(torch.load("model_05_word2vec"))

#### Run model

In [43]:
trainer.run(train_loader, max_epochs=25)

Epoch[1] Iteration[400/60208] Loss: 1.07
Epoch[1] Iteration[800/60208] Loss: 1.04
Epoch[1] Iteration[1200/60208] Loss: 1.12
Epoch[1] Iteration[1600/60208] Loss: 0.92
Epoch[1] Iteration[2000/60208] Loss: 0.87
Epoch[1] Iteration[2400/60208] Loss: 0.80
Epoch[1] Iteration[2800/60208] Loss: 0.93
Epoch[1] Iteration[3200/60208] Loss: 0.78
Epoch[1] Iteration[3600/60208] Loss: 0.73
Epoch[1] Iteration[4000/60208] Loss: 0.69
Epoch[1] Iteration[4400/60208] Loss: 0.71
Epoch[1] Iteration[4800/60208] Loss: 0.61
Epoch[1] Iteration[5200/60208] Loss: 0.68
Epoch[1] Iteration[5600/60208] Loss: 0.65
Epoch[1] Iteration[6000/60208] Loss: 0.68
Epoch[1] Iteration[6400/60208] Loss: 0.61
Epoch[1] Iteration[6800/60208] Loss: 0.61
Epoch[1] Iteration[7200/60208] Loss: 0.58
Epoch[1] Iteration[7600/60208] Loss: 0.57
Epoch[1] Iteration[8000/60208] Loss: 0.59
Epoch[1] Iteration[8400/60208] Loss: 0.59
Epoch[1] Iteration[8800/60208] Loss: 0.54
Epoch[1] Iteration[9200/60208] Loss: 0.55
Epoch[1] Iteration[9600/60208] Loss:

Epoch[2] Iteration[77200/60208] Loss: 0.37
Epoch[2] Iteration[77600/60208] Loss: 0.34
Epoch[2] Iteration[78000/60208] Loss: 0.32
Epoch[2] Iteration[78400/60208] Loss: 0.34
Epoch[2] Iteration[78800/60208] Loss: 0.38
Epoch[2] Iteration[79200/60208] Loss: 0.33
Epoch[2] Iteration[79600/60208] Loss: 0.34
Epoch[2] Iteration[80000/60208] Loss: 0.34
Epoch[2] Iteration[80400/60208] Loss: 0.34
Epoch[2] Iteration[80800/60208] Loss: 0.34
Epoch[2] Iteration[81200/60208] Loss: 0.33
Epoch[2] Iteration[81600/60208] Loss: 0.35
Epoch[2] Iteration[82000/60208] Loss: 0.37
Epoch[2] Iteration[82400/60208] Loss: 0.32
Epoch[2] Iteration[82800/60208] Loss: 0.37
Epoch[2] Iteration[83200/60208] Loss: 0.35
Epoch[2] Iteration[83600/60208] Loss: 0.35
Epoch[2] Iteration[84000/60208] Loss: 0.36
Epoch[2] Iteration[84400/60208] Loss: 0.36
Epoch[2] Iteration[84800/60208] Loss: 0.35
Epoch[2] Iteration[85200/60208] Loss: 0.35
Epoch[2] Iteration[85600/60208] Loss: 0.35
Epoch[2] Iteration[86000/60208] Loss: 0.37
Epoch[2] It

Epoch[3] Iteration[152400/60208] Loss: 0.32
Epoch[3] Iteration[152800/60208] Loss: 0.37
Epoch[3] Iteration[153200/60208] Loss: 0.35
Epoch[3] Iteration[153600/60208] Loss: 0.34
Epoch[3] Iteration[154000/60208] Loss: 0.33
Epoch[3] Iteration[154400/60208] Loss: 0.32
Epoch[3] Iteration[154800/60208] Loss: 0.30
Epoch[3] Iteration[155200/60208] Loss: 0.34
Epoch[3] Iteration[155600/60208] Loss: 0.32
Epoch[3] Iteration[156000/60208] Loss: 0.35
Epoch[3] Iteration[156400/60208] Loss: 0.33
Epoch[3] Iteration[156800/60208] Loss: 0.31
Epoch[3] Iteration[157200/60208] Loss: 0.35
Epoch[3] Iteration[157600/60208] Loss: 0.33
Epoch[3] Iteration[158000/60208] Loss: 0.37
Epoch[3] Iteration[158400/60208] Loss: 0.35
Epoch[3] Iteration[158800/60208] Loss: 0.34
Epoch[3] Iteration[159200/60208] Loss: 0.37
Epoch[3] Iteration[159600/60208] Loss: 0.34
Epoch[3] Iteration[160000/60208] Loss: 0.36
Epoch[3] Iteration[160400/60208] Loss: 0.32
Epoch[3] Iteration[160800/60208] Loss: 0.34
Epoch[3] Iteration[161200/60208]

Epoch[4] Iteration[227200/60208] Loss: 0.37
Epoch[4] Iteration[227600/60208] Loss: 0.33
Epoch[4] Iteration[228000/60208] Loss: 0.32
Epoch[4] Iteration[228400/60208] Loss: 0.38
Epoch[4] Iteration[228800/60208] Loss: 0.33
Epoch[4] Iteration[229200/60208] Loss: 0.34
Epoch[4] Iteration[229600/60208] Loss: 0.35
Epoch[4] Iteration[230000/60208] Loss: 0.35
Epoch[4] Iteration[230400/60208] Loss: 0.32
Epoch[4] Iteration[230800/60208] Loss: 0.36
Epoch[4] Iteration[231200/60208] Loss: 0.36
Epoch[4] Iteration[231600/60208] Loss: 0.36
Epoch[4] Iteration[232000/60208] Loss: 0.33
Epoch[4] Iteration[232400/60208] Loss: 0.36
Epoch[4] Iteration[232800/60208] Loss: 0.36
Epoch[4] Iteration[233200/60208] Loss: 0.36
Epoch[4] Iteration[233600/60208] Loss: 0.34
Epoch[4] Iteration[234000/60208] Loss: 0.36
Epoch[4] Iteration[234400/60208] Loss: 0.36
Epoch[4] Iteration[234800/60208] Loss: 0.35
Epoch[4] Iteration[235200/60208] Loss: 0.34
Epoch[4] Iteration[235600/60208] Loss: 0.33
Epoch[4] Iteration[236000/60208]

Epoch[6] Iteration[302000/60208] Loss: 0.30
Epoch[6] Iteration[302400/60208] Loss: 0.33
Epoch[6] Iteration[302800/60208] Loss: 0.32
Epoch[6] Iteration[303200/60208] Loss: 0.32
Epoch[6] Iteration[303600/60208] Loss: 0.31
Epoch[6] Iteration[304000/60208] Loss: 0.33
Epoch[6] Iteration[304400/60208] Loss: 0.26
Epoch[6] Iteration[304800/60208] Loss: 0.29
Epoch[6] Iteration[305200/60208] Loss: 0.33
Epoch[6] Iteration[305600/60208] Loss: 0.30
Epoch[6] Iteration[306000/60208] Loss: 0.28
Epoch[6] Iteration[306400/60208] Loss: 0.33
Epoch[6] Iteration[306800/60208] Loss: 0.32
Epoch[6] Iteration[307200/60208] Loss: 0.36
Epoch[6] Iteration[307600/60208] Loss: 0.30
Epoch[6] Iteration[308000/60208] Loss: 0.34
Epoch[6] Iteration[308400/60208] Loss: 0.31
Epoch[6] Iteration[308800/60208] Loss: 0.34
Epoch[6] Iteration[309200/60208] Loss: 0.37
Epoch[6] Iteration[309600/60208] Loss: 0.34
Epoch[6] Iteration[310000/60208] Loss: 0.32
Epoch[6] Iteration[310400/60208] Loss: 0.34
Epoch[6] Iteration[310800/60208]

KeyboardInterrupt: 

In [44]:
torch.save(model.state_dict(), "model_05_word2vec")

#### Save the embeddings

In [45]:
label_token = ['|' + code2token[c] for c in range(n_user)]
writer.add_embedding(model.user.weight)
# writer.add_embedding(model.item.weight, metadata=label_token)

### Introspect the model

Evaluate what urban dictionary thinks are similar words.

In [55]:
vectors_raw = model.user.weight.data.numpy()
vectors = vectors_raw / np.sqrt((vectors_raw**2.0).sum(axis=1)[:, None])

In [56]:
(vectors[0]**2.0).sum()

0.99999994

In [96]:
def find_closest(token, n=10):
    code = token2code[token]
    vector = vectors[code]
    similarity = np.sum(vector[None, :] * vectors, axis=1)
    closest = np.argsort(similarity)[::-1]
    for code in closest[1:n]:
        print(code2token[code], similarity[code])

In [108]:
find_closest('dude')

bro 0.6443894
chick 0.6427469
guy 0.6156572
cool 0.5742106
bitch 0.5732945
chill 0.5504999
wanna 0.5483899
hey 0.53496593
mad 0.5258949


In [115]:
find_closest('netflix')

hookup 0.51350355
makeout 0.4876396
skype 0.46457273
sleepover 0.46274903
spouse 0.4424118
loneliness 0.43613356
reciprocate 0.43607166
infatuation 0.43111098
threesome 0.43035924


In [114]:
find_closest('lol')

wtf 0.6517912
chat 0.59220165
acronym 0.58055633
lmao 0.5773032
omg 0.55501175
somebody 0.551707
haha 0.5454682
abbreviation 0.5257285
fucking 0.5094592


In [112]:
find_closest('hipster')

hipsters 0.8625888
indie 0.67652184
ironic 0.63480437
vintage 0.63287544
trend 0.58198345
thrift 0.58075386
pretentious 0.5771992
conformist 0.56134546
subculture 0.5545582


In [97]:
find_closest('crunk')

hyphy 0.59062827
trashed 0.52575123
shizzle 0.48595563
kool 0.48234645
rhyme 0.4805321
poppin 0.47744057
chillin 0.46551502
wack 0.46236467
hella 0.45885378


In [98]:
find_closest('bromance')

romantically 0.5936279
platonic 0.56198347
butch 0.4824332
glee 0.4821005
payne 0.4686876
hetero 0.46699834
pairing 0.464538
fanfiction 0.46309686
intimacy 0.46035388


In [99]:
find_closest('barbie')

conceited 0.55570185
doll 0.53594625
brunette 0.52610683
anorexic 0.52409136
blond 0.51191294
vain 0.503785
jessica 0.49590558
bimbo 0.49487546
mascara 0.48396283


In [100]:
find_closest('relationship')

romantic 0.6252151
friendship 0.6043344
couple 0.5868452
sexual 0.5751699
boyfriend 0.5647651
marry 0.56442153
date 0.55326617
feeling 0.5497228
friend 0.5419754


In [101]:
find_closest('pope')

orthodox 0.65916073
protestant 0.6568552
salvation 0.6357822
christianity 0.62910753
scripture 0.6278157
bible 0.6104638
catholic 0.60814005
messiah 0.5917543
christ 0.5841886


In [102]:
find_closest('trump')

palin 0.45582134
donald 0.43337244
republicans 0.43279046
shithead 0.41710183
wreak 0.41436175
barack 0.41338554
baller 0.41276947
slander 0.4060114
mel 0.39643386


In [118]:
find_closest('selfie')

selfies 0.6768813
instagram 0.58078086
photo 0.5547765
pic 0.5447346
snapchat 0.54272944
upload 0.52603865
photographer 0.5154379
caption 0.49573278
tweet 0.47855204


### Subtract and add word vectors

In [123]:
def add_subtract(center, minus, plus, n=10):
    vector = (vectors[token2code[center]]
             - vectors[token2code[minus]]
             + vectors[token2code[plus]])
    similarity = np.sum(vector[None, :] * vectors, axis=1)
    closest = np.argsort(similarity)[::-1]
    for code in closest[2:n]:
        print(code2token[code])

In [124]:
add_subtract('burrito', 'mexican', 'italian')

italy
guido
hamburger
spaghetti
cheeseburger
steak
guidos
patty


In [125]:
add_subtract('drunk', 'beer', 'weed')

drunk
ganja
shrooms
chronic
paranoid
pothead
kush
fucked
