# Word2Vec

To simplify, we will train the following small corpus from NLTK.

In [1]:
from nltk.corpus import inaugural

raw_corpus = inaugural.raw("1789-Washington.txt")
print(f"words (including punctuation): {len(inaugural.words('1793-Washington.txt'))}")
raw_corpus

words (including punctuation): 147


'Fellow-Citizens of the Senate and of the House of Representatives:\n\nAmong the vicissitudes incident to life no event could have filled me with greater anxieties than that of which the notification was transmitted by your order, and received on the 14th day of the present month. On the one hand, I was summoned by my Country, whose voice I can never hear but with veneration and love, from a retreat which I had chosen with the fondest predilection, and, in my flattering hopes, with an immutable decision, as the asylum of my declining years -- a retreat which was rendered every day more necessary as well as more dear to me by the addition of habit to inclination, and of frequent interruptions in my health to the gradual waste committed on it by time. On the other hand, the magnitude and difficulty of the trust to which the voice of my country called me, being sufficient to awaken in the wisest and most experienced of her citizens a distrustful scrutiny into his qualifications, could not

In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import lightning as L

from collections import Counter
from torch.utils.data import Dataset, DataLoader
from torch import optim

## Aside: understanding `nn.Embedding` in PyTorch

`nn.Embedding` generates a $n\times m$ matrix where n = number of words (ie. vocabulary) and m = dimension of each word embedding. The creation of a `nn.Embedding` object will randomly initialize its entries (user can specify distribution to draw from) in model class relation.

A forward pass to a `nn.Embedding` object takes a torch tensor of `torch.int64` data type (ie. `torch.LongTensor` type) and returns the corresponding indicies row of embeddings to the integer entries of the input. (ex. [0] will return the 1st embedding vector, see examples in cells below). 


In [3]:
embedding = nn.Embedding(3, 3) # create a 3x3 matrix of embeddings
print(embedding.weight)
input = torch.tensor([0, 1], dtype=torch.int64) # returns the 1st and 2nd ROW of the matrix
print(input.dtype)
embedding(input)

Parameter containing:
tensor([[ 0.7938, -1.7842, -0.1968],
        [-0.1610, -1.7579, -2.6122],
        [ 0.3320,  0.9222, -0.3271]], requires_grad=True)
torch.int64


tensor([[ 0.7938, -1.7842, -0.1968],
        [-0.1610, -1.7579, -2.6122]], grad_fn=<EmbeddingBackward0>)

In [4]:
embedding = nn.Embedding(5, 3) # create a 5x3 matrix of embeddings
print(embedding.weight)
input = torch.tensor([[0, 1], [3,4]], dtype=torch.long) 
# returns 3rd order tensor with first entry in 1st dimension = 1st and 2nd row embedding vectors and second entry = 4th and 5th row embedding vectors
embedding(input)

Parameter containing:
tensor([[-1.3579,  0.0149,  1.2664],
        [-1.2165, -0.4107,  0.3443],
        [-1.1959, -0.7602,  0.4023],
        [-1.6276,  0.8704, -0.8173],
        [-1.0997, -0.4593,  0.1944]], requires_grad=True)


tensor([[[-1.3579,  0.0149,  1.2664],
         [-1.2165, -0.4107,  0.3443]],

        [[-1.6276,  0.8704, -0.8173],
         [-1.0997, -0.4593,  0.1944]]], grad_fn=<EmbeddingBackward0>)

## Skip-gram architecture

### Model inputs

Skip-gram word2vec takes in a pair of integers that represent a word pairing between a "center word" and an "outside word" the integers are usually the index of the word (or token) in the vocabulary (list object of words/token the model is to be trained on). The center word here is self-explanatory. What determines if the word should be outside is based on the `window` parameter of the model where it determines how many slots are to be considered (usually 2-4 slots). 

In [108]:
# create dataset class to handle corpus and feeding dataset for training + inference (ie. generate integer pairs from a corpus)

class SkipGramDataset(Dataset):
    def __init__(
        self, 
        corpus: str, 
        window_size: int = 2,
        min_count: int = 5 # required minimum times appeared by the word in corpus to be incorporated into vocabulary
    ):
        self.window_size = window_size

        self.tokens = self._tokenize(corpus)
        self.vocab, self.word2idx, self.idx2word = self._build_vocab(min_count)
        self.pairs = self._generate_pairs()
        
    def __len__(self) -> int:
        return len(self.pairs)
    
    def __getitem__(self, index: int):
        target = self.pairs[index]["target"]
        context = self.pairs[index]["context"]
        return torch.tensor([target], dtype=torch.long), torch.tensor([context], dtype=torch.long)
    
    def _tokenize(self, corpus: str) -> list[str]:
        tokens = corpus.lower().split()
        return tokens
    
    def _build_vocab(self, min_count: int) -> tuple[list[str], dict[str, int], dict[int, str]]:
        """Build vocabulary from tokens where it takes only words that occur more than minimal count and indices them in order of appearance"""
        word_counts = Counter(self.tokens)
        vocab = [word for word, count in word_counts.items() if count >= min_count]
        word_2_index_dict = {word: idx for idx, word in enumerate(vocab)}
        index_2_word_dict = {idx: word for word, idx in word_2_index_dict.items()}
        return vocab, word_2_index_dict, index_2_word_dict
    
    def _generate_pairs(self) -> list[dict[str, int]]:
        """
            For each word i in the corpus which are in the vocabulary:
                create a pair where:
                    1. target: index of the word i
                    2. context: index of nth closest word (that is in the vocabulary) to i (where n is dictated by window_size)
                append pair to pair list
                repeat for same target with (n - 1)th closest applicable word 
                    and onwards until window size exhausted
        """
        pairs = []
        indexed_tokens = [self.word2idx[token] for token in self.tokens if token in self.vocab]

        for i, target_index in enumerate(indexed_tokens):
            #Get context window
            start = max(0, i - self.window_size)
            end = min(len(indexed_tokens), i + self.window_size + 1)

            context_indices = [indexed_tokens[j] for j in range(start, end) if j != i]

            # Skip-gram: predict context from target
            for context_index in context_indices:
                pairs.append({
                    "target": target_index, 
                    "context": context_index
                })

        return pairs



In [109]:
corpus_dataset = SkipGramDataset(raw_corpus)
corpus_dataset[0]

(tensor([0]), tensor([1]))

### Model

In [123]:
class SkipGram(L.LightningModule):
    def __init__(self, vocab_size: int, embedding_dim: int, learning_rate: float):
        super().__init__()
        self.v_embeddings = nn.Embedding(vocab_size, embedding_dim) 
        self.u_embeddings = nn.Embedding(vocab_size, embedding_dim)

        # Initialize with small random values
        self.v_embeddings.weight.data.uniform_(-0.5 / embedding_dim, 0.5 / embedding_dim)
        self.u_embeddings.weight.data.uniform_(-0.5 / embedding_dim, 0.5 / embedding_dim)

        # Hyperparamters
        self.learning_rate = learning_rate

    def forward(self, center_word: torch.Tensor, outside_word: torch.Tensor) -> torch.Tensor:
        v_c = self.v_embeddings(center_word) #the embeddings are essentially row vectors
        u_o = self.u_embeddings(outside_word) #the embeddings are essentially row vectors

        numerator = torch.exp((u_o @ v_c.transpose(1, 2)).squeeze(2)) 
        lower_product = torch.exp((self.u_embeddings.weight.data @ v_c.transpose(1, 2)).squeeze(2)) 
        denominator = torch.sum(lower_product)

        probability = numerator / denominator

        return probability
    
    def loss(self, center_word: torch.Tensor, outside_word: torch.Tensor) -> torch.Tensor:
        _probability = self.forward(center_word, outside_word)
        loss = -torch.mean(torch.log(_probability).unsqueeze(1))

        return loss
    
    def training_step(self, batch: tuple[torch.Tensor], batch_idx: int):
        _input_target, _input_context = batch
        training_loss = self.loss(_input_target, _input_context)
        return training_loss
    
    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer



In [125]:
# test some outputs
model = SkipGram(len(corpus_dataset.vocab), 10, learning_rate=0.001)
output = model(torch.LongTensor([[0]]), torch.LongTensor([[2]]))
output

tensor([[0.0250]], grad_fn=<DivBackward0>)

In [126]:
loss = model.loss(torch.LongTensor([[0]]), torch.LongTensor([[2]]))
loss

tensor(3.6887, grad_fn=<NegBackward0>)

### Training

In [127]:
max_epochs = 200
batch_size = 10

model = SkipGram(len(corpus_dataset.vocab), 10, learning_rate=0.001)
trainer = L.Trainer(max_epochs=200)

# lr_find_results = trainer.tuner.lr_find(model, train_dataloaders=)

dataloader = DataLoader(corpus_dataset, batch_size=batch_size)

trainer.fit(model, train_dataloaders=dataloader)


ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
/home/tony/anaconda3/envs/rec-sys/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default

  | Name         | Type      | Params | Mode 
---------------------------------------------------
0 | v_embeddings | Embedding | 400    | train
1 | u_embeddings | Embedd

Epoch 199: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 279/279 [00:00<00:00, 604.05it/s, v_num=0]

`Trainer.fit` stopped: `max_epochs=200` reached.


Epoch 199: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 279/279 [00:00<00:00, 600.95it/s, v_num=0]


In [131]:
model(torch.LongTensor([[0]]), torch.LongTensor([[2]]))

tensor([[0.1402]], grad_fn=<DivBackward0>)

### Get embeddings for given word

In [None]:
def get_embedding(word: str, corpus: Dataset = corpus_dataset) -> torch.Tensor:
    try:
        idx = corpus.word2idx[word]
    except KeyError:
        raise Exception(f"{word} is not part of the vocabulary")
    else:
        _u = model.u_embeddings.weight.data[idx]
        _v = model.v_embeddings.weight.data[idx]

        embedding = (_u + _v) / 2 # the entries of the embedding vector for a given word is the mean of the u and v components

    return embedding

In [148]:
this = get_embedding("this")
this

tensor([-3.5031, -3.2055, -1.6266,  3.9733, -5.6469, -5.0813,  2.1924, -1.2415,
         4.9197, -3.2067])