# Skip-gram word embedding 
In this notebook, we train a word embedding model using the skip-gram technique. In particular, we use both the *dot product* trick and the *negative sampling* trick to speed up learning. Given that said, training a word embedding is still highly expensive: you should expect a few hours of training if you want to train a good word embedding (more time would probably be required if you only use CPU!)

In [1]:
# train our model using brown corpus
from nltk.corpus import brown
print(brown.categories())

['adventure', 'belles_lettres', 'editorial', 'fiction', 'government', 'hobbies', 'humor', 'learned', 'lore', 'mystery', 'news', 'religion', 'reviews', 'romance', 'science_fiction']


In [2]:
sents = []
for genre in brown.categories():
    sents += brown.sents(categories=genre)
print(sents[5])
print(len(sents))

['Sometimes', 'he', 'woke', 'up', 'in', 'the', 'middle', 'of', 'the', 'night', 'thinking', 'of', 'Ann', ',', 'and', 'then', 'could', 'not', 'get', 'back', 'to', 'sleep', '.']
57340


In [3]:
# lower-case and stem all tokens

from nltk.stem import PorterStemmer
from tqdm.notebook import tqdm

stemmer = PorterStemmer()

all_tokens = []
train_sents = []

for ss in tqdm(sents):
    train_sents.append([stemmer.stem(token.lower()) for token in ss])
    all_tokens += train_sents[-1]

  0%|          | 0/57340 [00:00<?, ?it/s]

In [4]:
from nltk import FreqDist
fd = FreqDist(all_tokens)
fd.most_common(10)

[('the', 69971),
 (',', 58334),
 ('.', 49346),
 ('of', 36413),
 ('and', 28853),
 ('to', 26158),
 ('a', 23195),
 ('in', 21337),
 ('it', 10618),
 ('that', 10594)]

In [5]:
# build the vocabulary

vocab = list(set(all_tokens))
print('vocab size', len(vocab))
print(vocab[:10])

vocab size 34543
['poshest', '$0.9', 'hungarian-born', 'invigor', 'charg', 'center-punch', 'writ', 'inflect', 'springtim', '1800']


In [6]:
# build dictionaries to help us mapping between words and their indices

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

word_to_idx = {word:idx for idx, word in enumerate(vocab)}
idx_to_word = {idx:word for idx, word in enumerate(vocab)}
embd_dim = 100 # dimension of the wanted word vectors; a strong word vector would require 100 or more dimensions

In [7]:
# example code for computing pairwise dot product
# we will use this trick below in EmbeddingLearner
max1 = torch.randn(5,10)
max2 = torch.randn(5,10)
dot_product = torch.bmm(max1.view(5,1,10),max2.view(5,10,1)).squeeze()
print(dot_product)

tensor([-0.7216,  1.3227,  0.5370, -0.1401,  2.2933])


In [8]:
# create the neural network for learning embeddings
class EmbedingLearner(nn.Module):

    def __init__(self, vocab, embd_dim, device):
        super(EmbedingLearner, self).__init__()
        self.u_vecs = nn.Embedding(len(vocab), embd_dim)
        self.v_vecs = nn.Embedding(len(vocab), embd_dim)
        self.embd_dim = embd_dim
        self.device = device
        self.to(self.device)

    def forward(self, word_pairs_idx):
        pair_num = len(word_pairs_idx)
        center_lookup = torch.tensor([pair[0] for pair in word_pairs_idx], dtype=torch.long).to(self.device)
        context_lookup = torch.tensor([pair[1] for pair in word_pairs_idx], dtype=torch.long).to(self.device)
        center_vecs = self.u_vecs(center_lookup)
        context_vecs = self.v_vecs(context_lookup)
        sim_scores = torch.bmm(center_vecs.view(pair_num, 1, self.embd_dim), context_vecs.view(pair_num, self.embd_dim, 1)).squeeze()
        return sim_scores
        

In [9]:
# function for constructing mini-batches
import random
def get_mini_batch(word_to_idx, tokens, center_idx, win_size):
    vocab_num = len(word_to_idx)
    center_word = tokens[center_idx]
    center_word_idx = word_to_idx[center_word]
    word_pairs = []
    # first we build positive examples, i.e. center words with real context words
    for i in range(center_idx-win_size, center_idx+win_size+1):
        if i < 0 or i >= len(tokens): continue
        if i == center_idx: continue
        context_word = tokens[i]
        word_pairs.append( (center_word_idx, word_to_idx[context_word], 1) )
    # then we use negative sampling to find some non-context words
    pos_examples_num = len(word_pairs)
    context_word_idx = set([tup[1] for tup in word_pairs])
    while len(word_pairs) < 2*pos_examples_num:
        neg_word_idx = random.randint(0,vocab_num-1)
        while neg_word_idx in context_word_idx:
            neg_word_idx = random.randint(0,vocab_num-1)
        word_pairs.append( (center_word_idx, neg_word_idx, -1) )
    return word_pairs

In [10]:
# test the mini_batch constructor
get_mini_batch(word_to_idx, train_sents[10], center_idx=3, win_size=2)

[(4853, 12795, 1),
 (4853, 12360, 1),
 (4853, 9497, 1),
 (4853, 4445, 1),
 (4853, 31447, -1),
 (4853, 19437, -1),
 (4853, 4461, -1),
 (4853, 28007, -1)]

In [11]:
import numpy as np

# initialize the embedding learner
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
embd_learner = EmbedingLearner(vocab, embd_dim, device)

def get_word_vec(embd_learner, word):
    if word not in vocab: return None
    lookup = torch.tensor(word_to_idx[word], dtype=torch.long).to(device)
    u_vec = embd_learner.u_vecs(torch.tensor(lookup)).cpu().detach().numpy()
    v_vec = embd_learner.v_vecs(torch.tensor(lookup)).cpu().detach().numpy()
    return np.mean([v_vec, u_vec], axis=0)

from sklearn.metrics.pairwise import cosine_similarity
def get_most_similar(word, word_vecs):
    similar_list = []
    word_vec = word_vecs[word]
    all_vecs = np.array(list(word_vecs.values())).reshape(-1,embd_dim)
    cos_sims = list(cosine_similarity(word_vec.reshape(1,-1), all_vecs)[0])
    sorted_cos_sims = sorted(cos_sims,reverse=True)
    for v in sorted_cos_sims:
        word_id = cos_sims.index(v)
        similar_list.append((idx_to_word[word_id], v))
    return similar_list
    
# get word vectors (before training)
word_vecs = {}
for word in tqdm(vocab):
    vec = get_word_vec(embd_learner,word)
    if vec is not None:
        word_vecs[word] = vec

# print the most similar words for a given word, to see how the word embeddings work
# note that we have not trained our word embeddings here
# hence the most similar words should be rather random
print(get_most_similar('read', word_vecs)[:10])

  0%|          | 0/34543 [00:00<?, ?it/s]

  # Remove the CWD from sys.path while we load stuff.
  # This is added back by InteractiveShellApp.init_path()


[('read', 0.9999999), ('cabdriv', 0.39111134), ('honest-to-betsi', 0.3885032), ("peabody'", 0.3740201), ('admonit', 0.36070418), ('kahler-craft', 0.35043496), ('pickoff', 0.34816372), ("dartmouth'", 0.34729564), ('barton', 0.34363657), ('folk-danc', 0.34258938)]


In [12]:
#embd_learner = EmbedingLearner(vocab, embd_dim, device)

In [13]:
num_epochs = 3
lr = 1e-5
window_size = 10

# init optimizer
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
import random

optimizer = optim.Adam(params=embd_learner.parameters(), lr=lr)

for epo in tqdm(range(num_epochs), desc='epoch'):
    embd_learner.train()
    epoch_loss = []
    cnt = 0
    for sent in tqdm(train_sents, desc='sent'):
        if len(sent) < 10: continue # skip very short sentences
        cnt += 1
        if (cnt+1)%500 == 0: print('avg loss until step {}: {}'.format(cnt+1, np.mean(epoch_loss)))
        for center_idx in range(0,len(sent)):
            idx_pairs = get_mini_batch(word_to_idx, sent, center_idx, window_size)
            
            # Step 1: Clear the gradients 
            optimizer.zero_grad()

            # Step 2: Compute the forward pass of the model
            sim_scores = embd_learner([(pp[0],pp[1]) for pp in idx_pairs])
            #print('\nprobs', probs)
            #print('log probs', torch.log(probs))

            # Step 3: Compute the loss value that we wish to optimize
            true_labels = torch.tensor([pair[2] for pair in idx_pairs], dtype=torch.float).to(device)
            loss = -torch.dot(sim_scores, true_labels)
            #loss = -torch.dot(torch.log(probs), torch.tensor([pair[2] for pair in idx_pairs], dtype=torch.float).to(device)) # cross-entropy loss
            #loss = -torch.dot(10.*(probs-1), true_labels)

            # Step 4: Propagate the loss signal backward
            loss.backward()

            # Step 5: Trigger the optimizer to perform one update
            optimizer.step()
            
            #print(loss.detach().numpy())
            epoch_loss.append(loss.cpu().detach().numpy())
            
        
            
    print('avg loss in epoch {}: {:4f}'.format(epo, np.mean(epoch_loss)))
            

epoch:   0%|          | 0/3 [00:00<?, ?it/s]

sent:   0%|          | 0/57340 [00:00<?, ?it/s]

avg loss until step 500: 0.5654630661010742
avg loss until step 1000: -0.4153224527835846
avg loss until step 1500: -0.2386307567358017
avg loss until step 2000: -0.6134337782859802
avg loss until step 2500: -0.979987382888794
avg loss until step 3000: -1.3266180753707886
avg loss until step 3500: -1.5214871168136597
avg loss until step 4000: -1.6783417463302612
avg loss until step 4500: -2.046870470046997
avg loss until step 5000: -2.3686416149139404
avg loss until step 5500: -2.6272451877593994
avg loss until step 6000: -3.0577385425567627
avg loss until step 6500: -3.548492193222046
avg loss until step 7000: -4.0123138427734375
avg loss until step 7500: -4.525312423706055
avg loss until step 8000: -4.9944539070129395
avg loss until step 8500: -5.722128391265869
avg loss until step 9000: -6.3657097816467285
avg loss until step 9500: -6.937943458557129
avg loss until step 10000: -7.338605880737305
avg loss until step 10500: -7.801496982574463
avg loss until step 11000: -8.224465370178

sent:   0%|          | 0/57340 [00:00<?, ?it/s]

avg loss until step 500: -277.73687744140625
avg loss until step 1000: -279.0705871582031
avg loss until step 1500: -277.1864013671875
avg loss until step 2000: -275.29010009765625
avg loss until step 2500: -284.4161376953125
avg loss until step 3000: -289.00189208984375
avg loss until step 3500: -299.31292724609375
avg loss until step 4000: -314.55706787109375
avg loss until step 4500: -326.62945556640625
avg loss until step 5000: -336.1578063964844
avg loss until step 5500: -344.5446472167969
avg loss until step 6000: -351.8552551269531
avg loss until step 6500: -357.798828125
avg loss until step 7000: -365.8656921386719
avg loss until step 7500: -371.743896484375
avg loss until step 8000: -380.8100280761719
avg loss until step 8500: -392.4014892578125
avg loss until step 9000: -401.3846435546875
avg loss until step 9500: -409.787353515625
avg loss until step 10000: -414.1379699707031
avg loss until step 10500: -417.41571044921875
avg loss until step 11000: -421.9937744140625
avg los

sent:   0%|          | 0/57340 [00:00<?, ?it/s]

avg loss until step 500: -2067.375244140625
avg loss until step 1000: -2066.0966796875
avg loss until step 1500: -2043.3858642578125
avg loss until step 2000: -2020.3819580078125
avg loss until step 2500: -2070.76123046875
avg loss until step 3000: -2091.785888671875
avg loss until step 3500: -2151.398193359375
avg loss until step 4000: -2243.364501953125
avg loss until step 4500: -2312.21826171875
avg loss until step 5000: -2362.970703125
avg loss until step 5500: -2406.37255859375
avg loss until step 6000: -2439.256591796875
avg loss until step 6500: -2461.867431640625
avg loss until step 7000: -2498.333251953125
avg loss until step 7500: -2518.751220703125
avg loss until step 8000: -2560.49267578125
avg loss until step 8500: -2611.74755859375
avg loss until step 9000: -2646.80615234375
avg loss until step 9500: -2680.222412109375
avg loss until step 10000: -2691.02685546875
avg loss until step 10500: -2695.86669921875
avg loss until step 11000: -2708.39306640625
avg loss until step 

In [14]:
# recompute all words embeddings (after training)
word_vecs = {}
for word in vocab:
    vec = get_word_vec(embd_learner,word)
    if vec is not None:
        word_vecs[word] = vec

  # Remove the CWD from sys.path while we load stuff.
  # This is added back by InteractiveShellApp.init_path()


In [40]:
# see whether the similar words found by the trained word embeddings make sense 
print(get_most_similar('we', word_vecs)[:10]) 

[('we', 0.9999999), ('it', 0.8605505), ('that', 0.85932636), ('the', 0.85546327), ('of', 0.85371566), (',', 0.8522324), ('in', 0.85171664), ('.', 0.8510599), ('to', 0.8467591), ('as', 0.8466495)]
