In [1]:
import torch
from torch.autograd import Variable
import numpy as np
import torch.functional as F
import torch.nn.functional as F

from scipy.spatial import distance

In [44]:
corpus = [
    'he is a king',
    'she is a queen',
    'he is a man',
    'she is a woman',
    'warsaw is poland capital',
    'berlin is germany capital',
    'paris is france capital',   
]

In [45]:
def tokenize_corpus(corpus):
    tokens = [x.split() for x in corpus]
    return tokens

tokenized_corpus = tokenize_corpus(corpus)

In [4]:
vocabulary = []
for sentence in tokenized_corpus:
    for token in sentence:
        if token not in vocabulary:
            vocabulary.append(token)

word2idx = {w: idx for (idx, w) in enumerate(vocabulary)}
idx2word = {idx: w for (idx, w) in enumerate(vocabulary)}

vocabulary_size = len(vocabulary)

In [5]:
window_size = 2
idx_pairs = []
for sentence in tokenized_corpus:
    indices = [word2idx[word] for word in sentence]
    for center_word_pos in range(len(indices)):
        for w in range(-window_size, window_size + 1):
            context_word_pos = center_word_pos + w
            # make soure not jump out sentence
            if context_word_pos < 0 or context_word_pos >= len(indices) or center_word_pos == context_word_pos:
                continue
            context_word_idx = indices[context_word_pos]
            idx_pairs.append((indices[center_word_pos], context_word_idx))

idx_pairs = np.array(idx_pairs) # it will be useful to have this as numpy array

In [6]:
def get_input_layer(word_idx):
    x = torch.zeros(vocabulary_size).float()
    x[word_idx] = 1.0
    return x

In [7]:
embedding_dims = 5
W1 = Variable(torch.randn(embedding_dims, vocabulary_size).float(), requires_grad=True)
W2 = Variable(torch.randn(vocabulary_size, embedding_dims).float(), requires_grad=True)
num_epochs = 1001
learning_rate = 0.001

for epo in range(num_epochs):
    loss_val = 0
    for data, target in idx_pairs:
        x = Variable(get_input_layer(data)).float()
        y_true = Variable(torch.from_numpy(np.array([target])).long())

        z1 = torch.matmul(W1, x)
        z2 = torch.matmul(W2, z1)    
        log_softmax = F.log_softmax(z2, dim=0)

        loss = F.nll_loss(log_softmax.view(1,-1), y_true)
        loss_val += loss.data.item()
        loss.backward()
    
        with torch.no_grad():
            W1 -= learning_rate * W1.grad
            W2 -= learning_rate * W2.grad

            W1.grad.zero_()
            W2.grad.zero_()

    if epo % 10 == 0:    
        print('Loss at epo {0}: {1}'.format(epo, loss_val/len(idx_pairs))) #'Loss at epo {epo}: {loss_val/len(idx_pairs)}'

Loss at epo 0: 6.8664751699992586
Loss at epo 10: 5.437300780841283
Loss at epo 20: 4.77494215965271
Loss at epo 30: 4.387421384879521
Loss at epo 40: 4.109059355940137
Loss at epo 50: 3.892965911115919
Loss at epo 60: 3.7176239899226595
Loss at epo 70: 3.571108218601772
Loss at epo 80: 3.4462477394512723
Loss at epo 90: 3.33836875643049
Loss at epo 100: 3.244168015888759
Loss at epo 110: 3.16115779876709
Loss at epo 120: 3.087382744039808
Loss at epo 130: 3.0212683745792934
Loss at epo 140: 2.9615271261760165
Loss at epo 150: 2.9070978283882143
Loss at epo 160: 2.8571009039878845
Loss at epo 170: 2.810803851059505
Loss at epo 180: 2.767595742430006
Loss at epo 190: 2.726966186932155
Loss at epo 200: 2.688487502506801
Loss at epo 210: 2.65180424451828
Loss at epo 220: 2.616621516432081
Loss at epo 230: 2.582699012756348
Loss at epo 240: 2.5498445408684867
Loss at epo 250: 2.5179105503218513
Loss at epo 260: 2.4867896352495467
Loss at epo 270: 2.456411632469722
Loss at epo 280: 2.426739

In [8]:
# W1.shape, W2.shape

In [9]:
with torch.no_grad():
    emb = {word: torch.matmul(W1, get_input_layer(i)).numpy() for i, word in enumerate(vocabulary)}

In [10]:
print(distance.cosine(emb['man'], emb['woman']))
print(distance.cosine(emb['man'], emb['he']))
print(distance.cosine(emb['man'], emb['she']))
print(distance.cosine(emb['woman'], emb['she']))
print()
print(distance.cosine(emb['king'], emb['he']))
print(distance.cosine(emb['king'], emb['she']))
print(distance.cosine(emb['king'], emb['queen']))
print(distance.cosine(emb['queen'], emb['he']))
print(distance.cosine(emb['queen'], emb['she']))
print()
print(distance.cosine(emb['warsaw'], emb['berlin']))
print(distance.cosine(emb['warsaw'], emb['paris']))
print(distance.cosine(emb['warsaw'], emb['poland']))
print(distance.cosine(emb['warsaw'], emb['germany']))
print(distance.cosine(emb['berlin'], emb['germany']))
print(distance.cosine(emb['berlin'], emb['poland']))
print(distance.cosine(emb['berlin'], emb['france']))

0.2198285460472107
0.2600386142730713
0.06145864725112915
0.3356438875198364

0.03295016288757324
0.17506688833236694
0.6612101793289185
0.5454409420490265
0.2948251962661743

0.17275750637054443
0.9971998913679272
0.7850916236639023
0.7926154434680939
0.8764673694968224
1.3115880191326141
0.9882218223065138


In [11]:
with torch.no_grad():
    emb = {word: W1[:, i] * W2[i] for i, word in enumerate(vocabulary)}

In [12]:
print(distance.cosine(emb['man'], emb['woman']))
print(distance.cosine(emb['man'], emb['he']))
print(distance.cosine(emb['man'], emb['she']))
print(distance.cosine(emb['woman'], emb['she']))
print()
print(distance.cosine(emb['king'], emb['he']))
print(distance.cosine(emb['king'], emb['she']))
print(distance.cosine(emb['king'], emb['queen']))
print(distance.cosine(emb['queen'], emb['he']))
print(distance.cosine(emb['queen'], emb['she']))
print()
print(distance.cosine(emb['warsaw'], emb['berlin']))
print(distance.cosine(emb['warsaw'], emb['paris']))
print(distance.cosine(emb['warsaw'], emb['poland']))
print(distance.cosine(emb['warsaw'], emb['germany']))
print(distance.cosine(emb['berlin'], emb['germany']))
print(distance.cosine(emb['berlin'], emb['poland']))
print(distance.cosine(emb['berlin'], emb['france']))

0.052087247371673584
0.5862727761268616
0.17826193571090698
0.08458787202835083

0.03320729732513428
0.3027147650718689
0.47494280338287354
0.6454004049301147
0.3522189259529114

1.5133419036865234
1.2108029425144196
0.9838755503296852
0.8900401517748833
0.6909601092338562
1.3766942024230957
1.4267953634262085


### same things in real dataset

In [77]:
import gzip
import gensim

import re

from collections import Counter
from tqdm import tqdm

from collections import OrderedDict

import pickle

In [17]:
def read_input(input_file):
    """This method reads the input file which is in gzip format"""

    with gzip.open(input_file, 'rb') as f:
        for i, line in enumerate(f):

            # do some pre-processing and return list of words for each review
            # text
            yield gensim.utils.simple_preprocess(line)

In [20]:
data_file = 'opinrankdatasetwithjudgments.tar.gz' # 'OpinRankDatasetWithJudgments.zip'
documents = list(read_input (data_file))

In [34]:
documents[36]

['query', 'good', 'cleanliness', 'nice', 'room', 'excellent', 'staff']

In [36]:
c = Counter()
for line in tqdm(documents):
    for word in line:
        c[word] += 1

100%|██████████| 792362/792362 [00:21<00:00, 37601.16it/s] 


In [58]:
c.most_common(10)

[('the', 3095891),
 ('and', 1631827),
 ('to', 1175956),
 ('was', 927119),
 ('in', 805277),
 ('we', 677762),
 ('of', 671223),
 ('is', 627155),
 ('for', 612576),
 ('it', 600419)]

In [71]:
# vocabulary = []
# for sentence in tqdm(documents):
#     for token in sentence:
#         if token not in vocabulary:
#             vocabulary.append(token)

# word2idx = {w: idx for (idx, w) in enumerate(vocabulary)}
# idx2word = {idx: w for (idx, w) in enumerate(vocabulary)}

# vocabulary_size = len(vocabulary)

In [60]:
def ha(source):
    pattern = r'[ \'\(\)\[\]]'
    return '{}'.format(re.sub(pattern, '', str(source)))

In [61]:
ha(tokenized_corpus).split(',')

['he',
 'is',
 'a',
 'king',
 'she',
 'is',
 'a',
 'queen',
 'he',
 'is',
 'a',
 'man',
 'she',
 'is',
 'a',
 'woman',
 'warsaw',
 'is',
 'poland',
 'capital',
 'berlin',
 'is',
 'germany',
 'capital',
 'paris',
 'is',
 'france',
 'capital']

In [62]:
list(OrderedDict.fromkeys(ha(tokenized_corpus).split(',')).keys())

['he',
 'is',
 'a',
 'king',
 'she',
 'queen',
 'man',
 'woman',
 'warsaw',
 'poland',
 'capital',
 'berlin',
 'germany',
 'paris',
 'france']

In [64]:
all_words = ha(documents).split(',')

In [65]:
all_words[:3]

['ustar', 'webuser', 'webuser']

In [68]:
vocabulary = list(OrderedDict.fromkeys(all_words).keys())

In [73]:
vocabulary_size = len(vocabulary)
word2idx = {w: idx for (idx, w) in enumerate(vocabulary)}
idx2word = {idx: w for (idx, w) in enumerate(vocabulary)}

In [74]:
vocabulary_size

189478

In [75]:
window_size = 4
idx_pairs = []
for sentence in tqdm(documents):
    indices = [word2idx[word] for word in sentence]
    for center_word_pos in range(len(indices)):
        for w in range(-window_size, window_size + 1):
            context_word_pos = center_word_pos + w
            # make soure not jump out sentence
            if context_word_pos < 0 or context_word_pos >= len(indices) or center_word_pos == context_word_pos:
                continue
            context_word_idx = indices[context_word_pos]
            idx_pairs.append((indices[center_word_pos], context_word_idx))

idx_pairs = np.array(idx_pairs) # it will be useful to have this as numpy array


  0%|          | 0/792362 [00:00<?, ?it/s][A
  0%|          | 3717/792362 [00:00<00:21, 36875.77it/s][A
  1%|          | 5800/792362 [00:00<00:27, 28874.12it/s][A
  1%|          | 7538/792362 [00:00<00:31, 25055.68it/s][A
  2%|▏         | 12138/792362 [00:00<00:25, 30275.61it/s][A
  2%|▏         | 16105/792362 [00:00<00:24, 32140.13it/s][A
  2%|▏         | 19330/792362 [00:00<00:24, 32166.22it/s][A
  3%|▎         | 22323/792362 [00:00<00:26, 29560.00it/s][A
  3%|▎         | 25115/792362 [00:00<00:26, 29365.24it/s][A
  4%|▎         | 29257/792362 [00:00<00:24, 30625.83it/s][A
  4%|▍         | 32921/792362 [00:01<00:24, 31181.81it/s][A
  5%|▍         | 36219/792362 [00:01<00:26, 28912.96it/s][A
  5%|▍         | 39029/792362 [00:01<00:26, 28321.35it/s][A
  5%|▌         | 42107/792362 [00:01<00:26, 28480.62it/s][A
  6%|▌         | 45201/792362 [00:01<00:26, 28642.96it/s][A
  6%|▌         | 48069/792362 [00:01<00:26, 27807.52it/s][A
  6%|▋         | 50648/792362 [00:01<00:2

 25%|██▌       | 200415/792362 [00:42<02:04, 4738.92it/s][A
 25%|██▌       | 200501/792362 [00:42<02:05, 4727.36it/s][A
 25%|██▌       | 200596/792362 [00:42<02:05, 4718.50it/s][A
 25%|██▌       | 200700/792362 [00:42<02:05, 4709.73it/s][A
 25%|██▌       | 200820/792362 [00:42<02:05, 4701.50it/s][A
 25%|██▌       | 200920/792362 [00:42<02:06, 4690.67it/s][A
 25%|██▌       | 201015/792362 [00:42<02:06, 4681.76it/s][A
 25%|██▌       | 201129/792362 [00:43<02:06, 4673.47it/s][A
 25%|██▌       | 201229/792362 [00:43<02:06, 4663.19it/s][A
 25%|██▌       | 201325/792362 [00:43<02:07, 4651.02it/s][A
 25%|██▌       | 201414/792362 [00:43<02:07, 4640.87it/s][A
 25%|██▌       | 201500/792362 [00:43<02:07, 4630.97it/s][A
 25%|██▌       | 201598/792362 [00:43<02:07, 4622.52it/s][A
 25%|██▌       | 201686/792362 [00:43<02:08, 4611.99it/s][A
 25%|██▌       | 201798/792362 [00:43<02:08, 4603.98it/s][A
 25%|██▌       | 201890/792362 [00:43<02:08, 4594.96it/s][A
 25%|██▌       | 201981/

 29%|██▊       | 227527/792362 [01:12<02:59, 3153.69it/s][A
 29%|██▊       | 227605/792362 [01:12<02:59, 3147.56it/s][A
 29%|██▊       | 227728/792362 [01:12<02:59, 3144.88it/s][A
 29%|██▉       | 227839/792362 [01:12<02:59, 3142.00it/s][A
 29%|██▉       | 227944/792362 [01:12<02:59, 3139.08it/s][A
 29%|██▉       | 228040/792362 [01:12<03:00, 3132.78it/s][A
 29%|██▉       | 228124/792362 [01:12<03:00, 3127.77it/s][A
 29%|██▉       | 228242/792362 [01:13<03:00, 3125.09it/s][A
 29%|██▉       | 228331/792362 [01:13<03:00, 3121.90it/s][A
 29%|██▉       | 228451/792362 [01:13<03:00, 3119.27it/s][A
 29%|██▉       | 228549/792362 [01:13<03:00, 3115.88it/s][A
 29%|██▉       | 228644/792362 [01:13<03:01, 3110.29it/s][A
 29%|██▉       | 228728/792362 [01:13<03:01, 3105.91it/s][A
 29%|██▉       | 228816/792362 [01:13<03:01, 3102.90it/s][A
 29%|██▉       | 228897/792362 [01:13<03:01, 3098.88it/s][A
 29%|██▉       | 228974/792362 [01:13<03:02, 3094.63it/s][A
 29%|██▉       | 229046/

 32%|███▏      | 251287/792362 [01:42<03:41, 2447.00it/s][A
 32%|███▏      | 251395/792362 [01:42<03:41, 2445.66it/s][A
 32%|███▏      | 252030/792362 [01:42<03:40, 2449.46it/s][A
 32%|███▏      | 252343/792362 [01:42<03:40, 2450.11it/s][A
 32%|███▏      | 252613/792362 [01:43<03:40, 2446.32it/s][A
 32%|███▏      | 252832/792362 [01:43<03:41, 2440.69it/s][A
 32%|███▏      | 253006/792362 [01:43<03:41, 2436.24it/s][A
 32%|███▏      | 253148/792362 [01:44<03:41, 2432.65it/s][A
 32%|███▏      | 253268/792362 [01:44<03:41, 2430.15it/s][A
 32%|███▏      | 253375/792362 [01:44<03:42, 2426.49it/s][A
 32%|███▏      | 253466/792362 [01:44<03:42, 2424.83it/s][A
 32%|███▏      | 253555/792362 [01:44<03:42, 2422.67it/s][A
 32%|███▏      | 253639/792362 [01:44<03:42, 2421.16it/s][A
 32%|███▏      | 253723/792362 [01:44<03:42, 2419.61it/s][A
 32%|███▏      | 253831/792362 [01:44<03:42, 2418.27it/s][A
 32%|███▏      | 253934/792362 [01:45<03:42, 2416.93it/s][A
 32%|███▏      | 254028/

 36%|███▌      | 282340/792362 [02:13<04:01, 2108.98it/s][A
 36%|███▌      | 282436/792362 [02:13<04:01, 2108.12it/s][A
 36%|███▌      | 282542/792362 [02:14<04:01, 2107.29it/s][A
 36%|███▌      | 282644/792362 [02:14<04:01, 2106.46it/s][A
 36%|███▌      | 282740/792362 [02:14<04:02, 2105.25it/s][A
 36%|███▌      | 282831/792362 [02:14<04:02, 2104.26it/s][A
 36%|███▌      | 282968/792362 [02:14<04:02, 2103.71it/s][A
 36%|███▌      | 283072/792362 [02:14<04:02, 2102.38it/s][A
 36%|███▌      | 283168/792362 [02:14<04:02, 2101.36it/s][A
 36%|███▌      | 283261/792362 [02:14<04:02, 2100.12it/s][A
 36%|███▌      | 283349/792362 [02:14<04:02, 2099.07it/s][A
 36%|███▌      | 283439/792362 [02:15<04:02, 2098.17it/s][A
 36%|███▌      | 283560/792362 [02:15<04:02, 2097.51it/s][A
 36%|███▌      | 283657/792362 [02:15<04:02, 2096.14it/s][A
 36%|███▌      | 283747/792362 [02:15<04:02, 2095.19it/s][A
 36%|███▌      | 283861/792362 [02:15<04:02, 2094.46it/s][A
 36%|███▌      | 283957/

 39%|███▉      | 309487/792362 [05:17<08:15, 975.32it/s][A
 39%|███▉      | 309584/792362 [05:17<08:15, 975.31it/s][A
 39%|███▉      | 309685/792362 [05:17<08:14, 975.31it/s][A
 39%|███▉      | 309810/792362 [05:17<08:14, 975.39it/s][A
 39%|███▉      | 309913/792362 [05:17<08:14, 975.29it/s][A
 39%|███▉      | 310008/792362 [05:17<08:14, 975.14it/s][A
 39%|███▉      | 310148/792362 [05:18<08:14, 975.27it/s][A
 39%|███▉      | 310250/792362 [05:18<08:14, 975.23it/s][A
 39%|███▉      | 310373/792362 [05:18<08:14, 975.31it/s][A
 39%|███▉      | 310480/792362 [05:18<08:14, 975.34it/s][A
 39%|███▉      | 310586/792362 [05:18<08:13, 975.28it/s][A
 39%|███▉      | 310722/792362 [05:18<08:13, 975.40it/s][A
 39%|███▉      | 310832/792362 [05:18<08:13, 975.41it/s][A
 39%|███▉      | 310957/792362 [05:18<08:13, 975.49it/s][A
 39%|███▉      | 311069/792362 [05:18<08:13, 975.49it/s][A
 39%|███▉      | 311177/792362 [05:18<08:13, 975.51it/s][A
 39%|███▉      | 311301/792362 [05:19<08

 42%|████▏     | 336613/792362 [05:51<07:55, 957.58it/s][A
 42%|████▏     | 336705/792362 [05:51<07:55, 957.54it/s][A
 43%|████▎     | 336795/792362 [05:51<07:55, 957.51it/s][A
 43%|████▎     | 336883/792362 [05:51<07:55, 957.46it/s][A
 43%|████▎     | 336969/792362 [05:51<07:55, 957.41it/s][A
 43%|████▎     | 337062/792362 [05:52<07:55, 957.41it/s][A
 43%|████▎     | 337167/792362 [05:52<07:55, 957.43it/s][A
 43%|████▎     | 337259/792362 [05:52<07:55, 957.32it/s][A
 43%|████▎     | 337344/792362 [05:52<07:55, 957.27it/s][A
 43%|████▎     | 337428/792362 [05:52<07:55, 957.22it/s][A
 43%|████▎     | 337517/792362 [05:52<07:55, 957.20it/s][A
 43%|████▎     | 337607/792362 [05:52<07:55, 957.18it/s][A
 43%|████▎     | 337728/792362 [05:52<07:54, 957.25it/s][A
 43%|████▎     | 337824/792362 [05:52<07:54, 957.23it/s][A
 43%|████▎     | 337940/792362 [05:53<07:54, 957.27it/s][A
 43%|████▎     | 338039/792362 [05:53<07:54, 957.27it/s][A
 43%|████▎     | 338138/792362 [05:53<07

 46%|████▋     | 366781/792362 [06:22<07:23, 959.29it/s][A
 46%|████▋     | 366867/792362 [06:22<07:23, 959.21it/s][A
 46%|████▋     | 366982/792362 [06:22<07:23, 959.26it/s][A
 46%|████▋     | 367073/792362 [06:22<07:23, 958.83it/s][A
 46%|████▋     | 367169/792362 [06:22<07:23, 958.83it/s][A
 46%|████▋     | 367297/792362 [06:23<07:23, 958.92it/s][A
 46%|████▋     | 367392/792362 [06:23<07:23, 958.90it/s][A
 46%|████▋     | 367506/792362 [06:23<07:23, 958.94it/s][A
 46%|████▋     | 367605/792362 [06:23<07:22, 958.89it/s][A
 46%|████▋     | 367698/792362 [06:23<07:22, 958.82it/s][A
 46%|████▋     | 367793/792362 [06:23<07:22, 958.82it/s][A
 46%|████▋     | 367927/792362 [06:23<07:22, 958.92it/s][A
 46%|████▋     | 368036/792362 [06:23<07:22, 958.95it/s][A
 46%|████▋     | 368141/792362 [06:23<07:22, 958.95it/s][A
 46%|████▋     | 368244/792362 [06:24<07:22, 958.97it/s][A
 46%|████▋     | 368403/792362 [06:24<07:22, 959.13it/s][A
 47%|████▋     | 368522/792362 [06:24<07

 50%|█████     | 397821/792362 [06:52<06:49, 963.41it/s][A
 50%|█████     | 397973/792362 [06:53<06:49, 963.54it/s][A
 50%|█████     | 398101/792362 [06:53<06:49, 963.60it/s][A
 50%|█████     | 398226/792362 [06:53<06:49, 963.63it/s][A
 50%|█████     | 398363/792362 [06:53<06:48, 963.73it/s][A
 50%|█████     | 398488/792362 [06:53<06:48, 963.57it/s][A
 50%|█████     | 398622/792362 [06:53<06:48, 963.66it/s][A
 50%|█████     | 398737/792362 [06:53<06:48, 963.65it/s][A
 50%|█████     | 398846/792362 [06:53<06:48, 963.67it/s][A
 50%|█████     | 398953/792362 [06:54<06:48, 963.64it/s][A
 50%|█████     | 399093/792362 [06:54<06:48, 963.74it/s][A
 50%|█████     | 399229/792362 [06:54<06:47, 963.84it/s][A
 50%|█████     | 399365/792362 [06:54<06:47, 963.93it/s][A
 50%|█████     | 399490/792362 [06:54<06:47, 963.96it/s][A
 50%|█████     | 399610/792362 [06:54<06:47, 963.92it/s][A
 50%|█████     | 399723/792362 [06:54<06:47, 963.96it/s][A
 50%|█████     | 399833/792362 [06:54<06

 72%|███████▏  | 573117/792362 [07:25<02:50, 1286.84it/s][A
 72%|███████▏  | 574045/792362 [07:25<02:49, 1288.60it/s][A
 73%|███████▎  | 574944/792362 [07:25<02:48, 1290.33it/s][A
 73%|███████▎  | 575875/792362 [07:25<02:47, 1292.12it/s][A
 73%|███████▎  | 576839/792362 [07:25<02:46, 1294.00it/s][A
 73%|███████▎  | 577760/792362 [07:25<02:45, 1295.77it/s][A
 73%|███████▎  | 578680/792362 [07:25<02:44, 1297.53it/s][A
 73%|███████▎  | 579622/792362 [07:26<02:43, 1299.35it/s][A
 73%|███████▎  | 580540/792362 [07:26<02:42, 1301.03it/s][A
 73%|███████▎  | 581497/792362 [07:26<02:41, 1302.88it/s][A
 74%|███████▎  | 582434/792362 [07:26<02:40, 1304.69it/s][A
 74%|███████▎  | 583335/792362 [07:26<02:40, 1306.42it/s][A
 74%|███████▎  | 584236/792362 [07:26<02:39, 1308.13it/s][A
 74%|███████▍  | 585272/792362 [07:26<02:38, 1310.16it/s][A
 74%|███████▍  | 586245/792362 [07:26<02:37, 1312.04it/s][A
 74%|███████▍  | 587193/792362 [07:26<02:36, 1313.86it/s][A
 74%|███████▍  | 588128/

In [78]:
with open('idx_pairs.pickle', 'wb') as f:
    pickle.dump(idx_pairs, f)

OverflowError: cannot serialize a bytes object larger than 4 GiB

In [84]:
embedding_dims = 128
W1 = Variable(torch.randn(embedding_dims, vocabulary_size).float(), requires_grad=True)
W2 = Variable(torch.randn(vocabulary_size, embedding_dims).float(), requires_grad=True)
num_epochs = 101
learning_rate = 0.001

for epo in range(num_epochs):
    loss_val = 0
    for data, target in idx_pairs:
        x = Variable(get_input_layer(data)).float()
        y_true = Variable(torch.from_numpy(np.array([target])).long())

        z1 = torch.matmul(W1, x)
        z2 = torch.matmul(W2, z1)    
        log_softmax = F.log_softmax(z2, dim=0)

        loss = F.nll_loss(log_softmax.view(1,-1), y_true)
        loss_val += loss.data.item()
        loss.backward()
    
        with torch.no_grad():
            W1 -= learning_rate * W1.grad
            W2 -= learning_rate * W2.grad

            W1.grad.zero_()
            W2.grad.zero_()

#     if epo % 10 == 0:    
#         print('Loss at epo {0}: {1}'.format(epo, loss_val/len(idx_pairs))) #'Loss at epo {epo}: {loss_val/len(idx_pairs)}'
    print('Loss at epo {0}: {1}'.format(epo, loss_val/len(idx_pairs)))

KeyboardInterrupt: 

In [5]:
import time

In [9]:
current_time = time.time()
time.sleep(5)
time.time() - current_time

5.003559827804565