## Demo notebook on how to train word2vec embeddings following the negative sampling training procedure proposed by Mikolov et. al.;
* The word embeddings are trained on the WikiText2 training dataset;

In [1]:
import numpy as np
import torch
from torch import nn
import random
import get_data as gd
import word2vec_NEG as w2v_neg
import pickle
torch.manual_seed(42)

<torch._C.Generator at 0x7f290b0fd930>

### Get `torch.utils.data.DataLoader` and `vocab.Vocab` objects;

In [2]:
loader, vocab = gd.get_wikitext2_data_neg("../data", vocab_min_freq=5)

In [3]:
print(f"size of WikiText2 vocab is: {len(vocab)}")

size of WikiText2 vocab is: 17027


In [4]:
print("Check the contents of a single batch:\n")
# For everything except centers, the shape should be:
# (batch_size, (2 * window_size + 2 * window_size * num_negatives));
for centers, contexts_and_negatives, coefficients, mask_pads in loader:
    print(centers.shape)
    print(contexts_and_negatives.shape)
    print(coefficients.shape)
    print(mask_pads.shape)
    break

Check the contents of a single batch:

torch.Size([512, 1])
torch.Size([512, 48])
torch.Size([512, 48])
torch.Size([512, 48])


### Set up the model;

In [5]:
class EmbeddingsModelNEG(nn.Module):
    def __init__(self, vocab_size, embed_size, **kwargs):
        super(EmbeddingsModelNEG, self).__init__(**kwargs)
        self.embed_center = nn.Embedding(vocab_size, embed_size)
        self.embed_context = nn.Embedding(vocab_size, embed_size)
    
    def forward(self, centers, contexts_and_negatives, coefficients):
        V = self.embed_center(centers)
        U = self.embed_context(contexts_and_negatives)
        return torch.bmm(V, U.permute(0, 2, 1)) * coefficients.unsqueeze(1)

In [6]:
vocab_size, embed_size = len(vocab), 100
model = EmbeddingsModelNEG(vocab_size, embed_size)

### Get the loss;

In [7]:
def log_sigmoid(x):
    return torch.log(1 / (1 + torch.exp(- x)))

In [8]:
class EmbeddingLossNEG(nn.Module):
    def __init__(self):
        super(EmbeddingLossNEG, self).__init__()
    
    def forward(self, prods, mask_pads):
        return - (log_sigmoid(prods) * mask_pads.unsqueeze(1)).sum()

In [9]:
loss_fn = EmbeddingLossNEG()

### Optimise with Adam;

In [10]:
optim = torch.optim.Adam(model.parameters(), lr=5e-4)

### Define train loop for 1 epoch;

In [11]:
def train_loop(model, loss_fn, optim, loader):
    model.train()
    fifth = max(1, len(loader) // 5)
    batches = len(loader)
    tot_loss = 0.
    for batch, (centers, contexts_and_negatives, coefficients, 
                mask_pads) in enumerate(loader, start=1):
        prods = model(centers, contexts_and_negatives, coefficients)
        loss = loss_fn(prods, mask_pads)
        if batch % fifth == 0:
            print(f"train_loss: {loss.item():.5f}\tprogress: {batch}/{batches}")
        tot_loss += loss.item()
        optim.zero_grad()
        loss.backward()
        optim.step()
    print(f"total train loss: {tot_loss:.5f}\n")

### Start training;

In [12]:
for t in range(50):
    print(f"Epoch {t+1}:\n-------------------------")
    train_loop(model, loss_fn, optim, loader)

Epoch 1:
-------------------------


  Variable._execution_engine.run_backward(


train_loss: 54512.91016	progress: 148/741
train_loss: 53060.00000	progress: 296/741
train_loss: 53230.69531	progress: 444/741
train_loss: 48379.25000	progress: 592/741
train_loss: 48962.58203	progress: 740/741
total train loss: 39111040.25000

Epoch 2:
-------------------------
train_loss: 49894.75000	progress: 148/741
train_loss: 49520.80859	progress: 296/741
train_loss: 45538.88281	progress: 444/741
train_loss: 46069.74219	progress: 592/741
train_loss: 43267.92578	progress: 740/741
total train loss: 35282268.53711

Epoch 3:
-------------------------
train_loss: 43608.15234	progress: 148/741
train_loss: 42912.37109	progress: 296/741
train_loss: 42652.69531	progress: 444/741
train_loss: 42193.89844	progress: 592/741
train_loss: 43532.42188	progress: 740/741
total train loss: 32338957.28711

Epoch 4:
-------------------------
train_loss: 41045.20703	progress: 148/741
train_loss: 41196.01953	progress: 296/741
train_loss: 39322.68750	progress: 444/741
train_loss: 39548.21094	progress: 592

train_loss: 6312.71924	progress: 740/741
total train loss: 4643113.49170

Epoch 31:
-------------------------
train_loss: 6427.05225	progress: 148/741
train_loss: 6036.92285	progress: 296/741
train_loss: 6217.67139	progress: 444/741
train_loss: 5720.13867	progress: 592/741
train_loss: 6021.55664	progress: 740/741
total train loss: 4482233.39197

Epoch 32:
-------------------------
train_loss: 5705.34521	progress: 148/741
train_loss: 5907.08838	progress: 296/741
train_loss: 5981.23535	progress: 444/741
train_loss: 6254.97705	progress: 592/741
train_loss: 5774.42480	progress: 740/741
total train loss: 4334296.56787

Epoch 33:
-------------------------
train_loss: 5668.79004	progress: 148/741
train_loss: 6035.21973	progress: 296/741
train_loss: 6043.55029	progress: 444/741
train_loss: 5540.29785	progress: 592/741
train_loss: 5081.78711	progress: 740/741
total train loss: 4198055.81995

Epoch 34:
-------------------------
train_loss: 5179.75098	progress: 148/741
train_loss: 5567.91943	prog

In [13]:
for t in range(50, 100):
    print(f"Epoch {t+1}:\n-------------------------")
    train_loop(model, loss_fn, optim, loader)

Epoch 51:
-------------------------
train_loss: 3965.02368	progress: 148/741
train_loss: 3837.23486	progress: 296/741
train_loss: 3869.54492	progress: 444/741
train_loss: 3679.48779	progress: 592/741
train_loss: 3939.10498	progress: 740/741
total train loss: 2885157.80493

Epoch 52:
-------------------------
train_loss: 3884.69849	progress: 148/741
train_loss: 3973.03247	progress: 296/741
train_loss: 3890.01270	progress: 444/741
train_loss: 3732.89160	progress: 592/741
train_loss: 3902.59619	progress: 740/741
total train loss: 2846031.04248

Epoch 53:
-------------------------
train_loss: 3710.52319	progress: 148/741
train_loss: 3775.92578	progress: 296/741
train_loss: 3891.46875	progress: 444/741
train_loss: 3876.96191	progress: 592/741
train_loss: 3725.39087	progress: 740/741
total train loss: 2808657.16040

Epoch 54:
-------------------------
train_loss: 3904.47681	progress: 148/741
train_loss: 3773.42065	progress: 296/741
train_loss: 3655.56763	progress: 444/741
train_loss: 3837.68

train_loss: 2778.05737	progress: 148/741
train_loss: 2860.36426	progress: 296/741
train_loss: 2881.45459	progress: 444/741
train_loss: 2975.36572	progress: 592/741
train_loss: 3063.57178	progress: 740/741
total train loss: 2132849.22687

Epoch 82:
-------------------------
train_loss: 2762.65332	progress: 148/741
train_loss: 2816.11426	progress: 296/741
train_loss: 2820.40308	progress: 444/741
train_loss: 2754.27637	progress: 592/741
train_loss: 2675.73853	progress: 740/741
total train loss: 2116256.39459

Epoch 83:
-------------------------
train_loss: 2658.85205	progress: 148/741
train_loss: 2728.62671	progress: 296/741
train_loss: 2929.54443	progress: 444/741
train_loss: 3007.12061	progress: 592/741
train_loss: 2916.16553	progress: 740/741
total train loss: 2099979.73425

Epoch 84:
-------------------------
train_loss: 2730.79443	progress: 148/741
train_loss: 2803.50488	progress: 296/741
train_loss: 2728.15259	progress: 444/741
train_loss: 2892.21509	progress: 592/741
train_loss: 28

### Serialize the embeddings (both center and context) as `numpy` objects;

In [14]:
np.save("embeds_wiki_centers.npy", model.embed_center.weight.data.numpy())
np.save("embeds_wiki_contexts.npy", model.embed_context.weight.data.numpy())