In [167]:
# =========================
# CBOW with Negative Sampling
# =========================

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

import random

# 1. Hyperparameters
embedding_dim = 10
context_size = 2   # number of words on each side
num_negatives = 5  # number of negative samples
epochs = 50
lr = 0.01



In [168]:
# -----------------------------
# 3. Toy Corpus (~20 sentences)
# -----------------------------
corpus = [
    "the cat sat down",
    "the cat ate food",
    "the dog sat down",
    "the dog ate food",
    "a cat chased a mouse",
    "the dog chased the cat",
    "a dog barked loudly",
    "the cat meowed softly",
    "the bird sang sweetly",
    "a bird flew away",
    "the fish swam fast",
    "a fish jumped high",
    "the boy played ball",
    "the girl sang song",
    "a boy read book",
    "a girl wrote letter",
    "the sun shines bright",
    "the moon glows softly",
    "the stars twinkle bright",
    "a cat slept quietly"
]

PAD = "<PAD>"

# 2. Vocabulary
tokens = sorted(list(set(" ".join(corpus).split()))) + [PAD]
word2idx = {w: i for i, w in enumerate(tokens)}
idx2word = {i: w for w, i in word2idx.items()}
PAD_idx = word2idx[PAD]

V = len(tokens)

In [169]:
# 3. Noise distribution for negative sampling
import collections
counts = collections.Counter(" ".join(corpus).split())
total = sum(counts.values())
freqs = torch.tensor([counts[w]/total for w in tokens], dtype=torch.float)
# Use unigram^3/4 for negative sampling (Mikolov et al.)
noise_dist = freqs ** 0.75
noise_dist = noise_dist / noise_dist.sum()

# Print the noise distribution with words
print(f"{'Word':<10} | {'Noise Probability':>20}")
print("-" * 35)
for w, p in zip(tokens, noise_dist):
    print(f"{w:<10} | {p.item():>20.6f}")

Word       |    Noise Probability
-----------------------------------
a          |             0.073365
ate        |             0.025938
away       |             0.015423
ball       |             0.015423
barked     |             0.015423
bird       |             0.025938
book       |             0.015423
boy        |             0.025938
bright     |             0.025938
cat        |             0.059127
chased     |             0.025938
dog        |             0.043623
down       |             0.025938
fast       |             0.015423
fish       |             0.025938
flew       |             0.015423
food       |             0.025938
girl       |             0.025938
glows      |             0.015423
high       |             0.015423
jumped     |             0.015423
letter     |             0.015423
loudly     |             0.015423
meowed     |             0.015423
moon       |             0.015423
mouse      |             0.015423
played     |             0.015423
quietly    |

In [170]:
# --------------------------
# 4. Generate CBOW pairs
# --------------------------
def generate_cbow_pairs(corpus, context_size=2):
    pairs = []
    window_len = 2 * context_size

    for sentence in corpus:
        words = sentence.split()
        N = len(words)

        for i, target in enumerate(words):
            context_ids = []

            # fill left + right context positions
            for offset in range(-context_size, context_size + 1):
                if offset == 0:
                    continue  # skip target
                j = i + offset
                if 0 <= j < N:
                    context_ids.append(word2idx[words[j]])
                else:
                    context_ids.append(PAD_idx)  # pad

            # now context_ids is always length 2*context_size
            pairs.append((context_ids, word2idx[target]))

    return pairs
pairs = generate_cbow_pairs(corpus, context_size)

print(f"Total CBOW pairs: {len(pairs)}")



Total CBOW pairs: 82


In [171]:

# --------------------------
# 5. Dataset
# --------------------------
class CBOWDataset(Dataset):
    def __init__(self, pairs):
        self.pairs = pairs
    def __len__(self):
        return len(self.pairs)
    def __getitem__(self, idx):
        contexts, target = self.pairs[idx]
        return torch.tensor(contexts), torch.tensor(target)

dataset = CBOWDataset(pairs)
dataloader = DataLoader(dataset, batch_size=3, shuffle=True)


In [172]:
print("\n=== Example DataLoader Iteration ===")
for batch_i, (contexts, targets) in enumerate(dataloader):
    print(f"\nBatch {batch_i+1}:")
    for c, t in zip(contexts, targets):
        ctx_words = [idx2word[i.item()] for i in c]
        tgt_word = idx2word[t.item()]
        print(f"  Context: {ctx_words} → Target: '{tgt_word}'")
    if batch_i == 1:  # just print 2 batches
        break


=== Example DataLoader Iteration ===

Batch 1:
  Context: ['the', 'cat', 'down', '<PAD>'] → Target: 'sat'
  Context: ['<PAD>', 'a', 'flew', 'away'] → Target: 'bird'
  Context: ['a', 'girl', 'letter', '<PAD>'] → Target: 'wrote'

Batch 2:
  Context: ['<PAD>', 'a', 'barked', 'loudly'] → Target: 'dog'
  Context: ['<PAD>', 'a', 'chased', 'a'] → Target: 'cat'
  Context: ['<PAD>', '<PAD>', 'cat', 'chased'] → Target: 'a'


In [173]:
# --------------------------
# 6. Model
# --------------------------
class CBOWNegativeSampling(nn.Module):
    def __init__(self, vocab_size, embedding_dim, PAD_idx):
        super().__init__()
        self.in_embed = nn.Embedding(vocab_size, embedding_dim)
        self.out_embed = nn.Embedding(vocab_size, embedding_dim)
        self.PAD_idx = PAD_idx


    def forward(self, contexts, target, neg_samples):

        # ------------------------
        # MASKED CONTEXT AVERAGE
        # ------------------------
        mask = (contexts != self.PAD_idx).float()         # [B, 2C]
        embeds = self.in_embed(contexts)                  # [B, 2C, D]
        masked_embeds = embeds * mask.unsqueeze(2)        # [B, 2C, D]

        sum_embeds = masked_embeds.sum(dim=1)             # [B, D]
        count = mask.sum(dim=1).unsqueeze(1)              # [B, 1]
        count = torch.clamp(count, min=1)

        v_c = sum_embeds / count                           # [B, D]

        # ------------------------
        # POSITIVE SCORE
        # ------------------------
        u_o = self.out_embed(target)                      # [B, D]
        pos_score = (v_c * u_o).sum(dim=1)
        pos_loss = F.logsigmoid(pos_score)

        # ------------------------
        # NEGATIVE SCORE
        # ------------------------
        u_k = self.out_embed(neg_samples)                 # [B, K, D]
        neg_score = torch.bmm(u_k, v_c.unsqueeze(2)).squeeze(dim=2)
        neg_loss = F.logsigmoid(-neg_score).sum(1)

        return -(pos_loss + neg_loss).mean()        

model = CBOWNegativeSampling(V, embedding_dim, PAD_idx)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [174]:
def count_parameters(model):
    print("\n=== Model Parameters ===")
    total_params = 0

    for name, param in model.named_parameters():
        if param.requires_grad:
            n_params = param.numel()
            mem = n_params * param.element_size()  # bytes
            total_params += n_params
            print(f"{name:20s} | {n_params:10,d} params")

    print("-" * 55)
    print(f"Total Trainable Params: {total_params:,}")

count_parameters(model)



=== Model Parameters ===
in_embed.weight      |        430 params
out_embed.weight     |        430 params
-------------------------------------------------------
Total Trainable Params: 860


In [175]:

for epoch in range(epochs):
    total_loss = 0
    for contexts, target in dataloader:
        neg_samples = torch.multinomial(noise_dist, len(target) * num_negatives, replacement=True)
        neg_samples = neg_samples.view(len(target), num_negatives)

        optimizer.zero_grad()
        loss = model(contexts, target, neg_samples)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss:.4f}")

Epoch 1/50, Loss: 171.6441
Epoch 2/50, Loss: 171.6962
Epoch 3/50, Loss: 144.1562
Epoch 4/50, Loss: 148.6950
Epoch 5/50, Loss: 124.0212
Epoch 6/50, Loss: 124.3845
Epoch 7/50, Loss: 126.8215
Epoch 8/50, Loss: 117.6695
Epoch 9/50, Loss: 108.4824
Epoch 10/50, Loss: 102.3632
Epoch 11/50, Loss: 98.7818
Epoch 12/50, Loss: 92.5184
Epoch 13/50, Loss: 80.7378
Epoch 14/50, Loss: 80.4390
Epoch 15/50, Loss: 75.0329
Epoch 16/50, Loss: 69.5878
Epoch 17/50, Loss: 65.3495
Epoch 18/50, Loss: 67.0577
Epoch 19/50, Loss: 62.1975
Epoch 20/50, Loss: 55.8747
Epoch 21/50, Loss: 58.2444
Epoch 22/50, Loss: 55.1815
Epoch 23/50, Loss: 51.1636
Epoch 24/50, Loss: 55.6556
Epoch 25/50, Loss: 49.2184
Epoch 26/50, Loss: 49.9983
Epoch 27/50, Loss: 48.2944
Epoch 28/50, Loss: 46.1428
Epoch 29/50, Loss: 43.7625
Epoch 30/50, Loss: 43.7921
Epoch 31/50, Loss: 42.0808
Epoch 32/50, Loss: 43.3042
Epoch 33/50, Loss: 39.9031
Epoch 34/50, Loss: 39.3197
Epoch 35/50, Loss: 42.5963
Epoch 36/50, Loss: 37.7143
Epoch 37/50, Loss: 36.5981


In [180]:
import torch

def predict_top_context(center_word, top_k=5):
    model.eval()
    with torch.no_grad():
        # get index of the center word
        center_idx = torch.tensor([word2idx[center_word]])
        # get embedding of the center word
        v_c = model.in_embed(center_idx)  # [1, D]
        v_c_norm = v_c / v_c.norm(dim=1, keepdim=True)

        # get all output embeddings and normalize
        u_all = model.out_embed.weight  # [V, D]
        u_all_norm = u_all / u_all.norm(dim=1, keepdim=True)

        # cosine similarity
        scores = torch.matmul(u_all_norm, v_c_norm.t()).squeeze()  # [V]

        # exclude center word itself
        scores[word2idx[center_word]] = -float('inf')
        topk_scores, topk_idx = torch.topk(scores, top_k)

        top_words = [idx2word[i.item()] for i in topk_idx]
        top_scores = [s.item() for s in topk_scores]

        print(f"\nTop {top_k} cosine-similar context words for '{center_word}':")
        print(f"{'Word':<12} | {'Cosine Sim':>10}")
        print("-" * 28)
        for w, s in zip(top_words, top_scores):
            print(f"{w:<12} | {s:>10.4f}")
        print()

        return list(zip(top_words, top_scores))

# Example usage
center = 'cat'
_ = predict_top_context(center, top_k=5)



Top 5 cosine-similar context words for 'cat':
Word         | Cosine Sim
----------------------------
down         |     0.1469
the          |     0.1242
slept        |     0.0764
a            |     0.0730
chased       |     0.0311

