In [1]:
import numpy as np
import torch
from torch.autograd import Variable
import torch.functional as F
import torch.nn.functional as F
#https://towardsdatascience.com/implementing-word2vec-in-pytorch-skip-gram-model-e6bae040d2fb

In [2]:
corpus = [
    'he is a king',
    'she is a queen',
    'he is a man',
    'she is a woman',
    'warsaw is poland capital',
    'berlin is germany capital',
    'paris is france capital',
]

In [4]:
def tokenize_corpus(corpus):
    tokens = [x.split() for x in corpus]
    return tokens

tokenized_corpus = tokenize_corpus(corpus)

In [5]:
vocabulary = []
for sentence in tokenized_corpus:
    for token in sentence:
        if token not in vocabulary:
            vocabulary.append(token)

word2idx = {w: idx for (idx, w) in enumerate(vocabulary)}
idx2word = {idx: w for (idx, w) in enumerate(vocabulary)}

vocabulary_size = len(vocabulary)

In [6]:
idx2word

{0: 'he',
 1: 'is',
 2: 'a',
 3: 'king',
 4: 'she',
 5: 'queen',
 6: 'man',
 7: 'woman',
 8: 'warsaw',
 9: 'poland',
 10: 'capital',
 11: 'berlin',
 12: 'germany',
 13: 'paris',
 14: 'france'}

In [7]:
window_size = 2
idx_pairs = []
# for each sentence
for sentence in tokenized_corpus:
    indices = [word2idx[word] for word in sentence]
    # for each word, threated as center word
    for center_word_pos in range(len(indices)):
        # for each window position
        for w in range(-window_size, window_size + 1):
            context_word_pos = center_word_pos + w
            # make soure not jump out sentence
            if context_word_pos < 0 or context_word_pos >= len(indices) or center_word_pos == context_word_pos:
                continue
            context_word_idx = indices[context_word_pos]
            idx_pairs.append((indices[center_word_pos], context_word_idx))

idx_pairs = np.array(idx_pairs) # it will be useful to have this as numpy array

In [8]:
#Input layer
def get_input_layer(word_idx):
    x = torch.zeros(vocabulary_size).float()
    x[word_idx] = 1.0
    return x

In [9]:
embedding_dims = 5
W1 = Variable(torch.randn(embedding_dims, vocabulary_size).float(), requires_grad=True)
W2 = Variable(torch.randn(vocabulary_size, embedding_dims).float(), requires_grad=True)
num_epochs = 101
learning_rate = 0.001

for epo in range(num_epochs):
    loss_val = 0
    for data, target in idx_pairs:
        x = Variable(get_input_layer(data)).float()
        y_true = Variable(torch.from_numpy(np.array([target])).long())

        z1 = torch.matmul(W1, x)
        z2 = torch.matmul(W2, z1)
    
        log_softmax = F.log_softmax(z2, dim=0)

        loss = F.nll_loss(log_softmax.view(1,-1), y_true)
        loss_val += loss.data[0]
        loss.backward()
        W1.data -= learning_rate * W1.grad.data
        W2.data -= learning_rate * W2.grad.data

        W1.grad.data.zero_()
        W2.grad.data.zero_()
    if epo % 10 == 0:    
        print(f'Loss at epo {epo}: {loss_val/len(idx_pairs)}')
        



Loss at epo 0: 4.113731861114502
Loss at epo 10: 3.8165783882141113
Loss at epo 20: 3.58565616607666
Loss at epo 30: 3.401162624359131
Loss at epo 40: 3.251281261444092
Loss at epo 50: 3.127729892730713
Loss at epo 60: 3.0244221687316895
Loss at epo 70: 2.9368245601654053
Loss at epo 80: 2.8615474700927734
Loss at epo 90: 2.7960410118103027
Loss at epo 100: 2.7383763790130615


In [16]:
idx_pairs

array([[ 0,  1],
       [ 0,  2],
       [ 1,  0],
       [ 1,  2],
       [ 1,  3],
       [ 2,  0],
       [ 2,  1],
       [ 2,  3],
       [ 3,  1],
       [ 3,  2],
       [ 4,  1],
       [ 4,  2],
       [ 1,  4],
       [ 1,  2],
       [ 1,  5],
       [ 2,  4],
       [ 2,  1],
       [ 2,  5],
       [ 5,  1],
       [ 5,  2],
       [ 0,  1],
       [ 0,  2],
       [ 1,  0],
       [ 1,  2],
       [ 1,  6],
       [ 2,  0],
       [ 2,  1],
       [ 2,  6],
       [ 6,  1],
       [ 6,  2],
       [ 4,  1],
       [ 4,  2],
       [ 1,  4],
       [ 1,  2],
       [ 1,  7],
       [ 2,  4],
       [ 2,  1],
       [ 2,  7],
       [ 7,  1],
       [ 7,  2],
       [ 8,  1],
       [ 8,  9],
       [ 1,  8],
       [ 1,  9],
       [ 1, 10],
       [ 9,  8],
       [ 9,  1],
       [ 9, 10],
       [10,  1],
       [10,  9],
       [11,  1],
       [11, 12],
       [ 1, 11],
       [ 1, 12],
       [ 1, 10],
       [12, 11],
       [12,  1],
       [12, 10],
       [10,  1

In [15]:
W2

tensor([[ 0.3497, -1.4642,  1.0571,  0.0045,  0.0010],
        [ 0.8269, -0.4358,  0.6130,  0.1759,  1.4404],
        [-0.0569, -0.4408, -0.4868, -1.3611, -1.0915],
        [ 1.7092, -0.2991, -0.1926, -1.1277,  0.2957],
        [-0.2882,  1.6593,  0.0396,  0.0230, -0.3410],
        [ 0.5864, -0.8022,  0.9407, -0.2941, -2.1316],
        [ 1.9038, -1.4988,  1.1339, -0.3454,  0.1369],
        [ 0.3563, -0.4898,  0.8352,  0.6977, -1.1130],
        [-1.9751,  0.8893, -0.9875,  1.3380,  0.0128],
        [-1.4534,  0.0092, -0.8266,  1.3806,  0.4391],
        [ 0.2096, -0.0563,  0.4370,  0.2366,  1.4569],
        [-0.4202,  0.3631,  0.2115,  0.2002,  0.1481],
        [ 0.5035, -1.9162, -0.4839, -0.1943, -1.8723],
        [-1.0708, -0.4171, -1.1511,  0.6019,  0.7890],
        [ 0.0876, -0.4830,  0.2926,  0.5758,  0.1390]])

In [47]:
corpus

['he is a king',
 'she is a queen',
 'he is a man',
 'she is a woman',
 'warsaw is poland capital',
 'berlin is germany capital',
 'paris is france capital']