# Re-making word2vec in PyTorch

After partially reading [this](https://lena-voita.github.io/nlp_course/word_embeddings.html) I wanted to try and make word2vec in pytorch to test my understanding

## Imports

In [None]:
import sys
sys.path.append("..")

import re

from helpers.stats import mem_size

from gensim.utils import simple_preprocess

import torch
import torch.optim as optim
import torch.nn.functional as F


device = "cuda" if torch.cuda.is_available() else "cpu"

## Load and pre-process data

In [None]:
with open("data/imdb_text.txt") as f:
    all_text = f.read()

In [None]:
# pre-processing

text = all_text

# remove html
text = re.sub("\<.*\/?\>", "", text)

tokens = simple_preprocess(text)

## Define helpers and params

In [None]:
embedding_dim = 300
window_size = 11
centre_index = window_size // 2

index_to_word = list(set(tokens))
i2wt = lambda x: index_to_word[int(x.item())]
word_to_index = {w: i for i, w in enumerate(index_to_word)}
indices_list = list(map(word_to_index.get, tokens))

vocab_size = len(index_to_word)

## Initialise parameters for optimisation

In [None]:
V = torch.randn((vocab_size, embedding_dim), device=device)
U = torch.randn((vocab_size, embedding_dim), device=device)

V.requires_grad = True
U.requires_grad = True

indices = torch.tensor(indices_list, dtype=torch.float32, device=device).view(1,1,-1)
training = F.unfold(indices, kernel_size=(1, window_size))
training = training.transpose(0, 1)
training = training.to(torch.int32)

n_training = training.size()[0]

## Training loop for single word,context_word pairs

This was extremely slow

```
o(epoch*training_size*window_size) (in python)
```
for each of those we need to matmul all of U (context vectors) with the current central word vector (V_i).


In [None]:
# from tqdm import tqdm

# lr = 0.1

# optimiser = optim.SGD(params=[V, U], lr=0.1)

# for epoch in range(10):
#     print(f"epoch {epoch}")
#     for i in tqdm(range(training.size()[0])):
#         # print(" ".join(map(i2wt, training[i, :])))
#         # print(i2wt(training[i, centre_index]))
#         for j in range(window_size):
#             if j == centre_index:
#                 continue


#             row = training[i, :]
#             v_idx = int(training[i, centre_index].item())
#             u_idx = int(training[i, j].item())

#             v = V[v_idx]

#             j_exp = torch.exp(torch.matmul(v.unsqueeze(0), U.unsqueeze(-1)).squeeze())
#             j_v_u = -j_exp[i] + torch.log(j_exp.sum())

# #             optimiser.zero_grad()

# #             j_v_u.backward()

# #             optimiser.step()
            
#             break
#         break
#     break


## Training loop for batch_of_words,context_word pairs

This is significantly faster, though still quite slow. Also I'm not sure if the objective function I calculate was correct.

```
o(epoch*(training_size//batch_size)*window_size) (in python)
```
for each of those we need to matmul all of U (context vectors) with the current batch of central word vectors.

In [None]:
from tqdm import tqdm

torch.cuda.empty_cache()

batch_size = 1000

lr = 0.1
optimiser = optim.SGD(params=[V, U], lr=0.1)

idx_list = list(range(0, n_training, batch_size)) + [n_training-1]
training_idx_list = list(zip(idx_list[:-1], idx_list[1:]))

# training_batch = training[:batch_size, :]
# window_index = 0

for epoch in range(10):
    for batch_start_idx, batch_end_idx in tqdm(training_idx_list):
        training_batch = training[batch_start_idx:batch_end_idx, :]
        current_batch_size = training_batch.size()[0]
        loss = 0
        for window_index in range(window_size):
            if window_index == centre_index:
                continue

            u_all_idx = training_batch[:, window_index]
            v_all_idx = training_batch[:, centre_index]
            v_all = V[v_all_idx]

            j_dot = torch.tensordot(U.unsqueeze(1), v_all, dims=([-1], [-1])).squeeze()
            j_exp = torch.exp(j_dot)

            numerator = j_exp[u_all_idx, torch.arange(current_batch_size)]
            denominator = j_exp.sum(dim=0)
            inner = torch.clamp(numerator/denominator, min=1e-45)
            j = -(1 / n_training)*torch.log(inner)
           
            loss = j.sum()

            optimiser.zero_grad()
            loss.backward()
            optimiser.step()

In [None]:
# 9617