# Assignment 1.3: Naive word2vec (40 points)

This task can be formulated very simply. Follow this [paper](https://arxiv.org/pdf/1411.2738.pdf) and implement word2vec like a two-layer neural network with matrices $W$ and $W'$. One matrix projects words to low-dimensional 'hidden' space and the other - back to high-dimensional vocabulary space.

![word2vec](https://i.stack.imgur.com/6eVXZ.jpg)

You can use TensorFlow/PyTorch and code from your previous task.

## Results of this task: (30 points)
 * trained word vectors (mention somewhere, how long it took to train)
 * plotted loss (so we can see that it has converged)
 * function to map token to corresponding word vector
 * beautiful visualizations (PCE, T-SNE), you can use TensorBoard and play with your vectors in 3D (don't forget to add screenshots to the task)

## Extra questions: (10 points)
 * Intrinsic evaluation: you can find datasets [here](http://download.tensorflow.org/data/questions-words.txt)
 * Extrinsic evaluation: you can use [these](https://medium.com/@dataturks/rare-text-classification-open-datasets-9d340c8c508e)

Also, you can find any other datasets for quantitative evaluation.

Again. It is **highly recommended** to read this [paper](https://arxiv.org/pdf/1411.2738.pdf)

Example of visualization in tensorboard:
https://projector.tensorflow.org

Example of 2D visualisation:

![2dword2vec](https://www.tensorflow.org/images/tsne.png)

In [1]:
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim

from pathlib import Path
from pprint import pprint

UNK_TOKEN = '<UNK>'

np.random.seed(4242)
random.seed(4242)

In [2]:
USE_GPU = True

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

device

device(type='cuda')

In [3]:
from collections import Counter


class CBOWBatcher:
    THRESHOLD = 5
    def __init__(self, dataset, window_size=2, threshold=THRESHOLD):
        self.window_size = window_size
        self.threshold = threshold
        self.c = Counter(dataset)
        # all the words we have plus <UNK> token for rare words
        unique = {w for w in dataset if self.c[w] > self.threshold}
        self.word2ind = {w: i for i, w in enumerate(sorted(unique))}
        self.word2ind[UNK_TOKEN] = len(self.word2ind)
        self.ind2word = {i: w for w, i in self.word2ind.items()}
        # We need to store only the numbers of the words here, as we have their numbers already
        # we create a padded array for tokens to process all the words from corpus
        # remove all the uncommon words here
        self.tokens = ([self.word2ind[UNK_TOKEN]] * window_size) +\
            [self.word2ind.get(w, self.word2ind[UNK_TOKEN]) for w in dataset] +\
            ([self.word2ind[UNK_TOKEN]] * window_size)
        self.vocab_size = len(set(self.tokens))
        assert self.vocab_size == len(self.word2ind)
        assert all(t < self.vocab_size for t in self.tokens)
        pprint(f'Corpus size: {len(dataset)}')
        pprint(f'Actual count of words used: {self.vocab_size}')
        pprint(f'{len(dataset)} words in dataset tokenized to {len(self.tokens)} tokens')

    def get_batch(self, batch_size=512):
        X = [None] * batch_size
        y = [None] * batch_size
        current = 0
        for start in np.random.permutation(range(len(self.tokens) - 2 * window_size)):
            center = start + window_size
            X[current] = [self.tokens[i]
                          for i in range(center - window_size, center + window_size + 1) if i != center]
            y[current] = self.tokens[center]
            current += 1
            if current == batch_size:
                # We need the generator, so only `yield ` is an option here
                yield torch.from_numpy(np.asarray(X)).to(device=device),\
                      torch.from_numpy(np.asarray(y)).to(device=device)
                # clean the buffer after we yielded it and we got back our process here
                X = [None] * batch_size
                y = [None] * batch_size
                current = 0
        if current:
            # if batch didn't get to the full size but the corpus ended
            yield torch.from_numpy(np.asarray(X[:current])).to(device=device),\
                  torch.from_numpy(np.asarray(y[:current])).to(device=device)         


In [4]:
test8_Data = Path.cwd() / 'text8'
with test8_Data.open() as f:
    # 1. simple cleaning: lowering all the words
    text8 = [a.lower() for line in f for a in line.split()]
    batcher = CBOWBatcher(text8, threshold=6)

'Corpus size: 17005207'
'Actual count of words used: 58113'
'17005207 words in dataset tokenized to 17005211 tokens'


In [14]:
class CBOWW2V(nn.Module):
    def __init__(self, vocab_size, embedding_size, hidden_size=256, window=2):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embedding_size * window * 2)
        self.relu = nn.ReLU(inplace=False)
        self.W1 = nn.Linear(embedding_size * window * 2, vocab_size)
        nn.init.xavier_normal_(self.W1.weight)

    def forward(self, x):
        # get the embedding by indices
        x = self.embed(x)
        # hidden linear layer
        x = self.relu(x)
        # get the predictions
        x = self.W1(x)
        # we need only 1 word by the given ones
        # here we got 4 options, so let's average them
        return x.mean(dim=1)


def test_CBOWW2V_shapes():
    window_size = 2
    batch_size = 64
    vocab_size = 50
    x = torch.zeros((batch_size, window_size * 2), dtype=torch.long)
    model = CBOWW2V(vocab_size, 42)
    scores = model(x)
    assert scores.size() == torch.Size([batch_size, vocab_size]), scores.size()


test_CBOWW2V_shapes()


In [5]:
from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm


EACH_PRINT = 100
writer = SummaryWriter() 
def train_model(model, optimizer, epochs=1, max_steps=None):
    loss = nn.CrossEntropyLoss()
    for e in range(epochs):
        total_loss = 0
        t = tqdm(batcher.get_batch(1024), desc=f'Epoch {e}')
        for step, (x, y) in enumerate(t):
            if step > max_steps:
                break
            model.train()
            x = x.to(device=device, dtype=torch.long)
            y = y.to(device=device, dtype=torch.long)

            scores = model(x)
            check = loss(scores, y)

            optimizer.zero_grad()
            check.backward()
            optimizer.step()
            total_loss += check.data
            average_loss = float(total_loss / (step + 1))
            writer.add_scalar('Current loss/train', check.data, step)
            writer.add_scalar('Total loss/train', total_loss, step)
            writer.add_scalar('Average loss/train', average_loss, step)
            t.set_postfix(loss=check.data)
            if not step % EACH_PRINT:
                pprint(f'Iteration {step}, current loss = {check.data:.4f}, average loss = {average_loss:.4f}')


In [7]:
# pprint(device)

# learning_rate = 1.568
# embedding_size = 222
# window_size = 2
# model = CBOWW2V(batcher.vocab_size, embedding_size)
# model = model.to(device=device)
# optimizer = optim.ASGD(model.parameters(), lr=learning_rate)

# train_model(model, optimizer)

device(type='cuda')


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 0', max=1.0, style=ProgressStyle(…

'Iteration 0, loss = 10.9621'
'Iteration 100, loss = 8.4109'
'Iteration 200, loss = 8.2269'
'Iteration 300, loss = 9.8921'
'Iteration 400, loss = 10.7657'
'Iteration 500, loss = 7.2920'
'Iteration 600, loss = 7.5492'
'Iteration 700, loss = 8.0928'
'Iteration 800, loss = 10.1871'
'Iteration 900, loss = 7.1046'
'Iteration 1000, loss = 9.2898'
'Iteration 1100, loss = 6.7658'
'Iteration 1200, loss = 7.3498'
'Iteration 1300, loss = 7.7469'
'Iteration 1400, loss = 8.1395'
'Iteration 1500, loss = 7.5857'
'Iteration 1600, loss = 6.9428'
'Iteration 1700, loss = 7.1464'
'Iteration 1800, loss = 7.0798'
'Iteration 1900, loss = 6.7013'
'Iteration 2000, loss = 7.1450'
'Iteration 2100, loss = 6.9293'
'Iteration 2200, loss = 7.1913'
'Iteration 2300, loss = 6.7560'
'Iteration 2400, loss = 6.8238'
'Iteration 2500, loss = 6.8997'
'Iteration 2600, loss = 6.3742'
'Iteration 2700, loss = 6.6186'
'Iteration 2800, loss = 6.8392'
'Iteration 2900, loss = 7.0350'
'Iteration 3000, loss = 6.7030'
'Iteration 3100, 

In [7]:
# results for the model are:
# 16607/? [4:12:24<00:00, 1.10it/s, loss=tensor(6.5550, device='cuda:0')]
# This is 1 epoch on the whole corpus


![Loss](imgs/LossGraph.png)

In [8]:
pprint(device)

learning_rate = 0.01568
embedding_size = 222
window_size = 2
model2 = CBOWW2V(batcher.vocab_size, embedding_size)
model2 = model2.to(device=device)
optimizer2 = optim.ASGD(model2.parameters(), lr=learning_rate)

train_model(model2, optimizer2, epochs=10, max_steps=700)

device(type='cuda')


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 0', max=1.0, style=ProgressStyle(…

'Iteration 0, current loss = 10.9680, average loss = 10.9680'
'Iteration 100, current loss = 9.1175, average loss = 9.9843'
'Iteration 200, current loss = 8.6529, average loss = 9.4582'
'Iteration 300, current loss = 8.4976, average loss = 9.1552'
'Iteration 400, current loss = 8.1041, average loss = 8.9558'
'Iteration 500, current loss = 8.1481, average loss = 8.8054'
'Iteration 600, current loss = 8.1087, average loss = 8.6950'
'Iteration 700, current loss = 8.0470, average loss = 8.6041'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 1', max=1.0, style=ProgressStyle(…

'Iteration 0, current loss = 8.0247, average loss = 8.0247'
'Iteration 100, current loss = 7.9059, average loss = 7.9931'
'Iteration 200, current loss = 8.2303, average loss = 7.9785'
'Iteration 300, current loss = 7.7628, average loss = 7.9552'
'Iteration 400, current loss = 7.8094, average loss = 7.9355'
'Iteration 500, current loss = 7.8529, average loss = 7.9205'
'Iteration 600, current loss = 7.6086, average loss = 7.9012'
'Iteration 700, current loss = 7.6729, average loss = 7.8821'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 2', max=1.0, style=ProgressStyle(…

'Iteration 0, current loss = 7.8019, average loss = 7.8019'
'Iteration 100, current loss = 7.7943, average loss = 7.7382'
'Iteration 200, current loss = 7.8658, average loss = 7.7205'
'Iteration 300, current loss = 7.5498, average loss = 7.7065'
'Iteration 400, current loss = 7.7082, average loss = 7.7016'
'Iteration 500, current loss = 7.6759, average loss = 7.6934'
'Iteration 600, current loss = 7.6919, average loss = 7.6850'
'Iteration 700, current loss = 7.7010, average loss = 7.6764'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 3', max=1.0, style=ProgressStyle(…

'Iteration 0, current loss = 7.4418, average loss = 7.4418'
'Iteration 100, current loss = 7.6411, average loss = 7.5812'
'Iteration 200, current loss = 7.5978, average loss = 7.5744'
'Iteration 300, current loss = 7.5095, average loss = 7.5723'
'Iteration 400, current loss = 7.4470, average loss = 7.5719'
'Iteration 500, current loss = 7.5134, average loss = 7.5641'
'Iteration 600, current loss = 7.3744, average loss = 7.5571'
'Iteration 700, current loss = 7.7160, average loss = 7.5485'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 4', max=1.0, style=ProgressStyle(…

'Iteration 0, current loss = 7.4782, average loss = 7.4782'
'Iteration 100, current loss = 7.3805, average loss = 7.4965'
'Iteration 200, current loss = 7.2169, average loss = 7.4731'
'Iteration 300, current loss = 7.5933, average loss = 7.4820'
'Iteration 400, current loss = 7.6021, average loss = 7.4738'
'Iteration 500, current loss = 7.3354, average loss = 7.4703'
'Iteration 600, current loss = 7.4310, average loss = 7.4674'
'Iteration 700, current loss = 7.4962, average loss = 7.4598'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 5', max=1.0, style=ProgressStyle(…

'Iteration 0, current loss = 7.4099, average loss = 7.4099'
'Iteration 100, current loss = 7.4543, average loss = 7.4244'
'Iteration 200, current loss = 7.4628, average loss = 7.4186'
'Iteration 300, current loss = 7.3844, average loss = 7.4103'
'Iteration 400, current loss = 7.3773, average loss = 7.4003'
'Iteration 500, current loss = 7.1500, average loss = 7.3967'
'Iteration 600, current loss = 7.3995, average loss = 7.3919'
'Iteration 700, current loss = 7.4419, average loss = 7.3897'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 6', max=1.0, style=ProgressStyle(…

'Iteration 0, current loss = 7.3353, average loss = 7.3353'
'Iteration 100, current loss = 7.2744, average loss = 7.3554'
'Iteration 200, current loss = 7.5885, average loss = 7.3536'
'Iteration 300, current loss = 7.3594, average loss = 7.3457'
'Iteration 400, current loss = 7.3388, average loss = 7.3371'
'Iteration 500, current loss = 7.3997, average loss = 7.3339'
'Iteration 600, current loss = 7.3427, average loss = 7.3323'
'Iteration 700, current loss = 7.4980, average loss = 7.3287'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 7', max=1.0, style=ProgressStyle(…

'Iteration 0, current loss = 7.3640, average loss = 7.3640'
'Iteration 100, current loss = 7.3848, average loss = 7.3240'
'Iteration 200, current loss = 7.3153, average loss = 7.3159'
'Iteration 300, current loss = 7.2021, average loss = 7.3072'
'Iteration 400, current loss = 7.4389, average loss = 7.3004'
'Iteration 500, current loss = 7.2328, average loss = 7.2982'
'Iteration 600, current loss = 7.1257, average loss = 7.2935'
'Iteration 700, current loss = 7.3307, average loss = 7.2913'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 8', max=1.0, style=ProgressStyle(…

'Iteration 0, current loss = 7.3852, average loss = 7.3852'
'Iteration 100, current loss = 7.4382, average loss = 7.2554'
'Iteration 200, current loss = 7.1681, average loss = 7.2598'
'Iteration 300, current loss = 7.1810, average loss = 7.2585'
'Iteration 400, current loss = 7.2736, average loss = 7.2502'
'Iteration 500, current loss = 7.2018, average loss = 7.2486'
'Iteration 600, current loss = 7.2317, average loss = 7.2432'
'Iteration 700, current loss = 7.1801, average loss = 7.2414'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 9', max=1.0, style=ProgressStyle(…

'Iteration 0, current loss = 7.2066, average loss = 7.2066'
'Iteration 100, current loss = 7.3967, average loss = 7.2225'
'Iteration 200, current loss = 7.2669, average loss = 7.2240'
'Iteration 300, current loss = 7.2199, average loss = 7.2141'
'Iteration 400, current loss = 7.2148, average loss = 7.2115'
'Iteration 500, current loss = 6.9840, average loss = 7.2127'
'Iteration 600, current loss = 7.1612, average loss = 7.2111'
'Iteration 700, current loss = 7.0293, average loss = 7.2105'


In [9]:
train_model(model2, optimizer2, epochs=20, max_steps=700)

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 0', max=1.0, style=ProgressStyle(…

'Iteration 0, current loss = 7.1295, average loss = 7.1295'
'Iteration 100, current loss = 7.2120, average loss = 7.1972'
'Iteration 200, current loss = 7.1200, average loss = 7.2000'
'Iteration 300, current loss = 7.2095, average loss = 7.1947'
'Iteration 400, current loss = 7.2956, average loss = 7.1919'
'Iteration 500, current loss = 7.0616, average loss = 7.1899'
'Iteration 600, current loss = 7.1213, average loss = 7.1869'
'Iteration 700, current loss = 7.0558, average loss = 7.1831'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 1', max=1.0, style=ProgressStyle(…

'Iteration 0, current loss = 7.1132, average loss = 7.1132'
'Iteration 100, current loss = 7.0275, average loss = 7.1538'
'Iteration 200, current loss = 7.2400, average loss = 7.1659'
'Iteration 300, current loss = 7.0808, average loss = 7.1666'
'Iteration 400, current loss = 7.1553, average loss = 7.1674'
'Iteration 500, current loss = 7.1802, average loss = 7.1639'
'Iteration 600, current loss = 7.3558, average loss = 7.1584'
'Iteration 700, current loss = 7.1528, average loss = 7.1511'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 2', max=1.0, style=ProgressStyle(…

'Iteration 0, current loss = 7.2398, average loss = 7.2398'
'Iteration 100, current loss = 7.3172, average loss = 7.1277'
'Iteration 200, current loss = 7.2658, average loss = 7.1313'
'Iteration 300, current loss = 7.1282, average loss = 7.1333'
'Iteration 400, current loss = 7.1849, average loss = 7.1337'
'Iteration 500, current loss = 7.1314, average loss = 7.1338'
'Iteration 600, current loss = 7.1601, average loss = 7.1331'
'Iteration 700, current loss = 7.1546, average loss = 7.1313'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 3', max=1.0, style=ProgressStyle(…

'Iteration 0, current loss = 7.1869, average loss = 7.1869'
'Iteration 100, current loss = 7.2182, average loss = 7.0946'
'Iteration 200, current loss = 7.2987, average loss = 7.1055'
'Iteration 300, current loss = 7.2859, average loss = 7.1060'
'Iteration 400, current loss = 7.1160, average loss = 7.1065'
'Iteration 500, current loss = 6.9330, average loss = 7.1038'
'Iteration 600, current loss = 7.0466, average loss = 7.1023'
'Iteration 700, current loss = 7.1770, average loss = 7.1010'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 4', max=1.0, style=ProgressStyle(…

'Iteration 0, current loss = 7.1334, average loss = 7.1334'
'Iteration 100, current loss = 7.2832, average loss = 7.1116'
'Iteration 200, current loss = 7.0400, average loss = 7.0942'
'Iteration 300, current loss = 7.0896, average loss = 7.0943'
'Iteration 400, current loss = 7.1456, average loss = 7.0897'
'Iteration 500, current loss = 6.9408, average loss = 7.0889'
'Iteration 600, current loss = 6.9917, average loss = 7.0860'
'Iteration 700, current loss = 7.0964, average loss = 7.0847'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 5', max=1.0, style=ProgressStyle(…

'Iteration 0, current loss = 7.0194, average loss = 7.0194'
'Iteration 100, current loss = 7.0589, average loss = 7.0737'
'Iteration 200, current loss = 6.9681, average loss = 7.0737'
'Iteration 300, current loss = 6.8842, average loss = 7.0706'
'Iteration 400, current loss = 6.9667, average loss = 7.0719'
'Iteration 500, current loss = 7.0857, average loss = 7.0727'
'Iteration 600, current loss = 7.1763, average loss = 7.0759'
'Iteration 700, current loss = 6.9774, average loss = 7.0775'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 6', max=1.0, style=ProgressStyle(…

'Iteration 0, current loss = 7.1645, average loss = 7.1645'
'Iteration 100, current loss = 7.2422, average loss = 7.0808'
'Iteration 200, current loss = 6.8786, average loss = 7.0740'
'Iteration 300, current loss = 7.1151, average loss = 7.0674'
'Iteration 400, current loss = 7.1330, average loss = 7.0644'
'Iteration 500, current loss = 7.2414, average loss = 7.0602'
'Iteration 600, current loss = 6.9224, average loss = 7.0569'
'Iteration 700, current loss = 7.0394, average loss = 7.0594'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 7', max=1.0, style=ProgressStyle(…

'Iteration 0, current loss = 6.9963, average loss = 6.9963'
'Iteration 100, current loss = 7.0061, average loss = 7.0370'
'Iteration 200, current loss = 7.2579, average loss = 7.0404'
'Iteration 300, current loss = 6.8759, average loss = 7.0413'
'Iteration 400, current loss = 6.9649, average loss = 7.0363'
'Iteration 500, current loss = 7.0348, average loss = 7.0333'
'Iteration 600, current loss = 6.9418, average loss = 7.0368'
'Iteration 700, current loss = 6.9809, average loss = 7.0355'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 8', max=1.0, style=ProgressStyle(…

'Iteration 0, current loss = 7.0157, average loss = 7.0157'
'Iteration 100, current loss = 7.0527, average loss = 7.0305'
'Iteration 200, current loss = 7.0582, average loss = 7.0375'
'Iteration 300, current loss = 7.1487, average loss = 7.0368'
'Iteration 400, current loss = 7.0949, average loss = 7.0383'
'Iteration 500, current loss = 7.1698, average loss = 7.0384'
'Iteration 600, current loss = 6.9820, average loss = 7.0331'
'Iteration 700, current loss = 7.2715, average loss = 7.0313'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 9', max=1.0, style=ProgressStyle(…

'Iteration 0, current loss = 7.0320, average loss = 7.0320'
'Iteration 100, current loss = 7.1789, average loss = 7.0214'
'Iteration 200, current loss = 6.9223, average loss = 7.0115'
'Iteration 300, current loss = 7.0856, average loss = 7.0180'
'Iteration 400, current loss = 7.1150, average loss = 7.0200'
'Iteration 500, current loss = 6.9547, average loss = 7.0209'
'Iteration 600, current loss = 6.9909, average loss = 7.0203'
'Iteration 700, current loss = 6.9866, average loss = 7.0185'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 10', max=1.0, style=ProgressStyle…

'Iteration 0, current loss = 7.1111, average loss = 7.1111'
'Iteration 100, current loss = 7.0652, average loss = 7.0272'
'Iteration 200, current loss = 7.0514, average loss = 7.0172'
'Iteration 300, current loss = 7.2668, average loss = 7.0178'
'Iteration 400, current loss = 7.1276, average loss = 7.0114'
'Iteration 500, current loss = 6.9808, average loss = 7.0094'
'Iteration 600, current loss = 7.0064, average loss = 7.0102'
'Iteration 700, current loss = 6.9056, average loss = 7.0102'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 11', max=1.0, style=ProgressStyle…

'Iteration 0, current loss = 7.0241, average loss = 7.0241'
'Iteration 100, current loss = 7.0751, average loss = 7.0365'
'Iteration 200, current loss = 7.0186, average loss = 7.0205'
'Iteration 300, current loss = 7.1641, average loss = 7.0117'
'Iteration 400, current loss = 6.8292, average loss = 7.0055'
'Iteration 500, current loss = 6.8222, average loss = 7.0018'
'Iteration 600, current loss = 6.9733, average loss = 6.9996'
'Iteration 700, current loss = 6.9183, average loss = 6.9989'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 12', max=1.0, style=ProgressStyle…

'Iteration 0, current loss = 6.9884, average loss = 6.9884'
'Iteration 100, current loss = 7.1014, average loss = 6.9843'
'Iteration 200, current loss = 6.9448, average loss = 6.9908'
'Iteration 300, current loss = 6.9677, average loss = 6.9922'
'Iteration 400, current loss = 7.1003, average loss = 6.9930'
'Iteration 500, current loss = 6.9260, average loss = 6.9919'
'Iteration 600, current loss = 6.9840, average loss = 6.9873'
'Iteration 700, current loss = 7.0020, average loss = 6.9858'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 13', max=1.0, style=ProgressStyle…

'Iteration 0, current loss = 7.0499, average loss = 7.0499'
'Iteration 100, current loss = 6.8055, average loss = 6.9609'
'Iteration 200, current loss = 6.9975, average loss = 6.9665'
'Iteration 300, current loss = 7.2060, average loss = 6.9730'
'Iteration 400, current loss = 6.7319, average loss = 6.9720'
'Iteration 500, current loss = 6.8883, average loss = 6.9760'
'Iteration 600, current loss = 6.9818, average loss = 6.9799'
'Iteration 700, current loss = 7.0498, average loss = 6.9798'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 14', max=1.0, style=ProgressStyle…

'Iteration 0, current loss = 6.9616, average loss = 6.9616'
'Iteration 100, current loss = 6.9267, average loss = 6.9541'
'Iteration 200, current loss = 7.1456, average loss = 6.9608'
'Iteration 300, current loss = 7.0253, average loss = 6.9618'
'Iteration 400, current loss = 6.9814, average loss = 6.9608'
'Iteration 500, current loss = 7.1156, average loss = 6.9604'
'Iteration 600, current loss = 6.8700, average loss = 6.9602'
'Iteration 700, current loss = 6.8378, average loss = 6.9602'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 15', max=1.0, style=ProgressStyle…

'Iteration 0, current loss = 7.0308, average loss = 7.0308'
'Iteration 100, current loss = 6.9007, average loss = 6.9620'
'Iteration 200, current loss = 6.8562, average loss = 6.9646'
'Iteration 300, current loss = 6.9673, average loss = 6.9581'
'Iteration 400, current loss = 6.9321, average loss = 6.9617'
'Iteration 500, current loss = 6.9375, average loss = 6.9621'
'Iteration 600, current loss = 7.0997, average loss = 6.9603'
'Iteration 700, current loss = 6.9635, average loss = 6.9588'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 16', max=1.0, style=ProgressStyle…

'Iteration 0, current loss = 7.0912, average loss = 7.0912'
'Iteration 100, current loss = 7.0080, average loss = 6.9603'
'Iteration 200, current loss = 6.8647, average loss = 6.9471'
'Iteration 300, current loss = 6.8438, average loss = 6.9504'
'Iteration 400, current loss = 6.9325, average loss = 6.9488'
'Iteration 500, current loss = 7.0225, average loss = 6.9509'
'Iteration 600, current loss = 6.6456, average loss = 6.9513'
'Iteration 700, current loss = 6.9428, average loss = 6.9504'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 17', max=1.0, style=ProgressStyle…

'Iteration 0, current loss = 6.9459, average loss = 6.9459'
'Iteration 100, current loss = 7.2064, average loss = 6.9236'
'Iteration 200, current loss = 6.9516, average loss = 6.9321'
'Iteration 300, current loss = 6.8978, average loss = 6.9355'
'Iteration 400, current loss = 6.8991, average loss = 6.9336'
'Iteration 500, current loss = 6.8931, average loss = 6.9392'
'Iteration 600, current loss = 7.0754, average loss = 6.9382'
'Iteration 700, current loss = 7.1194, average loss = 6.9401'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 18', max=1.0, style=ProgressStyle…

'Iteration 0, current loss = 6.9645, average loss = 6.9645'
'Iteration 100, current loss = 6.7213, average loss = 6.9349'
'Iteration 200, current loss = 7.0140, average loss = 6.9411'
'Iteration 300, current loss = 7.0105, average loss = 6.9393'
'Iteration 400, current loss = 6.9580, average loss = 6.9372'
'Iteration 500, current loss = 6.9227, average loss = 6.9350'
'Iteration 600, current loss = 6.9804, average loss = 6.9342'
'Iteration 700, current loss = 6.8111, average loss = 6.9315'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 19', max=1.0, style=ProgressStyle…

'Iteration 0, current loss = 6.9600, average loss = 6.9600'
'Iteration 100, current loss = 7.0127, average loss = 6.9270'
'Iteration 200, current loss = 7.0226, average loss = 6.9347'
'Iteration 300, current loss = 6.7182, average loss = 6.9242'
'Iteration 400, current loss = 6.9684, average loss = 6.9238'
'Iteration 500, current loss = 6.9947, average loss = 6.9228'
'Iteration 600, current loss = 7.0337, average loss = 6.9249'
'Iteration 700, current loss = 6.8715, average loss = 6.9234'


In [12]:
torch.save(model2.state_dict(), 'model2.dict')

In [19]:
class CBOWW2VSparse(nn.Module):
    def __init__(self, vocab_size, embedding_size, hidden_size=256, window=2):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embedding_size * window * 2, sparse=True)
        nn.init.xavier_normal_(self.embed.weight)
        self.relu = nn.ReLU(inplace=False)
        self.W1 = nn.Linear(embedding_size * window * 2, vocab_size)
        nn.init.xavier_normal_(self.W1.weight)

    def forward(self, x):
        # get the embedding by indices
        x = self.embed(x)
        # hidden non-linear layer
        x = self.relu(x)
        # get the predictions
        x = self.W1(x)
        # we need only 1 word by the given ones
        # here we got 4 options, so let's average them
        return x.mean(dim=1)


def test_CBOWW2VSparse_shapes():
    window_size = 2
    batch_size = 64
    vocab_size = 50
    x = torch.zeros((batch_size, window_size * 2), dtype=torch.long)
    model = CBOWW2VSparse(vocab_size, 42)
    scores = model(x)
    assert scores.size() == torch.Size([batch_size, vocab_size]), scores.size()


test_CBOWW2VSparse_shapes()

In [10]:
class MupltipleOptimizer:
    def __init__(self, *opts):
        self.optimiizers = opts

    def zero_grad(self):
        for o in self.optimiizers:
            o.zero_grad()

    def step(self):
        for o in self.optimiizers:
            o.step()

In [23]:
pprint(device)

learning_rate = .01568
embedding_size = 222
window_size = 2
model_sparse = CBOWW2VSparse(batcher.vocab_size, embedding_size)
model_sparse = model_sparse.to(device=device)
optimizer_sparse_sgd = optim.ASGD(model_sparse.parameters(), lr=learning_rate)

optimizer_sparse = optim.SparseAdam([model_sparse.embed.weight], lr=learning_rate)
optimizer_dense = optim.AdamW([model_sparse.W1.weight], lr=learning_rate)
unified = MupltipleOptimizer(optimizer_sparse, optimizer_dense)

train_model(model_sparse, unified, epochs=10, max_steps=700)

device(type='cuda')


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 0', max=1.0, style=ProgressStyle(…

'Iteration 0, current loss = 10.9704, average loss = 10.9704'
'Iteration 100, current loss = 7.3439, average loss = 7.7937'
'Iteration 200, current loss = 7.0171, average loss = 7.4469'
'Iteration 300, current loss = 6.7441, average loss = 7.2650'
'Iteration 400, current loss = 6.8543, average loss = 7.1461'
'Iteration 500, current loss = 6.5714, average loss = 7.0605'
'Iteration 600, current loss = 6.5671, average loss = 6.9928'
'Iteration 700, current loss = 6.6814, average loss = 6.9393'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 1', max=1.0, style=ProgressStyle(…

'Iteration 0, current loss = 6.6316, average loss = 6.6316'
'Iteration 100, current loss = 6.5776, average loss = 6.5133'
'Iteration 200, current loss = 6.3682, average loss = 6.5062'
'Iteration 300, current loss = 6.3670, average loss = 6.4928'
'Iteration 400, current loss = 6.3308, average loss = 6.4713'
'Iteration 500, current loss = 6.3482, average loss = 6.4623'
'Iteration 600, current loss = 6.4137, average loss = 6.4473'
'Iteration 700, current loss = 6.4525, average loss = 6.4385'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 2', max=1.0, style=ProgressStyle(…

'Iteration 0, current loss = 6.3551, average loss = 6.3551'
'Iteration 100, current loss = 6.3796, average loss = 6.3051'
'Iteration 200, current loss = 6.3772, average loss = 6.3077'
'Iteration 300, current loss = 6.3886, average loss = 6.3011'
'Iteration 400, current loss = 6.2027, average loss = 6.2983'
'Iteration 500, current loss = 6.1874, average loss = 6.2947'
'Iteration 600, current loss = 6.1945, average loss = 6.2893'
'Iteration 700, current loss = 6.1623, average loss = 6.2862'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 3', max=1.0, style=ProgressStyle(…

'Iteration 0, current loss = 6.0598, average loss = 6.0598'
'Iteration 100, current loss = 6.2218, average loss = 6.2008'
'Iteration 200, current loss = 6.1234, average loss = 6.2042'
'Iteration 300, current loss = 6.1225, average loss = 6.1958'
'Iteration 400, current loss = 6.2250, average loss = 6.1939'
'Iteration 500, current loss = 6.0521, average loss = 6.1937'
'Iteration 600, current loss = 6.1842, average loss = 6.1885'
'Iteration 700, current loss = 5.9642, average loss = 6.1820'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 4', max=1.0, style=ProgressStyle(…

'Iteration 0, current loss = 5.9090, average loss = 5.9090'
'Iteration 100, current loss = 6.2164, average loss = 6.1049'
'Iteration 200, current loss = 6.1861, average loss = 6.1118'
'Iteration 300, current loss = 5.9330, average loss = 6.1078'
'Iteration 400, current loss = 6.0642, average loss = 6.1032'
'Iteration 500, current loss = 6.2734, average loss = 6.0990'
'Iteration 600, current loss = 6.3212, average loss = 6.0933'
'Iteration 700, current loss = 6.1814, average loss = 6.0868'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 5', max=1.0, style=ProgressStyle(…

'Iteration 0, current loss = 6.0261, average loss = 6.0261'
'Iteration 100, current loss = 6.2415, average loss = 6.0043'
'Iteration 200, current loss = 6.0821, average loss = 6.0065'
'Iteration 300, current loss = 6.0680, average loss = 6.0058'
'Iteration 400, current loss = 6.0997, average loss = 6.0058'
'Iteration 500, current loss = 5.9992, average loss = 6.0046'
'Iteration 600, current loss = 6.0252, average loss = 6.0068'
'Iteration 700, current loss = 6.0102, average loss = 6.0045'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 6', max=1.0, style=ProgressStyle(…

'Iteration 0, current loss = 6.0180, average loss = 6.0180'
'Iteration 100, current loss = 5.9538, average loss = 5.9215'
'Iteration 200, current loss = 6.0872, average loss = 5.9273'
'Iteration 300, current loss = 6.0350, average loss = 5.9292'
'Iteration 400, current loss = 5.8667, average loss = 5.9294'
'Iteration 500, current loss = 5.7535, average loss = 5.9255'
'Iteration 600, current loss = 5.7600, average loss = 5.9244'
'Iteration 700, current loss = 6.1275, average loss = 5.9233'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 7', max=1.0, style=ProgressStyle(…

'Iteration 0, current loss = 5.9308, average loss = 5.9308'
'Iteration 100, current loss = 6.0055, average loss = 5.8864'
'Iteration 200, current loss = 5.7475, average loss = 5.8746'
'Iteration 300, current loss = 5.9079, average loss = 5.8729'
'Iteration 400, current loss = 5.8566, average loss = 5.8675'
'Iteration 500, current loss = 5.7602, average loss = 5.8705'
'Iteration 600, current loss = 5.7143, average loss = 5.8695'
'Iteration 700, current loss = 5.8394, average loss = 5.8674'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 8', max=1.0, style=ProgressStyle(…

'Iteration 0, current loss = 5.7799, average loss = 5.7799'
'Iteration 100, current loss = 5.9232, average loss = 5.8180'
'Iteration 200, current loss = 5.7526, average loss = 5.8208'
'Iteration 300, current loss = 5.5690, average loss = 5.8171'
'Iteration 400, current loss = 5.7215, average loss = 5.8164'
'Iteration 500, current loss = 5.6684, average loss = 5.8137'
'Iteration 600, current loss = 5.8820, average loss = 5.8154'
'Iteration 700, current loss = 5.6981, average loss = 5.8168'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 9', max=1.0, style=ProgressStyle(…

'Iteration 0, current loss = 5.9115, average loss = 5.9115'
'Iteration 100, current loss = 5.8279, average loss = 5.7590'
'Iteration 200, current loss = 5.7010, average loss = 5.7584'
'Iteration 300, current loss = 5.8210, average loss = 5.7569'
'Iteration 400, current loss = 5.9012, average loss = 5.7576'
'Iteration 500, current loss = 5.9202, average loss = 5.7575'
'Iteration 600, current loss = 5.7001, average loss = 5.7524'
'Iteration 700, current loss = 5.8715, average loss = 5.7553'


In [24]:
train_model(model_sparse, unified, epochs=20, max_steps=700)

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 0', max=1.0, style=ProgressStyle(…

'Iteration 0, current loss = 5.5959, average loss = 5.5959'
'Iteration 100, current loss = 5.6760, average loss = 5.7075'
'Iteration 200, current loss = 5.8361, average loss = 5.7002'
'Iteration 300, current loss = 5.6588, average loss = 5.7048'
'Iteration 400, current loss = 5.5189, average loss = 5.7042'
'Iteration 500, current loss = 5.7943, average loss = 5.7035'
'Iteration 600, current loss = 5.5454, average loss = 5.7083'
'Iteration 700, current loss = 5.6094, average loss = 5.7048'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 1', max=1.0, style=ProgressStyle(…

'Iteration 0, current loss = 5.5468, average loss = 5.5468'
'Iteration 100, current loss = 5.6099, average loss = 5.6823'
'Iteration 200, current loss = 5.5720, average loss = 5.6834'
'Iteration 300, current loss = 5.7842, average loss = 5.6819'
'Iteration 400, current loss = 5.7390, average loss = 5.6833'
'Iteration 500, current loss = 5.8492, average loss = 5.6816'
'Iteration 600, current loss = 5.6568, average loss = 5.6775'
'Iteration 700, current loss = 5.8624, average loss = 5.6767'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 2', max=1.0, style=ProgressStyle(…

'Iteration 0, current loss = 5.6205, average loss = 5.6205'
'Iteration 100, current loss = 5.7074, average loss = 5.6307'
'Iteration 200, current loss = 5.5815, average loss = 5.6292'
'Iteration 300, current loss = 5.8131, average loss = 5.6379'
'Iteration 400, current loss = 5.7137, average loss = 5.6431'
'Iteration 500, current loss = 5.5229, average loss = 5.6414'
'Iteration 600, current loss = 5.6599, average loss = 5.6417'
'Iteration 700, current loss = 5.6678, average loss = 5.6433'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 3', max=1.0, style=ProgressStyle(…

'Iteration 0, current loss = 5.5268, average loss = 5.5268'
'Iteration 100, current loss = 5.5974, average loss = 5.6233'
'Iteration 200, current loss = 5.5486, average loss = 5.6114'
'Iteration 300, current loss = 5.8275, average loss = 5.6108'
'Iteration 400, current loss = 5.7025, average loss = 5.6074'
'Iteration 500, current loss = 5.5937, average loss = 5.6051'
'Iteration 600, current loss = 5.6032, average loss = 5.6103'
'Iteration 700, current loss = 5.4337, average loss = 5.6111'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 4', max=1.0, style=ProgressStyle(…

'Iteration 0, current loss = 5.4827, average loss = 5.4827'
'Iteration 100, current loss = 5.6601, average loss = 5.6020'
'Iteration 200, current loss = 5.5506, average loss = 5.5916'
'Iteration 300, current loss = 5.6371, average loss = 5.5821'
'Iteration 400, current loss = 5.5776, average loss = 5.5852'
'Iteration 500, current loss = 5.4829, average loss = 5.5832'
'Iteration 600, current loss = 5.6198, average loss = 5.5810'
'Iteration 700, current loss = 5.4474, average loss = 5.5822'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 5', max=1.0, style=ProgressStyle(…

'Iteration 0, current loss = 5.5171, average loss = 5.5171'
'Iteration 100, current loss = 5.4627, average loss = 5.5592'
'Iteration 200, current loss = 5.4842, average loss = 5.5600'
'Iteration 300, current loss = 5.4610, average loss = 5.5552'
'Iteration 400, current loss = 5.4366, average loss = 5.5563'
'Iteration 500, current loss = 5.5516, average loss = 5.5581'
'Iteration 600, current loss = 5.5926, average loss = 5.5569'
'Iteration 700, current loss = 5.5448, average loss = 5.5594'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 6', max=1.0, style=ProgressStyle(…

'Iteration 0, current loss = 5.5375, average loss = 5.5375'
'Iteration 100, current loss = 5.3871, average loss = 5.5236'
'Iteration 200, current loss = 5.4896, average loss = 5.5293'
'Iteration 300, current loss = 5.6773, average loss = 5.5268'
'Iteration 400, current loss = 5.3307, average loss = 5.5277'
'Iteration 500, current loss = 5.5461, average loss = 5.5320'
'Iteration 600, current loss = 5.6761, average loss = 5.5357'
'Iteration 700, current loss = 5.8169, average loss = 5.5382'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 7', max=1.0, style=ProgressStyle(…

'Iteration 0, current loss = 5.5612, average loss = 5.5612'
'Iteration 100, current loss = 5.4203, average loss = 5.5011'
'Iteration 200, current loss = 5.4915, average loss = 5.5081'
'Iteration 300, current loss = 5.5567, average loss = 5.5082'
'Iteration 400, current loss = 5.5004, average loss = 5.5114'
'Iteration 500, current loss = 5.6424, average loss = 5.5105'
'Iteration 600, current loss = 5.3507, average loss = 5.5101'
'Iteration 700, current loss = 5.5856, average loss = 5.5103'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 8', max=1.0, style=ProgressStyle(…

'Iteration 0, current loss = 5.5607, average loss = 5.5607'
'Iteration 100, current loss = 5.3913, average loss = 5.4960'
'Iteration 200, current loss = 5.4474, average loss = 5.5029'
'Iteration 300, current loss = 5.4376, average loss = 5.4962'
'Iteration 400, current loss = 5.6037, average loss = 5.5000'
'Iteration 500, current loss = 5.5136, average loss = 5.5019'
'Iteration 600, current loss = 5.5337, average loss = 5.5030'
'Iteration 700, current loss = 5.3739, average loss = 5.5021'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 9', max=1.0, style=ProgressStyle(…

'Iteration 0, current loss = 5.7752, average loss = 5.7752'
'Iteration 100, current loss = 5.3922, average loss = 5.4920'
'Iteration 200, current loss = 5.4338, average loss = 5.4733'
'Iteration 300, current loss = 5.2856, average loss = 5.4789'
'Iteration 400, current loss = 5.4065, average loss = 5.4809'
'Iteration 500, current loss = 5.6303, average loss = 5.4840'
'Iteration 600, current loss = 5.2364, average loss = 5.4838'
'Iteration 700, current loss = 5.6276, average loss = 5.4833'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 10', max=1.0, style=ProgressStyle…

'Iteration 0, current loss = 5.4897, average loss = 5.4897'
'Iteration 100, current loss = 5.6393, average loss = 5.4518'
'Iteration 200, current loss = 5.4222, average loss = 5.4574'
'Iteration 300, current loss = 5.4602, average loss = 5.4601'
'Iteration 400, current loss = 5.4505, average loss = 5.4563'
'Iteration 500, current loss = 5.3720, average loss = 5.4657'
'Iteration 600, current loss = 5.5528, average loss = 5.4644'
'Iteration 700, current loss = 5.6596, average loss = 5.4658'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 11', max=1.0, style=ProgressStyle…

'Iteration 0, current loss = 5.4101, average loss = 5.4101'
'Iteration 100, current loss = 5.2903, average loss = 5.4442'
'Iteration 200, current loss = 5.4560, average loss = 5.4433'
'Iteration 300, current loss = 5.5116, average loss = 5.4425'
'Iteration 400, current loss = 5.4492, average loss = 5.4454'
'Iteration 500, current loss = 5.2834, average loss = 5.4467'
'Iteration 600, current loss = 5.4853, average loss = 5.4469'
'Iteration 700, current loss = 5.4545, average loss = 5.4472'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 12', max=1.0, style=ProgressStyle…

'Iteration 0, current loss = 5.5172, average loss = 5.5172'
'Iteration 100, current loss = 5.5324, average loss = 5.4248'
'Iteration 200, current loss = 5.3916, average loss = 5.4267'
'Iteration 300, current loss = 5.3900, average loss = 5.4284'
'Iteration 400, current loss = 5.4383, average loss = 5.4305'
'Iteration 500, current loss = 5.5002, average loss = 5.4331'
'Iteration 600, current loss = 5.5623, average loss = 5.4387'
'Iteration 700, current loss = 5.3103, average loss = 5.4390'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 13', max=1.0, style=ProgressStyle…

'Iteration 0, current loss = 5.5202, average loss = 5.5202'
'Iteration 100, current loss = 5.2835, average loss = 5.4227'
'Iteration 200, current loss = 5.5624, average loss = 5.4299'
'Iteration 300, current loss = 5.3550, average loss = 5.4299'
'Iteration 400, current loss = 5.3757, average loss = 5.4354'
'Iteration 500, current loss = 5.4357, average loss = 5.4352'
'Iteration 600, current loss = 5.1937, average loss = 5.4350'
'Iteration 700, current loss = 5.5063, average loss = 5.4353'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 14', max=1.0, style=ProgressStyle…

'Iteration 0, current loss = 5.4458, average loss = 5.4458'
'Iteration 100, current loss = 5.1538, average loss = 5.4054'
'Iteration 200, current loss = 5.2387, average loss = 5.4134'
'Iteration 300, current loss = 5.5732, average loss = 5.4172'
'Iteration 400, current loss = 5.3950, average loss = 5.4204'
'Iteration 500, current loss = 5.4625, average loss = 5.4235'
'Iteration 600, current loss = 5.4129, average loss = 5.4278'
'Iteration 700, current loss = 5.3759, average loss = 5.4278'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 15', max=1.0, style=ProgressStyle…

'Iteration 0, current loss = 5.4432, average loss = 5.4432'
'Iteration 100, current loss = 5.4627, average loss = 5.4066'
'Iteration 200, current loss = 5.4371, average loss = 5.4119'
'Iteration 300, current loss = 5.4437, average loss = 5.4115'
'Iteration 400, current loss = 5.6615, average loss = 5.4164'
'Iteration 500, current loss = 5.6129, average loss = 5.4183'
'Iteration 600, current loss = 5.4566, average loss = 5.4196'
'Iteration 700, current loss = 5.5308, average loss = 5.4228'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 16', max=1.0, style=ProgressStyle…

'Iteration 0, current loss = 5.5361, average loss = 5.5361'
'Iteration 100, current loss = 5.5175, average loss = 5.4087'
'Iteration 200, current loss = 5.3944, average loss = 5.4137'
'Iteration 300, current loss = 5.2948, average loss = 5.4127'
'Iteration 400, current loss = 5.4504, average loss = 5.4202'
'Iteration 500, current loss = 5.2421, average loss = 5.4195'
'Iteration 600, current loss = 5.6143, average loss = 5.4213'
'Iteration 700, current loss = 5.5085, average loss = 5.4219'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 17', max=1.0, style=ProgressStyle…

'Iteration 0, current loss = 5.3135, average loss = 5.3135'
'Iteration 100, current loss = 5.5225, average loss = 5.3963'
'Iteration 200, current loss = 5.3442, average loss = 5.4049'
'Iteration 300, current loss = 5.3848, average loss = 5.4041'
'Iteration 400, current loss = 5.3864, average loss = 5.3995'
'Iteration 500, current loss = 5.3407, average loss = 5.4026'
'Iteration 600, current loss = 5.4646, average loss = 5.4034'
'Iteration 700, current loss = 5.2799, average loss = 5.4073'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 18', max=1.0, style=ProgressStyle…

'Iteration 0, current loss = 5.3225, average loss = 5.3225'
'Iteration 100, current loss = 5.2660, average loss = 5.3756'
'Iteration 200, current loss = 5.4380, average loss = 5.3914'
'Iteration 300, current loss = 5.3265, average loss = 5.3979'
'Iteration 400, current loss = 5.2032, average loss = 5.3957'
'Iteration 500, current loss = 5.3936, average loss = 5.4013'
'Iteration 600, current loss = 5.2801, average loss = 5.4004'
'Iteration 700, current loss = 5.6433, average loss = 5.4042'


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epoch 19', max=1.0, style=ProgressStyle…

'Iteration 0, current loss = 5.3133, average loss = 5.3133'
'Iteration 100, current loss = 5.2537, average loss = 5.3852'
'Iteration 200, current loss = 5.4570, average loss = 5.3927'
'Iteration 300, current loss = 5.2854, average loss = 5.3888'
'Iteration 400, current loss = 5.3531, average loss = 5.3945'
'Iteration 500, current loss = 5.3620, average loss = 5.3961'
'Iteration 600, current loss = 5.3412, average loss = 5.3938'
'Iteration 700, current loss = 5.5851, average loss = 5.3959'


Looks like it can't be better than current state at ~5.4, comparing with previous ~6.9

![Sparse Loss](imgs/Average_loss_train.svg)

In [49]:
def embedding_by_index(ind):
    if ind not in batcher.ind2word:
        return None
    return model_sparse.embed(torch.LongTensor([ind]).to(device))


def embedding_by_word(w):
    if w not in batcher.word2ind:
        return None
    return embedding_by_index(batcher.word2ind[w])

In [89]:
import pandas as pd


STOP_WORDS = {'the', 'of', 'and', 'in', 'a', 'to', 'is', 'as', 'for', 's', 'was', 'by', 'that'}

sample = embedding_by_word('the')
columns = sample.size()[1]

samples = batcher.c.most_common(10000)
df = pd.DataFrame(index=range(len(samples)), columns=[str(i) for i in range(columns)])
for i in range(len(samples)):
    df.iloc[i] = embedding_by_word(samples[i][0]).cpu().detach().numpy()
df['label'] = pd.Series([x[0] for x in samples])
df['color'] = pd.Series(['r' if x[0] in STOP_WORDS else
                         'b' if x[1] > 100000 else
                         'y' if x[1] > 10000 else
                         'g' for x in samples])

In [97]:
df.loc[:, ['label', 'color']].to_csv('metadata.tsv', sep='\t', index=False)
df.iloc[:, :888].to_csv('embeddings.tsv', sep='\t', header=False, index=False)

PCA results via linkes wibsite:

- Numeric words are red, stop words are blue, popular words are pink
- As you can see, they cluster good enough
- Hockey and USSR are near :)
- Nearest word for `husband` is `oh` :))
- Near `criticisms` one can find `nazi` cluster
- Months are clustering - for each month the nearest tensor is another month. Also related words are: `born`, `broadway`, and `voting`, which makes much sense for me. One exception is `may` as it's get clustered with verbs.

![PCA results](imgs/PCA.png)

T-SNE results:

Overall picture after some iterations
![TSNE](imgs/TSNE1.png)

Small cluster of the popular words extracted separately
![TSNE](imgs/TSNE2.png)

AS example, we can examine numeric words, most of them are quite near
![TSNE](imgs/TSNE3.png)

In [114]:
from gensim.models.keyedvectors import WordEmbeddingsKeyedVectors

embedding = WordEmbeddingsKeyedVectors(vector_size=888)
for i, n in enumerate(tqdm(batcher.word2ind)):
    embedding.add(entities=n, weights=embedding_by_word(n).cpu().detach())
    if not i % 100:
        pprint(f'{i}, {n}')

HBox(children=(FloatProgress(value=0.0, max=58113.0), HTML(value='')))

'0, a'
'100, abeda'
'200, abraham'
'300, acadians'
'400, accumulate'
'500, acquires'
'600, adapters'
'700, adjusted'
'800, adrienne'
'900, aerial'
'1000, afield'
'1100, agile'
'1200, aikman'
'1300, akimbo'
'1400, alchemist'
'1500, alfaro'
'1600, allegedly'
'1700, almonds'
'1800, alto'
'1900, ambition'
'2000, ammonite'
'2100, amyraut'
'2200, ancestral'
'2300, angelico'
'2400, annabel'
'2500, antarctica'
'2600, antiproton'
'2700, apatosaurus'
'2800, appeased'
'2900, approximations'
'3000, arbor'
'3100, arduous'
'3200, arlene'
'3300, arrive'
'3400, ascalon'
'3500, aspartic'
'3600, associated'
'3700, asymmetry'
'3800, atrial'
'3900, aucklanders'
'4000, australians'
'4100, autumn'
'4200, award'
'4300, aztlan'
'4400, badb'
'4500, bakunin'
'4600, bamiyan'
'4700, baptized'
'4800, barratry'
'4900, basler'
'5000, battling'
'5100, beans'
'5200, beeching'
'5300, belgica'
'5400, benelux'
'5500, bernadette'
'5600, bevan'
'5700, biennial'
'5800, biodiversity'
'5900, bishop'
'6000, blair'
'6100, blobs

'47000, senegal'
'47100, sequencers'
'47200, seti'
'47300, shadowing'
'47400, shares'
'47500, sheltered'
'47600, shiragami'
'47700, shoulder'
'47800, shyness'
'47900, signaled'
'48000, similarities'
'48100, sinner'
'48200, skel'
'48300, slashes'
'48400, slotted'
'48500, smoot'
'48600, soapbox'
'48700, sold'
'48800, sonar'
'48900, soundness'
'49000, spanning'
'49100, specter'
'49200, spinoffs'
'49300, sporting'
'49400, squaw'
'49500, stainless'
'49600, starfish'
'49700, steadfastly'
'49800, stereotypically'
'49900, stitch'
'50000, stowe'
'50100, strengthen'
'50200, struggles'
'50300, subcarrier'
'50400, subscripts'
'50500, success'
'50600, suited'
'50700, sunnyvale'
'50800, superstructure'
'50900, surname'
'51000, suzuki'
'51100, swinnerton'
'51200, symptoms'
'51300, szl'
'51400, tails'
'51500, tamar'
'51600, taranto'
'51700, taxa'
'51800, teddy'
'51900, temmu'
'52000, terahertz'
'52100, tetra'
'52200, theirs'
'52300, thermionic'
'52400, thoroughgoing'
'52500, thuringiensis'
'52600, til

<gensim.models.keyedvectors.WordEmbeddingsKeyedVectors at 0x14045015160>

In [115]:
embedding.save('keyed_values.dir')

In [132]:
# 0.003 :(
accuracy, result = embedding.evaluate_word_analogies('questions-words.txt')
pprint(accuracy)
for r in result:
    correct_len = len(r['correct'])
    incorrect_len = len(r['incorrect'])
    pprint(f'{r["section"]}: {correct_len} / {(correct_len + incorrect_len)}')

0.0033216679518358145
'capital-common-countries: 6 / 506'
'capital-world: 3 / 3224'
'currency: 0 / 548'
'city-in-state: 4 / 2128'
'family: 28 / 420'
'gram1-adjective-to-adverb: 0 / 992'
'gram2-opposite: 5 / 650'
'gram3-comparative: 5 / 1332'
'gram4-superlative: 0 / 870'
'gram5-present-participle: 0 / 1056'
'gram6-nationality-adjective: 2 / 1521'
'gram7-past-tense: 1 / 1482'
'gram8-plural: 2 / 1260'
'gram9-plural-verbs: 0 / 870'
'Total accuracy: 56 / 16859'


In [129]:
from gensim.test.utils import datapath

pprint(
    (embedding.n_similarity(["king"], ["duke"]),
     embedding.n_similarity(["king"], ["queen"]),
     embedding.most_similar(positive=['woman', 'king'], negative=['man']),
     embedding.n_similarity(['sushi', 'shop'], ['japanese', 'restaurant']),))

embedding.evaluate_word_pairs(datapath('wordsim353.tsv'))

(0.11921588873527891,
 0.19238090419572795,
 [('goddess', 0.26827317476272583),
  ('ivan', 0.25932806730270386),
  ('count', 0.25258105993270874),
  ('successor', 0.2509012818336487),
  ('elder', 0.24425522983074188),
  ('assistant', 0.24366986751556396),
  ('terrified', 0.24023668467998505),
  ('dissolution', 0.24021600186824799),
  ('caesaris', 0.2356097400188446),
  ('audrey', 0.23385654389858246)],
 0.46028097442197746)


((0.1611660995499102, 0.0024933373498505786),
 SpearmanrResult(correlation=0.14772197868394188, pvalue=0.0056238029600451005),
 0.84985835694051)