<a href="https://colab.research.google.com/github/ochaudha/sample/blob/main/TransformerEncoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
https://towardsdatascience.com/word2vec-with-pytorch-implementing-original-paper-2cd7040120b0/

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import math
from torch.utils.data import Dataset, DataLoader

# --- Positional Encoding Class ---
class PositionalEncoding(nn.Module):
    """
    Implements the positional encoding for the Transformer model.
    It adds sinusoidal positional encodings to the input embeddings
    to inject information about the relative or absolute position of tokens
    in the sequence.
    """
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe) # Register 'pe' as a buffer, not a parameter

    def forward(self, x):
        """
        Args:
            x: Input tensor (batch_size, sequence_length, d_model)
        Returns:
            Tensor with positional encoding added.
        """
        # Add positional encoding to the input.
        # self.pe is truncated to the input sequence length.
        x = x + self.pe[:x.size(0), :]
        return x

# --- Transformer Encoder Layer Class ---
class TransformerEncoderLayer(nn.Module):
    """
    A single layer of the Transformer Encoder.
    It consists of a multi-head self-attention mechanism,
    followed by a position-wise feed-forward network.
    Each sub-layer also includes a residual connection and layer normalization.
    """
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
        super(TransformerEncoderLayer, self).__init__()
        # Multi-head self-attention module
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=False)
        # Position-wise feed-forward network
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        # Layer normalization components
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        # Dropout layers for residual connections
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        # Activation function for the feed-forward network
        self.activation = nn.ReLU()

    def forward(self, src, src_mask=None, src_key_padding_mask=None, **kwargs):
        """
        Args:
            src: The input sequence to the encoder layer (sequence_length, batch_size, d_model).
            src_mask: An optional mask for the src sequence (optional).
            src_key_padding_mask: An optional mask for src keys indicating which elements
                                  to ignore in the attention computation (batch_size, sequence_length).
        Returns:
            The output of the Transformer Encoder Layer (sequence_length, batch_size, d_model).
        """
        # Multi-head self-attention
        # src, src, src are used for Q, K, V respectively in self-attention
        attn_output, _ = self.self_attn(src, src, src, attn_mask=src_mask,
                                        key_padding_mask=src_key_padding_mask)

        # First residual connection and layer normalization
        src = src + self.dropout1(attn_output)
        src = self.norm1(src)

        # Position-wise feed-forward network
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))

        # Second residual connection and layer normalization
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src

# --- Transformer Encoder Model Class ---
class TransformerEncoder(nn.Module):
    """
    The full Transformer Encoder model.
    It comprises an embedding layer, positional encoding,
    and a stack of TransformerEncoderLayers.
    """
    def __init__(self, vocab_size, d_model, nhead, num_encoder_layers, dim_feedforward, dropout):
        super(TransformerEncoder, self).__init__()

        # Token embedding layer
        self.embedding = nn.Embedding(vocab_size, d_model)
        # Positional encoding layer
        self.pos_encoder = PositionalEncoding(d_model)
        # Single encoder layer instance to be replicated by nn.TransformerEncoder
        encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout)
        # Stack of TransformerEncoderLayers
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_encoder_layers)
        self.d_model = d_model

        # Linear layer for the final output, mapping d_model to vocab_size
        self.output_layer = nn.Linear(d_model, vocab_size)

    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        """
        Args:
            src: Input tensor of token indices (sequence_length, batch_size).
            src_mask: An optional mask for the src sequence (optional).
            src_key_padding_mask: An optional mask for src keys indicating which elements
                                  to ignore in the attention computation (batch_size, sequence_length).
        Returns:
            The output logits for each token in the vocabulary for each position
            (sequence_length, batch_size, vocab_size).
        """
        # 1. Embed tokens and scale by sqrt(d_model) as per the paper
        src = self.embedding(src) * math.sqrt(self.d_model)
        # 2. Add positional encodings
        src = self.pos_encoder(src)
        # 3. Pass through the Transformer Encoder stack
        output = self.transformer_encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
        # 4. Apply a linear layer to get logits over the vocabulary
        output = self.output_layer(output)
        return output

# --- Dummy Data Setup ---
# Define the vocabulary
vocab = ["<pad>", "hello", "world", "transformer", "encoder", "pytorch", "<sos>", "<eos>"]
# Create mappings from word to index and index to word
word_to_idx = {word: idx for idx, word in enumerate(vocab)}
idx_to_word = {idx: word for idx, word in enumerate(vocab)}
vocab_size = len(vocab)
pad_idx = word_to_idx["<pad>"] # Index for padding token

# Example dummy sentences (sequences)
dummy_data = [
    ["<sos>", "hello", "world", "<eos>"],
    ["<sos>", "transformer", "encoder", "pytorch", "<eos>"],
    ["<sos>", "hello", "encoder", "<eos>"]
]

# Determine the maximum sequence length in the dummy data
max_len = max(len(seq) for seq in dummy_data)

# Tokenize and pad the dummy data
tokenized_data = []
for seq in dummy_data:
    token_ids = [word_to_idx[token] for token in seq]
    # Pad sequences to max_len
    token_ids += [pad_idx] * (max_len - len(token_ids))
    tokenized_data.append(token_ids)

# Convert tokenized data to a PyTorch tensor
# Transpose to get [sequence_length, batch_size] as expected by the Transformer
input_tensor = torch.tensor(tokenized_data, dtype=torch.long).transpose(0, 1)

# Create a key padding mask to ignore padding tokens during attention calculation
# A True value indicates that the corresponding key token will be ignored.
src_key_padding_mask = (input_tensor == pad_idx).transpose(0, 1) # [batch_size, sequence_length]

# --- Model Instantiation ---
# Define model hyperparameters
d_model = 64            # Dimension of the model (embedding size)
nhead = 4               # Number of attention heads
num_encoder_layers = 2  # Number of Transformer encoder layers
dim_feedforward = 128   # Dimension of the feed-forward network
dropout = 0.1           # Dropout rate

# Instantiate the Transformer Encoder model
model = TransformerEncoder(vocab_size, d_model, nhead, num_encoder_layers, dim_feedforward, dropout)

# --- Training (Example with a dummy next-token prediction task) ---
# For a simple next-token prediction task, the target sequence is
# the input sequence shifted by one position, with the last token padded.
target_tensor = torch.roll(input_tensor, shifts=-1, dims=0)
target_tensor[-1, :] = pad_idx # Pad the last token of each sequence

# Define the loss function (CrossEntropyLoss is suitable for classification/prediction tasks)
# ignore_index=pad_idx ensures that padding tokens do not contribute to the loss.
criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)
# Define the optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 10
print("\n--- Starting Training ---")
for epoch in range(num_epochs):
    model.train() # Set the model to training mode
    optimizer.zero_grad() # Clear gradients from previous iteration

    # Forward pass: get model output (logits)
    output = model(input_tensor, src_key_padding_mask=src_key_padding_mask)

    # Reshape output and target for CrossEntropyLoss
    # output: [sequence_length * batch_size, vocab_size]
    # target: [sequence_length * batch_size]
    output = output.view(-1, vocab_size)
    target = target_tensor.view(-1)

    # Calculate the loss
    loss = criterion(output, target)

    # Backward pass: compute gradients
    loss.backward()
    # Optimizer step: update model parameters
    optimizer.step()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

print("Training finished.")

# --- Inference (Example) ---
model.eval() # Set the model to evaluation mode (disables dropout, etc.)
print("\n--- Starting Inference Example ---")
with torch.no_grad(): # Disable gradient calculations for inference
    # Take the first sequence from the dummy input for inference
    sample_input = input_tensor[:, 0].unsqueeze(1) # [sequence_length, 1]
    # Take the corresponding key padding mask for the first sequence
    sample_mask = src_key_padding_mask[0].unsqueeze(0) # [1, sequence_length]

    # Get the output representation from the model
    # The output here will be logits for each token in the vocabulary.
    # For a real application, you might use argmax to get the predicted token ID.
    output_logits = model(sample_input, src_key_padding_mask=sample_mask)

    print("\nInference Output Logits (for the first sequence):")
    print(f"Shape of output_logits: {output_logits.shape}")
    # Display the actual logits (for illustration)
    # print(output_logits)

    # To get the predicted word for each position in the sample sequence:
    predicted_token_ids = torch.argmax(output_logits, dim=-1).squeeze(1)
    predicted_words = [idx_to_word[idx.item()] for idx in predicted_token_ids]

    original_sequence = [idx_to_word[idx.item()] for idx in input_tensor[:, 0]]
    print(f"Original Sequence (first batch item): {original_sequence}")
    print(f"Predicted Tokens (for the first sequence): {predicted_words}")

    # You could also examine the output of the transformer_encoder directly if you
    # wanted the learned contextual embeddings before the final linear layer.
    # For that, you would need a separate forward method in the model or
    # to access the internal output of self.transformer_encoder.
    # For now, `output_logits` is the direct output from your current model setup.


--- Starting Training ---
Epoch [1/10], Loss: 2.0706
Epoch [2/10], Loss: 1.8092
Epoch [3/10], Loss: 1.5785
Epoch [4/10], Loss: 1.2701
Epoch [5/10], Loss: 1.1294
Epoch [6/10], Loss: 0.9963
Epoch [7/10], Loss: 0.8781
Epoch [8/10], Loss: 0.7436
Epoch [9/10], Loss: 0.6703
Epoch [10/10], Loss: 0.6024
Training finished.

--- Starting Inference Example ---

Inference Output Logits (for the first sequence):
Shape of output_logits: torch.Size([5, 1, 8])
Original Sequence (first batch item): ['<sos>', 'hello', 'world', '<eos>', '<pad>']
Predicted Tokens (for the first sequence): ['hello', 'world', '<eos>', 'world', 'world']


In [4]:


import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from collections import Counter
import numpy as np
import random # Although not explicitly used in the final version, good for general utilities
from tqdm.auto import tqdm # For nice progress bars in Colab

# --- 0. Configuration and Hyperparameters ---
class Config1:
    def __init__(self):
        self.embedding_dim = 100        # Dimension of word embeddings.
                                        # Intuition: This is the size of the vector space
                                        # where words will "live". Higher dimension allows
                                        # capturing more nuanced semantic relationships.
                                        # Why: A vector of this size will represent each word,
                                        # capturing its meaning based on its context.

        self.window_size = 4            # Context window size (4 history + 4 future words).
                                        # Intuition: How many words around a target word
                                        # define its 'context'. Larger window means broader context.
                                        # Why: Word2Vec's core idea is "words that appear in similar
                                        # contexts have similar meanings." This parameter defines
                                        # what "context" means.

        self.num_epochs = 5             # Number of training epochs.
                                        # Intuition: How many times the model sees the entire dataset.
                                        # Why: More epochs allow the model to refine its embeddings
                                        # by iteratively learning from more examples.

        self.batch_size = 64            # Batch size for training.
                                        # Intuition: Number of (context, target) pairs processed
                                        # in one optimization step.
                                        # Why: Training in batches balances computational efficiency
                                        # (processing multiple examples at once) with memory usage
                                        # and the stability of gradient updates.

        self.learning_rate = 0.001      # Learning rate for Adam optimizer.
                                        # Intuition: How big a step the optimizer takes when
                                        # updating model weights based on the loss gradient.
                                        # Why: Controls how quickly the model learns. Too high,
                                        # it might overshoot; too low, it learns very slowly.

        self.min_freq = 5               # Minimum frequency for words to be included in vocab.
                                        # Intuition: Words appearing less than this threshold
                                        # are treated as '<unk>' (unknown).
                                        # Why: Reduces vocabulary size, saving memory and
                                        # computation. Rare words often don't provide enough
                                        # context to learn robust embeddings anyway.

        self.negative_samples = 5       # Number of negative samples for Negative Sampling.
                                        # Intuition: For each positive (target, context) pair,
                                        # we sample this many "noise" words that are *not* in
                                        # the context. (Note: This is set but not fully utilized
                                        # in this specific "Plain Softmax" implementation, but
                                        # is crucial for a full Negative Sampling setup.)
                                        # Why: Speeds up training for very large vocabularies
                                        # by turning a multi-class classification problem (predict
                                        # one of V words) into a binary classification problem
                                        # (is this a context word or not?) for a few selected words.

        self.use_cbow = True            # Set to True for CBOW, False for Skip-Gram.
                                        # Intuition: Chooses between the two Word2Vec architectures.
                                        # Why: Allows switching between predicting target from context
                                        # (CBOW) or predicting context from target (Skip-Gram).

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
                                        # Intuition: Determines whether to use GPU (CUDA) or CPU.
                                        # Why: GPUs are much faster for deep learning computations.

# --- 0. Configuration and Hyperparameters ---
class Config:
    def __init__(self):
        self.embedding_dim = 100
        self.window_size = 4
        self.num_epochs = 5
        self.batch_size = 64
        self.learning_rate = 0.001
        # Change min_freq from 5 to 1
        self.min_freq = 1  # <--- Change this line
        self.negative_samples = 5
        self.use_cbow = True
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#config = Config()
#print(f"Using device: {config.device}")

config = Config()
print(f"Using device: {config.device}")

# --- 1. Data Processing and Dataset Class ---
class Word2VecDataset(Dataset):
    def __init__(self, text, window_size, vocab_to_idx, idx_to_vocab, word_counts, ns_probs, device, model_type="cbow"):
        self.window_size = window_size
        self.vocab_to_idx = vocab_to_idx
        self.idx_to_vocab = idx_to_vocab
        self.word_counts = word_counts
        self.ns_probs = ns_probs
        self.device = device
        self.model_type = model_type
        self.data = self._build_data(text)

    def _build_data(self, text):
        raw_words = text.lower().split()
        # Convert raw words to their numerical indices. Unknown words map to '<unk>'.
        indexed_words = [self.vocab_to_idx.get(word, self.vocab_to_idx['<unk>']) for word in raw_words]

        data = []
        for i, target_word_idx in enumerate(indexed_words):
            if target_word_idx == self.vocab_to_idx['<unk>']:
                continue # Why: We usually don't train on or try to predict unknown words.

            context_indices = []
            # Gather context words within the window
            for j in range(max(0, i - self.window_size), min(len(indexed_words), i + self.window_size + 1)):
                if i != j: # Exclude the target word itself
                    context_word_idx = indexed_words[j]
                    if context_word_idx != self.vocab_to_idx['<unk>']:
                        context_indices.append(context_word_idx)

            if self.model_type == "cbow":
                # For CBOW, input is context, target is the word
                # Intuition: Predict the middle word from its surrounding words.
                # Example: ([The, brown, fox, over], quick)
                if context_indices: # Only add if there are valid context words
                    data.append((context_indices, target_word_idx))
            elif self.model_type == "skipgram":
                # For Skip-Gram, input is the word, target is each context word
                # Intuition: Predict surrounding words from the middle word.
                # Example: (quick, The), (quick, brown), (quick, fox), (quick, over)
                for context_word_idx in context_indices:
                    data.append((target_word_idx, context_word_idx))
        return data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if self.model_type == "cbow":
            context_indices, target_word_idx = self.data[idx]
            context_tensor = torch.tensor(context_indices, dtype=torch.long)
            target_tensor = torch.tensor(target_word_idx, dtype=torch.long)
            return context_tensor, target_tensor
        elif self.model_type == "skipgram":
            target_word_idx, context_word_idx = self.data[idx]
            target_tensor = torch.tensor(target_word_idx, dtype=torch.long)
            context_tensor = torch.tensor(context_word_idx, dtype=torch.long)
            return target_tensor, context_tensor

def build_vocab_and_mappings(text, min_freq):
    words = text.lower().split()
    word_counts = Counter(words)

    # Filter words by minimum frequency
    vocab = sorted([word for word, count in word_counts.items() if count >= min_freq])

    # Add special tokens
    vocab = ['<pad>', '<unk>'] + vocab

    vocab_to_idx = {word: i for i, word in enumerate(vocab)}
    idx_to_vocab = {i: word for i, word in enumerate(vocab)}

    # Calculate sampling probabilities for Negative Sampling (P(w) = count(w)^(3/4))
    # Intuition: Words that occur more frequently are sampled less aggressively
    # as negative examples, while less frequent words are sampled more often
    # in proportion to their smoothed frequency. This helps balance the training.
    # Why: The original paper found that raising the unigram distribution to the
    # 3/4 power empirically performed better for negative sampling, giving a slight
    # boost to rarer words' chances of being selected as negatives.
    total_words = sum(word_counts.values())
    ns_probs = np.zeros(len(vocab))
    for i, word in idx_to_vocab.items():
        if word in word_counts:
            # $P(w) = \frac{count(w)^{3/4}}{\sum_{w'} count(w')^{3/4}}$
            # This formula gives a probability distribution where rarer words
            # have a relatively higher chance of being sampled than in a raw
            # frequency distribution.
            ns_probs[i] = word_counts[word]**0.75 # Intuition: Raw frequency count raised to power 0.75
                                                  # Why: Empirically found to be effective for negative sampling.
                                                  # It "smooths" the distribution, reducing the dominance of
                                                  # extremely common words.
    ns_probs /= np.sum(ns_probs) # Normalize to sum to 1.
                                 # Intuition: Ensures `ns_probs` is a valid probability distribution.
                                 # Why: Required for `torch.multinomial` which expects probabilities summing to 1.

    return vocab_to_idx, idx_to_vocab, word_counts, ns_probs

# Collate function for DataLoader (especially for CBOW where context length varies)
def collate_fn_cbow(batch):
    contexts = [item[0] for item in batch]
    targets = torch.stack([item[1] for item in batch])

    # Pad contexts to the maximum length in the batch
    max_len = max(len(c) for c in contexts)
    padded_contexts = torch.zeros(len(contexts), max_len, dtype=torch.long)
    for i, context in enumerate(contexts):
        padded_contexts[i, :len(context)] = context

    return padded_contexts, targets

# --- 2. Model Architectures ---
class CBOW(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(CBOW, self).__init__()
        # `nn.Embedding` layer: A lookup table for word vectors.
        # It takes integer indices and returns their corresponding dense vectors.
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        # `nn.Linear` layer: Projects the context embedding (average of context words)
        # back to the size of the vocabulary. This output represents the unnormalized
        # log-probabilities (logits) for each word in the vocabulary being the target word.
        self.linear = nn.Linear(embedding_dim, vocab_size)

    def forward(self, context_word_indices):
        # context_word_indices: (batch_size, max_context_len)
        embeds = self.embeddings(context_word_indices) # (batch_size, max_context_len, embedding_dim)
        # Intuition: For each word index in the context, retrieve its high-dimensional vector representation.

        # Average the embeddings of the context words (OlgaChernytska's note)
        # This is the core of CBOW: the context is represented by the average of its word embeddings.
        sum_embeds = embeds.sum(dim=1) # (batch_size, embedding_dim)
                                       # Intuition: Summing all embedding vectors for words in the context.
                                       # Why: Simpler to compute than averaging directly if padded.

        num_words = (context_word_indices != 0).sum(dim=1).float().unsqueeze(1) # (batch_size, 1)
                                                                               # Intuition: Count of non-padding (non-zero) words in each context.
                                                                               # Why: Essential for correct averaging, otherwise padded zeros would skew the average.
        num_words[num_words == 0] = 1.0 # Avoid division by zero if context is empty (e.g., all padding, though should be avoided by `if context_indices` in dataset)
                                        # Why: Prevents runtime errors for potentially empty contexts (e.g., short sentences).

        avg_embeds = sum_embeds / num_words # (batch_size, embedding_dim)
                                            # Intuition: $E_{context} = \frac{1}{|C|} \sum_{w \in C} E_w$
                                            # This is the aggregated representation of the context.
                                            # Why: The CBOW model assumes the target word can be predicted
                                            # from a combined (averaged) representation of its surrounding words.

        output = self.linear(avg_embeds) # (batch_size, vocab_size) - logits
                                         # Intuition: Project the context embedding to a vector of size `vocab_size`.
                                         # Each value in this vector is a raw score (logit) for a word in the vocabulary.
                                         # Why: These logits will be used by `CrossEntropyLoss` to calculate
                                         # the probability of each word being the correct target word.
        return output

class SkipGram(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(SkipGram, self).__init__()
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        # Output layer: Projects the target word's embedding to a vector of `vocab_size`.
        # This output represents the unnormalized log-probabilities (logits) for each word
        # in the vocabulary being a context word.
        self.linear = nn.Linear(embedding_dim, vocab_size)

    def forward(self, target_word_index):
        # target_word_index: (batch_size,)
        # Get embedding for the target word
        embed = self.embeddings(target_word_index) # (batch_size, embedding_dim)
                                                   # Intuition: Retrieve the vector representation of the single target word.
                                                   # Why: In Skip-Gram, the central word is used to predict its context.

        output = self.linear(embed) # (batch_size, vocab_size) - logits
                                    # Intuition: Project the target word's embedding to a vector of size `vocab_size`.
                                    # Each value is a raw score for a word in the vocabulary being a context word.
                                    # Why: These logits will be used by `CrossEntropyLoss` to calculate
                                    # the probability of each word being the correct context word.
        return output

# --- 3. Trainer Class ---
class Word2VecTrainer:
    def __init__(self, model, dataloader, optimizer, criterion, config, vocab_to_idx, ns_probs):
        self.model = model.to(config.device)
        self.dataloader = dataloader
        self.optimizer = optimizer
        self.criterion = criterion
        self.config = config
        self.vocab_to_idx = vocab_to_idx
        self.ns_probs = ns_probs
        self.norm_embeddings = True # Flag to apply norm constraint

    def _sample_negatives(self, batch_size, num_negative_samples, exclude_target_idx=None):
        # Sample negative word indices based on their frequency probabilities (ns_probs)
        neg_samples = torch.multinomial(
            torch.from_numpy(self.ns_probs).float(), # Convert pre-calculated `ns_probs` to a PyTorch tensor.
                                                    # Intuition: This is our weighted probability distribution
                                                    # for sampling words.
                                                    # Why: `torch.multinomial` needs a tensor of probabilities.
            batch_size * num_negative_samples,      # Total number of samples needed.
                                                    # Intuition: We need `num_negative_samples` for each item in the batch.
                                                    # Why: Efficiently generate all negative samples in one go.
            replacement=True                        # Allow sampling the same word multiple times.
                                                    # Why: Ensures we can always get the required number of samples.
        ).to(self.config.device)
        return neg_samples.view(batch_size, num_negative_samples)
        # Intuition: Reshape the flat list of samples into a (batch_size, num_negative_samples) matrix.
        # Why: Each row corresponds to a batch item, with its respective negative samples.

    def train(self):
        self.model.train() # Sets the model to training mode (e.g., enables dropout if present, not here).
        total_loss = 0
        pbar = tqdm(self.dataloader, desc="Training") # Visual progress bar.

        for i, (inputs, targets) in enumerate(pbar):
            inputs = inputs.to(self.config.device)
            targets = targets.to(self.config.device)

            self.optimizer.zero_grad() # Clear gradients from the previous iteration.
                                       # Intuition: Prevents gradients from accumulating across batches.
                                       # Why: Each batch's gradient is independent for parameter updates.

            if self.config.use_cbow:
                predictions = self.model(inputs) # (batch_size, vocab_size) - logits
                                                 # Intuition: The model predicts the target word's identity
                                                 # based on its context words.
                loss = self.criterion(predictions, targets)
            else:
                predictions = self.model(inputs) # (batch_size, vocab_size) - logits
                                                 # Intuition: The model predicts the context word's identity
                                                 # based on the target word.
                loss = self.criterion(predictions, targets)

            loss.backward() # Compute gradients of the loss with respect to model parameters.
                            # Intuition: The "backpropagation" step. It calculates how much each parameter
                            # contributed to the error.
                            # Why: These gradients are used by the optimizer to update the weights.

            self.optimizer.step() # Update model parameters using the computed gradients.
                                  # Intuition: Adjusts the model's weights to reduce the loss.
                                  # Why: This is how the model learns to make better predictions.

            # Apply embedding norm regularization (if enabled)
            if self.norm_embeddings:
                with torch.no_grad(): # Operations inside this block won't track gradients.
                                      # Why: Normalization is a fixed post-processing step, not part of
                                      # the learnable gradient-based optimization.
                    norm = self.model.embeddings.weight.norm(2, dim=1, keepdim=True)
                                      # Intuition: $||v||_2 = \sqrt{\sum_{i=1}^{D} v_i^2}$
                                      # Calculate the L2 (Euclidean) norm for each embedding vector.
                                      # `dim=1` means compute norm for each row (each word embedding).
                                      # `keepdim=True` keeps the dimension for broadcasting.
                                      # Why: The L2 norm measures the "length" of the vector.
                    self.model.embeddings.weight.div_(norm.clamp(min=1e-6))
                                      # Intuition: $v' = v / ||v||_2$
                                      # Divide each embedding vector by its L2 norm, effectively scaling it to have a length of 1.
                                      # `clamp(min=1e-6)` prevents division by zero for extremely small norms.
                                      # Why: Regularizing embeddings to have unit norm can improve stability,
                                      # prevent exploding gradients, and ensure that similarity measures
                                      # (like cosine similarity) are solely based on direction, not magnitude.

            total_loss += loss.item() # Accumulate loss for the epoch.
            pbar.set_postfix(loss=loss.item()) # Update progress bar with current loss.

        avg_loss = total_loss / len(self.dataloader)
        return avg_loss

# --- 4. Main Execution Block ---
if __name__ == "__main__":
    # Dummy text for demonstration
    dummy_text = """
    The quick brown fox jumps over the lazy dog. The dog barks.
    Word embeddings capture semantic meanings. PyTorch is a deep learning framework.
    Natural language processing is a fascinating field. Word2Vec is a classic algorithm.
    """

    # Build vocabulary
    vocab_to_idx, idx_to_vocab, word_counts, ns_probs = build_vocab_and_mappings(dummy_text, config.min_freq)
    vocab_size = len(vocab_to_idx)
    print(f"Vocabulary size: {vocab_size}")

    # Create dataset and dataloader
    if config.use_cbow:
        print("Using CBOW model.")
        dataset = Word2VecDataset(dummy_text, config.window_size, vocab_to_idx, idx_to_vocab, word_counts, ns_probs, config.device, model_type="cbow")
        # `collate_fn=collate_fn_cbow` is crucial here because CBOW contexts can
        # have different lengths, requiring padding within each batch.
        dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True, collate_fn=collate_fn_cbow)
        model = CBOW(vocab_size, config.embedding_dim)
    else:
        print("Using Skip-Gram model.")
        dataset = Word2VecDataset(dummy_text, config.window_size, vocab_to_idx, idx_to_vocab, word_counts, ns_probs, config.device, model_type="skipgram")
        # Skip-Gram inputs (target word) and outputs (single context word) are always fixed size,
        # so default collate_fn works.
        dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)
        model = SkipGram(vocab_size, config.embedding_dim)

    # Define optimizer and loss function
    optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
    # `nn.CrossEntropyLoss`:
    # Intuition: This loss function measures how well the model's predicted probability
    # distribution over the vocabulary matches the true target word. It's used for
    # multi-class classification problems.
    # It internally applies a Softmax function to the `predictions` (logits) and then
    # calculates the Negative Log Likelihood (NLL) of the true target.
    # Why: For a target word $y$, it aims to maximize the probability $P(y|context)$ (CBOW)
    # or $P(context|y)$ (Skip-Gram). Minimizing the negative log-likelihood is equivalent to
    # maximizing the likelihood.
    # For a batch, the loss is the average of $-\log(P(\text{correct_word}))$.
    criterion = nn.CrossEntropyLoss()

    # Initialize trainer and start training
    trainer = Word2VecTrainer(model, dataloader, optimizer, criterion, config, vocab_to_idx, ns_probs)

    for epoch in range(config.num_epochs):
        epoch_loss = trainer.train()
        print(f"Epoch {epoch+1}/{config.num_epochs}, Average Loss: {epoch_loss:.4f}")

    print("\nTraining complete!")

    # --- 5. Inference / Using Embeddings ---
    print("\n--- Inference Example: Finding Similar Words ---")

    # Get the learned embeddings
    word_embeddings = model.embeddings.weight.cpu().detach().numpy()
    # `.cpu().detach().numpy()`:
    # `.cpu()`: Moves the tensor from GPU to CPU.
    # `.detach()`: Creates a new tensor that doesn't track gradients, important for inference.
    # `.numpy()`: Converts the PyTorch tensor to a NumPy array for easier numerical operations.
    # Why: Prepare embeddings for non-PyTorch numerical computations (like numpy dot product)
    # and to ensure they are on the CPU for general use outside of training.

    def get_embedding(word):
        idx = vocab_to_idx.get(word.lower(), vocab_to_idx['<unk>'])
        return word_embeddings[idx]

    def find_similar_words(word, top_n=5):
        if word.lower() not in vocab_to_idx:
            print(f"'{word}' not in vocabulary.")
            return

        word_vec = get_embedding(word)

        # Calculate cosine similarity with all other words
        similarities = []
        for i, emb in enumerate(word_embeddings):
            if i == vocab_to_idx.get(word.lower(), -1): # Skip self-comparison
                continue

            # Compute cosine similarity: $\cos(\theta) = \frac{\mathbf{A} \cdot \mathbf{B}}{||\mathbf{A}||_2 ||\mathbf{B}||_2}$
            # Intuition: Measures the cosine of the angle between two vectors.
            # Ranges from -1 (opposite) to 1 (same direction), 0 for orthogonal.
            # Why: For unit vectors (normalized to length 1), the dot product IS the cosine similarity.
            # This is why normalizing embeddings is useful for direct similarity comparisons.
            dot_product = np.dot(word_vec, emb) # $\mathbf{A} \cdot \mathbf{B}$
                                                # Intuition: A measure of how much two vectors point in the same direction,
                                                # scaled by their magnitudes.
                                                # Why: The numerator of the cosine similarity formula.
            norm_word_vec = np.linalg.norm(word_vec) # $||\mathbf{A}||_2$
                                                     # Intuition: Length (magnitude) of the first vector.
                                                     # Why: Denominator of the cosine similarity.
            norm_emb = np.linalg.norm(emb)           # $||\mathbf{B}||_2$
                                                     # Intuition: Length (magnitude) of the second vector.
                                                     # Why: Denominator of the cosine similarity.

            if norm_word_vec == 0 or norm_emb == 0:
                similarity = -1 # Handle division by zero.
            else:
                similarity = dot_product / (norm_word_vec * norm_emb)
                                # Intuition: The final cosine similarity value.
                                # If vectors are already unit length, this simplifies to just `dot_product`.

            similarities.append((idx_to_vocab[i], similarity))

        similarities.sort(key=lambda x: x[1], reverse=True) # Sort by similarity (descending).
        print(f"\nWords similar to '{word}':")
        for w, sim in similarities[:top_n]:
            print(f"- {w}: {sim:.4f}")

    # Test similarity
    find_similar_words("dog")
    find_similar_words("pytorch")
    find_similar_words("language")
    find_similar_words("nonexistentword") # Test unknown word

    # Access a specific embedding
    print(f"\nEmbedding for 'dog' (first 5 dimensions): {get_embedding('dog')[:5]}")
    print(f"Embedding for '<unk>' (first 5 dimensions): {get_embedding('<unk>')[:5]}")


Using device: cpu
Vocabulary size: 31
Using CBOW model.


Training:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 1/5, Average Loss: 3.4903


Training:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 2/5, Average Loss: 3.4327


Training:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 3/5, Average Loss: 3.4284


Training:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 4/5, Average Loss: 3.4245


Training:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 5/5, Average Loss: 3.4208

Training complete!

--- Inference Example: Finding Similar Words ---

Words similar to 'dog':
- learning: 0.2270
- <pad>: 0.2069
- barks.: 0.2013
- quick: 0.1457
- fox: 0.1140

Words similar to 'pytorch':
- embeddings: 0.2398
- quick: 0.2138
- a: 0.1763
- jumps: 0.1747
- the: 0.1113

Words similar to 'language':
- capture: 0.1632
- meanings.: 0.1579
- algorithm.: 0.1046
- over: 0.1030
- jumps: 0.0951
'nonexistentword' not in vocabulary.

Embedding for 'dog' (first 5 dimensions): [-0.06064522  0.09076002  0.11505468 -0.05188991  0.10866774]
Embedding for '<unk>' (first 5 dimensions): [-0.1416734  -0.06804858  0.22849047 -0.13469714  0.02342173]
