In [0]:
import torch
from torch.autograd import Variable
import numpy as np
import torch.functional as F
import torch.nn.functional as F

In [0]:

corpus = [
    'he is a king',
    'she is a queen',
    'he is a man',
    'she is a woman',
    'warsaw is poland capital',
    'berlin is germany capital',
    'paris is france capital',
]

In [0]:
def tokenize_corpus(corpus):
    tokens = [x.split() for x in corpus]
    return tokens

tokenized_corpus = tokenize_corpus(corpus)
print(tokenized_corpus)

[['he', 'is', 'a', 'king'], ['she', 'is', 'a', 'queen'], ['he', 'is', 'a', 'man'], ['she', 'is', 'a', 'woman'], ['warsaw', 'is', 'poland', 'capital'], ['berlin', 'is', 'germany', 'capital'], ['paris', 'is', 'france', 'capital']]


In [0]:
vocabulary = []
for sentence in tokenized_corpus:
    for token in sentence:
        if token not in vocabulary:
            vocabulary.append(token)

word2idx = {w: idx for (idx, w) in enumerate(vocabulary)}
idx2word = {idx: w for (idx, w) in enumerate(vocabulary)}

vocabulary_size = len(vocabulary)

In [0]:
word2idx

{'a': 2,
 'berlin': 11,
 'capital': 10,
 'france': 14,
 'germany': 12,
 'he': 0,
 'is': 1,
 'king': 3,
 'man': 6,
 'paris': 13,
 'poland': 9,
 'queen': 5,
 'she': 4,
 'warsaw': 8,
 'woman': 7}

In [0]:
window_size = 2
idx_pairs = []
# for each sentence
for sentence in tokenized_corpus:
    indices = [word2idx[word] for word in sentence]
    # for each word, threated as center word
    for center_word_pos in range(len(indices)):
        # for each window position
        for w in range(-window_size, window_size + 1):
            context_word_pos = center_word_pos + w
            # make soure not jump out sentence
            if context_word_pos < 0 or context_word_pos >= len(indices) or center_word_pos == context_word_pos:
                continue
            context_word_idx = indices[context_word_pos]
            idx_pairs.append((indices[center_word_pos], context_word_idx))

idx_pairs = np.array(idx_pairs) # it will be useful to have this as numpy array

In [0]:
idx_pairs[:10]

array([[0, 1],
       [0, 2],
       [1, 0],
       [1, 2],
       [1, 3],
       [2, 0],
       [2, 1],
       [2, 3],
       [3, 1],
       [3, 2]])

![alt text](https://miro.medium.com/max/377/1*uYiqfNrUIzkdMrmkBWGMPw.png)

In [0]:
def get_input_layer(word_idx):
    x = torch.zeros(vocabulary_size).float()
    x[word_idx] = 1.0
    return x
  
  #Input layer is just the center word encoded in one-hot manner. It dimensions are [1, vocabulary_size]
  
  

In [0]:
embedding_dims = 5
W1 = Variable(torch.randn(embedding_dims, vocabulary_size).float(), requires_grad=True)
W2 = Variable(torch.randn(vocabulary_size, embedding_dims).float(), requires_grad=True)
num_epochs = 1010
learning_rate = 0.001

for epo in range(num_epochs):
    loss_val = 0
    for data, target in idx_pairs:
        x = Variable(get_input_layer(data)).float()
        y_true = Variable(torch.from_numpy(np.array([target])).long())

        z1 = torch.matmul(W1, x)
        z2 = torch.matmul(W2, z1)
    
        log_softmax = F.log_softmax(z2, dim=0)

        loss = F.nll_loss(log_softmax.view(1,-1), y_true)
        loss_val += loss.data.item()
        loss.backward()
        W1.data -= learning_rate * W1.grad.data
        W2.data -= learning_rate * W2.grad.data

        W1.grad.data.zero_()
        W2.grad.data.zero_()
    if epo % 10 == 0:    
        print(f'Loss at epo {epo}: {loss_val/len(idx_pairs)}')

Loss at epo 0: 4.538836489404951
Loss at epo 10: 4.053764954635075
Loss at epo 20: 3.6870625819478717
Loss at epo 30: 3.3981100388935634
Loss at epo 40: 3.164670308998653
Loss at epo 50: 2.9728501898901802
Loss at epo 60: 2.8136727179799763
Loss at epo 70: 2.681133336680276
Loss at epo 80: 2.5708919491086686
Loss at epo 90: 2.4794341121401104
Loss at epo 100: 2.4036593164716447
Loss at epo 110: 2.3407671809196473
Loss at epo 120: 2.2882858651024955
Loss at epo 130: 2.2441212603024074
Loss at epo 140: 2.206563547679356
Loss at epo 150: 2.174257889815739
Loss at epo 160: 2.1461487650871276
Loss at epo 170: 2.1214211889675685
Loss at epo 180: 2.099448001384735
Loss at epo 190: 2.079744052886963
Loss at epo 200: 2.0619311877659388
Loss at epo 210: 2.04571270431791
Loss at epo 220: 2.030852677140917
Loss at epo 230: 2.0171612892832074
Loss at epo 240: 2.0044845581054687
Loss at epo 250: 1.9926958067076548
Loss at epo 260: 1.9816904408591134
Loss at epo 270: 1.9713806578091213
Loss at epo 28