In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import Counter
from torch.utils.data import Dataset, DataLoader
import string
import random
from collections import Counter

In [2]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [3]:
with open('drive/MyDrive/data/text8', 'r') as file:
  content = file.read()


## Pre-processing data

In [4]:
def clean_encode_data(text):
  for punctuation in string.punctuation:
    text = text.replace(punctuation, '')
  text = text.strip()
  split_text = text.split()
  print(len(split_text))
  word_and_index, index_and_word = {}, {}
  limit, cnt = 10000, 0
  for id, (word, count) in enumerate(Counter(split_text).items()):
    if count > 5 and id < limit:
      word_and_index[word] = cnt
      index_and_word[cnt] = word
      cnt += 1
  encoded_text = [word_and_index[word] for word in word_and_index]
  return encoded_text, word_and_index, index_and_word


In [5]:
encoded_text, word_and_index, index_and_word = clean_encode_data(content)


17005207


## Skip-gram (negative sampling) model

In [6]:
class generate_data():
  def __init__(self, window_size, text):
    self.window_size = window_size
    self.text = text

  def get_target(self, word_id):
    start = max(0, word_id - self.window_size)
    end = min(len(self.text), word_id + self.window_size)
    return self.text[start: word_id] + self.text[word_id + 1: end]

  def make_dataset(self, num_samples):
    dataset = []
    for center in self.text:
      targets = self.get_target(center)
      for context in targets:
        neg_samples = np.random.choice(self.text, num_samples)
        neg_samples = [torch.tensor([sample]) for sample in neg_samples]
        dataset.append((torch.tensor([center]), torch.tensor([context]), torch.tensor(neg_samples)))
    return dataset


In [14]:
class skip_gram(nn.Module):
  def __init__(self, vocab_size, embed_size):
    super(skip_gram, self).__init__()

    self.center_embeddings = nn.Embedding(vocab_size, embed_size)
    self.context_embeddings = nn.Embedding(vocab_size, embed_size)

    self.center_embeddings.weight.data.uniform_(-1, 1)
    self.context_embeddings.weight.data.uniform_(-1, 1)

  def forward(self, center, context, neg_samples):
    emb_center = self.center_embeddings(center)
    emb_context = self.context_embeddings(context)
    emb_negsamples = self.context_embeddings(neg_samples)

    return emb_center, emb_context, emb_negsamples

In [15]:
class loss_var(nn.Module):
  def __init__(self):
    super().__init__()

  def forward(self, center_vector, context_vector, neg_vectors):
    pos_loss = torch.mul(center_vector, context_vector).squeeze().sum().sigmoid().log()
    neg_loss = torch.mul(-neg_vectors, center_vector).sum(dim = 1).sigmoid().log().sum()

    return -(pos_loss + neg_loss)

## Training

In [16]:
vocab_size = len(encoded_text)
emb_size = 100
neg_samples = 5

model = skip_gram(vocab_size = vocab_size, embed_size = emb_size)
loss = loss_var()
optimizer = optim.SGD(model.parameters(), lr = 0.01)

gen_data = generate_data(window_size = 4, text = encoded_text)
dataset = gen_data.make_dataset(num_samples = neg_samples)

dataset[:5]


[(tensor([0]), tensor([1]), tensor([8514, 7633,  930, 1348,  809])),
 (tensor([0]), tensor([2]), tensor([3687, 3219, 2294, 3070, 6165])),
 (tensor([0]), tensor([3]), tensor([2434, 2931, 3062, 8835, 5068])),
 (tensor([1]), tensor([0]), tensor([2482, 5613, 2420, 8799, 3141])),
 (tensor([1]), tensor([2]), tensor([8859, 8259, 2460,  129, 5649]))]

In [21]:
num_epochs = 10
device = 'cpu'
model.to(device)

for epoch in range(num_epochs):
  for center, context, neg_samples in dataset:
    center.to(device); context.to(device); neg_samples.to(device)
    emb_center, emb_context, emb_negsamples = model(center, context, neg_samples)
    loss_val = loss(emb_center, emb_context, emb_negsamples)
    optimizer.zero_grad()
    loss_val.backward()
    optimizer.step()
  if epoch % 1 == 0:
    print("Epoch {}, loss {}".format(epoch, loss_val.item()))


Epoch 0, loss 1.6876938343048096
Epoch 1, loss 1.4038118124008179
Epoch 2, loss 1.1912139654159546
Epoch 3, loss 1.0296595096588135
Epoch 4, loss 0.9045617580413818
Epoch 5, loss 0.8056939840316772
Epoch 6, loss 0.7259702682495117
Epoch 7, loss 0.6604741215705872
Epoch 8, loss 0.6057572364807129
Epoch 9, loss 0.5593625903129578


## Compute similarity and Test

In [57]:
def words_similarity(word, model):
  word_emb = model.center_embeddings(word)
  word_emb /= torch.norm(word_emb, dim = 1).squeeze()

  context_weights = model.context_embeddings.weight.data
  context_weights /= torch.norm(context_weights, dim = 1).reshape()

  shape_0, shape_1 = context_weights.shape
  similarities = torch.matmul(word_emb, context_weights.reshape(shape_1, shape_0))
  values, indexes = torch.topk(similarities, 5)
  most_similar = [index_and_word[id] for id in indexes[0]]
  return most_similar

In [58]:
test = index_and_word[0]
words_similarity(torch.tensor([0]), model)

RuntimeError: The size of tensor a (100) must match the size of tensor b (9145) at non-singleton dimension 1