📝 **Author:** Amirhossein Heydari - 📧 **Email:** amirhosseinheydari78@gmail.com - 📍 **Linktree:** [linktr.ee/mr_pylin](https://linktr.ee/mr_pylin)

---

# Dependencies

In [1]:
from collections import Counter

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.manifold import TSNE
from torch.utils.data import DataLoader, Dataset
from torchinfo import summary

from datasets import load_dataset

In [2]:
# set a seed for deterministic results
seed = 42
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
# check if cuda is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# log
device

# Word Embedding using Word2Vec [Skip-Grams method]
🌟 **Example**:
   - Sentence: **The quick brown fox jumps over the lazy dog.**
   - **WINDOW_SIZE: 2**
      - For "The" (center word), context words are ["quick"].
      - For "quick" (center word), context words are ["The", "brown"].
      - For "brown" (center word), context words are ["quick", "The", "fox", "jumps"].
      - For "fox" (center word), context words are ["brown", "quick", "jumps"].
      - For "jumps" (center word), context words are ["fox", "brown", "over"].
      - For "over" (center word), context words are ["jumps", "the"].
      - For "the" (center word), context words are ["over"].
   - **NEGATIVE_SAMPLES: 5**
      - If we are predicting "brown" and the context words are ["quick", "The", "fox", "jumps"].
      - we will also randomly select 5 words from the rest of the vocabulary (e.g., "lazy", "dog", etc.) and treat them as negative samples.

## Hyperparameters

In [4]:
EMBEDDING_DIM = 100
WINDOW_SIZE = 2
BATCH_SIZE = 2048
EPOCHS = 2
LEARNING_RATE = 0.01
NEGATIVE_SAMPLES = 5

## Step 1: Load and preprocess the dataset

In [5]:
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", cache_dir="../../datasets/")
train_text = dataset["train"]["text"]

In [None]:
# tokenize and build vocabulary
tokenized_sentences = [sentence.split() for sentence in train_text if sentence.strip()]
words = [word for sentence in tokenized_sentences for word in sentence]
vocab = Counter(words)

# log
print(f"vocabulary size             : {len(vocab)}")
print(f"top 5 most  frequent tokens : {vocab.most_common(5)}")
print(f"top 5 least frequent tokens : {sorted(vocab.items(), key=lambda x: x[1])[:5]}")

In [None]:
# filter out rare words
min_freq = 5
vocab = {word: freq for word, freq in vocab.items() if freq >= min_freq}
word2idx = {word: idx for idx, (word, _) in enumerate(vocab.items())}
idx2word = {idx: word for word, idx in word2idx.items()}
vocab_size = len(word2idx)

# log
print(f"vocabulary size             : {vocab_size}")
print(f"top 5 most  frequent tokens : {sorted(vocab.items(), key=lambda x: x[1], reverse=True)[:5]}")
print(f"top 5 least frequent tokens : {sorted(vocab.items(), key=lambda x: x[1])[:5]}")

In [None]:
# generate skip-grams with negative sampling
def generate_skip_grams(
    words: list[str], word2idx: dict[str, int], window_size: int, vocab: dict[str, int]
) -> list[tuple[int, int]]:
    pairs = []
    for idx, word in enumerate(words):
        if word not in word2idx:
            continue
        center_idx = word2idx[word]
        context_range = range(max(0, idx - window_size), min(len(words), idx + window_size + 1))
        for context_idx in context_range:
            if context_idx == idx:
                continue
            context_word = words[context_idx]
            if context_word in word2idx:
                pairs.append((center_idx, word2idx[context_word]))
    return pairs


# generate skip-grams
skip_grams = generate_skip_grams(words, word2idx, WINDOW_SIZE, vocab)

# log
print(f"total number of skip-grams         : {len(skip_grams)}")
print(f"sequence of words [1000 to 1006]   : {words[1000:1007]}")
print("{word:idx} for the above sequence] :", {w: word2idx[w] for w in words[1000:1007]})
print(
    f"skip-gram pairs for above sequence : {[skip_grams[skip_grams.index((401, 200)) + i] for i in range(WINDOW_SIZE)]}"
)

## Step 2: Create Dataset and DataLoader

In [9]:
class SkipGramDataset(Dataset):
    def __init__(self, skip_grams, vocab_size, neg_samples):
        self.skip_grams = skip_grams
        self.vocab_size = vocab_size
        self.neg_samples = neg_samples

    def __len__(self):
        return len(self.skip_grams)

    def __getitem__(self, idx):
        center, context = self.skip_grams[idx]
        negatives = torch.multinomial(torch.tensor([1.0] * self.vocab_size), self.neg_samples, replacement=True)
        return torch.tensor(center, dtype=torch.long), torch.tensor(context, dtype=torch.long), negatives

In [None]:
dataset = SkipGramDataset(skip_grams, vocab_size, NEGATIVE_SAMPLES)
data_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
num_batches = len(data_loader)

# log
for i in range(WINDOW_SIZE * 2):
    temp_sample = dataset[skip_grams.index((401, 200)) + i]
    print(temp_sample)
    print(f"  negative samples:")
    for j in range(NEGATIVE_SAMPLES):
        print("    ", {temp_sample[2][j].item(): idx2word[temp_sample[2][j].item()]})
    print()

## Step 3: Define and Initialize the Word2Vec model

In [11]:
class Word2Vec(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(Word2Vec, self).__init__()
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.context_embeddings = nn.Embedding(vocab_size, embedding_dim)

    def forward(self, center, context, negatives):
        center_embed = self.embeddings(center)
        context_embed = self.context_embeddings(context)
        neg_embed = self.context_embeddings(negatives)

        # positive scores
        positive_score = torch.sum(center_embed * context_embed, dim=1)
        positive_loss = -torch.log(torch.sigmoid(positive_score + 1e-9))

        # negative scores
        negative_score = torch.bmm(neg_embed, center_embed.unsqueeze(2)).squeeze()
        negative_loss = -torch.sum(torch.log(torch.sigmoid(-negative_score + 1e-9)), dim=1)

        return (positive_loss + negative_loss).mean()

In [12]:
model = Word2Vec(vocab_size, EMBEDDING_DIM).to(device)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [None]:
model

In [None]:
summary(
    model,
    input_data=(
        (
            torch.randint(0, vocab_size, (BATCH_SIZE,)).to(device),
            torch.randint(0, vocab_size, (BATCH_SIZE,)).to(device),
            torch.randint(0, vocab_size, (BATCH_SIZE, NEGATIVE_SAMPLES)).to(device),
        )
    ),
)

## Step 5: Training Loop

In [None]:
for epoch in range(EPOCHS):
    total_loss = 0
    model.train()

    # log
    print(f"epoch {epoch + 1:>0{len(str(EPOCHS))}}/{EPOCHS}")

    for i, (center, context, negatives) in enumerate(data_loader):
        # Move data to the correct device
        center, context, negatives = center.to(device), context.to(device), negatives.to(device)

        optimizer.zero_grad()
        loss = model(center, context, negatives)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        # log
        if i % 100 == 0:
            print(
                f"  iteration {i + 1:>0{len(str(num_batches))}}/{num_batches}  |  loss: {loss.item():.4f}  |  total loss: {total_loss:.4f}"
            )

    # log
    print(f"epoch {epoch+1:>0{len(str(EPOCHS))}}/{EPOCHS}  |  total loss: {total_loss:.4f}")
    print("-" * 50)

## Step 6: Save embeddings

In [None]:
embeddings = model.embeddings.weight.data
torch.save(embeddings, "../../assets/embeddings/word2vec_embeddings.pt")

# log
print()

## Step 7: Visualize word embeddings using t-SNE

In [20]:
def visualize_embeddings(embeddings: torch.Tensor, idx2word: dict[int, str], num_points: int = 100) -> None:
    tsne = TSNE(n_components=2, random_state=seed)
    reduced_embeddings = tsne.fit_transform(embeddings[:num_points].cpu().numpy())

    plt.figure(figsize=(10, 10))
    for idx, (x, y) in enumerate(reduced_embeddings):
        plt.scatter(x, y)
        plt.text(x, y, idx2word[idx], fontsize=9)
    plt.show()

In [None]:
visualize_embeddings(embeddings, idx2word)