In [23]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as tud

from collections import Counter
import collections
import numpy as np
import random
import math

import pandas as pd
import scipy
import sklearn
from sklearn.metrics.pairwise import cosine_similarity

USE_CUDA = torch.cuda.is_available()

random.seed(1)
np.random.seed(1)
torch.manual_seed(1)
if USE_CUDA:
    torch.cuda.manual_seed(1)
    
## hyper parameters
C = 3 # context window
K = 100 # number of negative samples

EMBEDDING_SIZE = 100
MAX_VOCAB_SIZE = 30000
BATCH_SIZE = 128
NUM_EPOCHS = 2
LEARNING_RATE = 0.2



def word_tokenize(text):
    return text.split()

In [24]:
with open("text8/text8.train.txt") as fin:
    text = fin.read()
text = text.split()
vocab = dict(collections.Counter(text).most_common(MAX_VOCAB_SIZE -1))
vocab['<unk>'] = len(text) -np.sum(list(vocab.values()))

idx_to_word = [word for word in vocab.keys()]
word_to_idx = {word:i for i, word in enumerate(idx_to_word)}

word_counts = np.array(list(vocab.values()), dtype=np.float32)
word_freqs = word_counts / np.sum(word_counts)
word_freqs = word_freqs ** (3./4.)
word_freqs = word_freqs / np.sum(word_freqs)

VOCAB_SIZE = len(idx_to_word)

In [25]:
class WordEmbeddingDataset(tud.Dataset):
    def __init__(self, text, word_to_idx, word_freqs):
        super(WordEmbeddingDataset, self).__init__()
        self.text_encoded = [word_to_idx.get(word, word_to_idx["<unk>"]) for word in text]
        self.text_encoded = torch.LongTensor(self.text_encoded)
        self.word_freqs = torch.Tensor(word_freqs)
        
    
    def __len__(self):
        return len(self.text_encoded)
        
    def __getitem__(self, idx):
        center_word = self.text_encoded[idx]
        pos_indices = list(range(idx-C, idx)) + list(range(idx+1, idx+C))
        pos_indices = [i % len(self.text_encoded) for i in pos_indices]
        pos_words = self.text_encoded[pos_indices]
        neg_words = torch.multinomial(self.word_freqs, K * pos_words.shape[0], replacement=True)
        
        return center_word, pos_words, neg_words    

In [26]:
dataset = WordEmbeddingDataset(text, word_to_idx, word_freqs)
dataloader = tud.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

In [27]:
next(iter(dataloader))

[tensor([  819,    45,   621,    15,  1797,    29,   328,   157,    25,   598,
             9,    13,    25,    12,     5,  4344,     3,    13,     0,     5,
          1532,   648,     9,   937,    16, 22599,    85,  7406,  2801,   419,
          1238,     1,   966,  1655,   644,     6,    16, 18573, 11226,    37,
           261,  1514,  3537,     1, 29999,   644,     4,   210,   110,     5,
          3316,  1454,    29,     7,     0,   825,     2,  3992,  2991,  9029,
          1881,     0, 20161,    13,     5,     4, 12028,  7117,   394,     3,
         27580,  3642,    36,  2050,    92,     8, 23976,  2184,   335,   339,
          1314,    15,    34,   284,  4247,  2389,    25,  8552,     0,  1467,
           131,  5437,     1, 10596,     2,     4,  1963,    37,     5,   401,
          2111,     2,     6,     0,    14,    10,     1,     9,  5363, 12439,
          8464,     1,   432,   298,  4171, 11035,     0,  3513,     4,   969,
            28,   836, 29999,    88,     5,     6,  

In [31]:
class EmbeddingModel(nn.Module):
    def __init__(self, vocab_size, embed_size):
        super(EmbeddingModel, self).__init__()
        self.vocab_size = vocab_size
        self.embed_size = embed_size
        
        self.in_embed = torch.nn.Embedding(self.vocab_size, self.embed_size, sparse=False)
        self.out_embed = torch.nn.Embedding(self.vocab_size, self.embed_size, sparse=False)
        
        
    def forward(self, input_labels, pos_labels, neg_labels):
        batch_size = input_labels.shape[0]
        input_embedding = self.in_embed(input_labels) # batch_size * embed_size
        pos_embedding = self.out_embed(pos_labels) # batch_size * (2*C) * embed_size
        neg_embedding = self.out_embed(neg_labels) # batch_size * (2*C*K) * embed_size
        
        log_pos = torch.bmm(pos_embedding, input_embedding.unsqueeze(2)).squeeze() # batch_size * (2*C) 
        log_neg = torch.bmm(neg_embedding, -input_embedding.unsqueeze(2)).squeeze() # batch_size * (2*C*K)
        
        log_pos = torch.nn.functional.logsigmoid(log_pos).sum(1)
        log_neg = torch.nn.functional.logsigmoid(log_neg).sum(1)
        loss = log_pos + log_neg
        
        return -loss
    
    def input_embeddings(self):
        return self.in_embed.weight.data.cpu().numpy()
        
        

In [32]:
model = EmbeddingModel(VOCAB_SIZE, EMBEDDING_SIZE)
if USE_CUDA:
    model = model.cuda()

In [None]:
optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)
for e in range(NUM_EPOCHS):
    for i, (input_labels, pos_labels, neg_labels) in enumerate(dataloader):
        if USE_CUDA:
            input_labels = input_labels.cuda()
            pos_labels = pos_labels.cuda()
            neg_labels = neg_labels.cuda()
            
        optimizer.zero_grad()
        loss = model(input_labels, pos_labels, neg_labels).mean()
        loss.backward()
        optimizer.step()
        
        if i % 100 == 0:
            print("epoch", e, "iteration", i, loss.item())
        

epoch 0 iteration 0 2060.41748046875
epoch 0 iteration 100 1007.7798461914062
epoch 0 iteration 200 836.9866943359375
epoch 0 iteration 300 702.9419555664062
epoch 0 iteration 400 587.60302734375
epoch 0 iteration 500 597.7804565429688
epoch 0 iteration 600 548.6334838867188
epoch 0 iteration 700 474.4078063964844
epoch 0 iteration 800 338.656982421875
epoch 0 iteration 900 409.16668701171875
epoch 0 iteration 1000 408.77423095703125
epoch 0 iteration 1100 328.84130859375
epoch 0 iteration 1200 366.0126037597656
epoch 0 iteration 1300 328.85736083984375
epoch 0 iteration 1400 257.6401672363281
epoch 0 iteration 1500 224.3773956298828
epoch 0 iteration 1600 273.71319580078125
epoch 0 iteration 1700 257.4461975097656
epoch 0 iteration 1800 196.97882080078125
epoch 0 iteration 1900 206.11093139648438
epoch 0 iteration 2000 151.7388153076172
epoch 0 iteration 2100 134.87025451660156
epoch 0 iteration 2200 161.03024291992188
epoch 0 iteration 2300 240.2861785888672
epoch 0 iteration 2400 20

epoch 0 iteration 19700 36.82936477661133
epoch 0 iteration 19800 35.133052825927734
epoch 0 iteration 19900 46.634822845458984
epoch 0 iteration 20000 34.05497360229492
epoch 0 iteration 20100 37.341453552246094
epoch 0 iteration 20200 33.34629440307617
epoch 0 iteration 20300 33.2215461730957
epoch 0 iteration 20400 33.385459899902344
epoch 0 iteration 20500 39.61796569824219
epoch 0 iteration 20600 36.553009033203125
epoch 0 iteration 20700 40.778472900390625
epoch 0 iteration 20800 37.51244354248047
epoch 0 iteration 20900 38.208953857421875
epoch 0 iteration 21000 46.90431594848633
epoch 0 iteration 21100 40.07072448730469
epoch 0 iteration 21200 44.681190490722656
epoch 0 iteration 21300 39.24074172973633
epoch 0 iteration 21400 39.98447036743164
epoch 0 iteration 21500 36.696807861328125
epoch 0 iteration 21600 35.18543243408203
epoch 0 iteration 21700 36.54738235473633
epoch 0 iteration 21800 38.53687286376953
epoch 0 iteration 21900 33.47993469238281
epoch 0 iteration 22000 33

epoch 0 iteration 39100 33.10449981689453
epoch 0 iteration 39200 30.095802307128906
epoch 0 iteration 39300 31.059370040893555
epoch 0 iteration 39400 33.98817443847656
epoch 0 iteration 39500 31.330595016479492
epoch 0 iteration 39600 34.35719680786133
epoch 0 iteration 39700 32.04148483276367
epoch 0 iteration 39800 33.97158432006836
epoch 0 iteration 39900 29.082427978515625
epoch 0 iteration 40000 32.803733825683594
epoch 0 iteration 40100 30.91243553161621
epoch 0 iteration 40200 35.47956085205078
epoch 0 iteration 40300 32.240482330322266
epoch 0 iteration 40400 36.18156814575195
epoch 0 iteration 40500 35.317806243896484
epoch 0 iteration 40600 30.628522872924805
epoch 0 iteration 40700 31.970947265625
epoch 0 iteration 40800 31.413223266601562
epoch 0 iteration 40900 31.606416702270508
epoch 0 iteration 41000 33.627166748046875
epoch 0 iteration 41100 32.99311065673828
epoch 0 iteration 41200 33.8876953125
epoch 0 iteration 41300 31.08798599243164
epoch 0 iteration 41400 32.91