In [48]:
# =========================
# Skip-Gram with Negative Sampling
# =========================

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import random

# 1. Hyperparameters
embedding_dim = 10
context_size = 2   # number of words on each side
num_negatives = 5  # number of negative samples
epochs = 25
lr = 0.01



In [49]:
# -----------------------------
# 3. Toy Corpus (~20 sentences)
# -----------------------------
corpus = [
    "the cat sat down",
    "the cat ate food",
    "the dog sat down",
    "the dog ate food",
    "a cat chased a mouse",
    "the dog chased the cat",
    "a dog barked loudly",
    "the cat meowed softly",
    "the bird sang sweetly",
    "a bird flew away",
    "the fish swam fast",
    "a fish jumped high",
    "the boy played ball",
    "the girl sang song",
    "a boy read book",
    "a girl wrote letter",
    "the sun shines bright",
    "the moon glows softly",
    "the stars twinkle bright",
    "a cat slept quietly"
]


# 2. Vocabulary
tokens = sorted(list(set(" ".join(corpus).split())))
word2idx = {w: i for i, w in enumerate(tokens)}
idx2word = {i: w for w, i in word2idx.items()}
V = len(tokens)

In [50]:
# 3. Noise distribution for negative sampling
import collections
counts = collections.Counter(" ".join(corpus).split())
total = sum(counts.values())
freqs = torch.tensor([counts[w]/total for w in tokens], dtype=torch.float)
# Use unigram^3/4 for negative sampling (Mikolov et al.)
noise_dist = freqs ** 0.75
noise_dist = noise_dist / noise_dist.sum()

# Print the noise distribution with words
print(f"{'Word':<10} | {'Noise Probability':>20}")
print("-" * 35)
for w, p in zip(tokens, noise_dist):
    print(f"{w:<10} | {p.item():>20.6f}")

Word       |    Noise Probability
-----------------------------------
a          |             0.073365
ate        |             0.025938
away       |             0.015423
ball       |             0.015423
barked     |             0.015423
bird       |             0.025938
book       |             0.015423
boy        |             0.025938
bright     |             0.025938
cat        |             0.059127
chased     |             0.025938
dog        |             0.043623
down       |             0.025938
fast       |             0.015423
fish       |             0.025938
flew       |             0.015423
food       |             0.025938
girl       |             0.025938
glows      |             0.015423
high       |             0.015423
jumped     |             0.015423
letter     |             0.015423
loudly     |             0.015423
meowed     |             0.015423
moon       |             0.015423
mouse      |             0.015423
played     |             0.015423
quietly    |

In [51]:
# 4. Generate Skip-Gram pairs
def generate_skipgram_pairs(corpus, context_size=2):
    pairs = []
    for sentence in corpus:
        words = sentence.split()
        for i, target in enumerate(words):
            target_idx = word2idx[target]
            # context window
            for j in range(max(0, i - context_size), min(len(words), i + context_size + 1)):
                if j != i:
                    context_idx = word2idx[words[j]]
                    pairs.append((target_idx, context_idx))
    return pairs

pairs = generate_skipgram_pairs(corpus, context_size)

# 5. Dataset
class SkipGramDataset(Dataset):
    def __init__(self, pairs):
        self.pairs = pairs
    def __len__(self):
        return len(self.pairs)
    def __getitem__(self, idx):
        return self.pairs[idx]

dataset = SkipGramDataset(pairs)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)


In [52]:
# Example of iterating through the DataLoader and printing words
for center, context in dataloader:
    center_words = [idx2word[idx.item()] for idx in center]
    context_words = [idx2word[idx.item()] for idx in context]
    
    print("Center words:", center_words)
    print("Context words:", context_words)
    break  # Display only the first batch

Center words: ['stars', 'girl', 'glows', 'ate']
Context words: ['the', 'sang', 'moon', 'the']


In [53]:

# 6. Skip-Gram Model

class SkipGramNS(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super().__init__()
        self.in_embed = nn.Embedding(vocab_size, embedding_dim)
        self.out_embed = nn.Embedding(vocab_size, embedding_dim)
    
    def forward(self, center, context, neg_samples):
        # center: [B]
        # context: [B]
        # neg_samples: [B, K]
        v_c = self.in_embed(center)           # [B, D]
        u_o = self.out_embed(context)         # [B, D]
        u_k = self.out_embed(neg_samples)     # [B, K, D]

        # positive score
        pos_score = torch.sum(v_c * u_o, dim=1)  # [B]
        pos_loss = F.logsigmoid(pos_score)

        # negative score
        neg_score = torch.bmm(u_k, v_c.unsqueeze(2)).squeeze()  # [B, K]
        neg_loss = F.logsigmoid(-neg_score).sum(1)             # [B]

        return -(pos_loss + neg_loss).mean()  # mean over batch

model = SkipGramNS(V, embedding_dim)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)


In [54]:
def count_parameters(model):
    print("\n=== Model Parameters ===")
    total_params = 0

    for name, param in model.named_parameters():
        if param.requires_grad:
            n_params = param.numel()
            mem = n_params * param.element_size()  # bytes
            total_params += n_params
            print(f"{name:20s} | {n_params:10,d} params")

    print("-" * 55)
    print(f"Total Trainable Params: {total_params:,}")

count_parameters(model)



=== Model Parameters ===
in_embed.weight      |        420 params
out_embed.weight     |        420 params
-------------------------------------------------------
Total Trainable Params: 840


In [55]:

# 7. Training loop
for epoch in range(epochs):
    total_loss = 0
    for center, context in dataloader:
        # generate negative samples
        neg_samples = torch.multinomial(noise_dist, len(center)*num_negatives, replacement=True)
        neg_samples = neg_samples.view(len(center), num_negatives)

        optimizer.zero_grad()
        loss = model(center, context, neg_samples)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss:.4f}")


Epoch 1/25, Loss: 409.4568
Epoch 2/25, Loss: 360.4789
Epoch 3/25, Loss: 329.2030
Epoch 4/25, Loss: 280.4951
Epoch 5/25, Loss: 255.1034
Epoch 6/25, Loss: 227.7789
Epoch 7/25, Loss: 215.8456
Epoch 8/25, Loss: 194.4384
Epoch 9/25, Loss: 178.8957
Epoch 10/25, Loss: 180.0030
Epoch 11/25, Loss: 155.6350
Epoch 12/25, Loss: 159.0434
Epoch 13/25, Loss: 145.7386
Epoch 14/25, Loss: 139.7112
Epoch 15/25, Loss: 134.7986
Epoch 16/25, Loss: 125.8668
Epoch 17/25, Loss: 116.9548
Epoch 18/25, Loss: 122.2632
Epoch 19/25, Loss: 116.8848
Epoch 20/25, Loss: 112.1864
Epoch 21/25, Loss: 110.2257
Epoch 22/25, Loss: 111.3834
Epoch 23/25, Loss: 107.0279
Epoch 24/25, Loss: 105.1920
Epoch 25/25, Loss: 106.6072


In [56]:
import torch

def predict_top_context_dot(center_word, top_k=5):
    model.eval()
    with torch.no_grad():
        # index of the center word
        center_idx = torch.tensor([word2idx[center_word]])

        # input embedding of center word  [1, D]
        v_c = model.in_embed(center_idx)

        # all output embeddings  [V, D]
        u_all = model.out_embed.weight

        # ---- DOT PRODUCT SCORES (SGNS objective) ----
        # score = v_c @ u_o^T
        scores = torch.matmul(u_all, v_c.t()).squeeze()     # [V]

        # exclude center word itself
        scores[word2idx[center_word]] = -float('inf')

        # get top-k
        topk_scores, topk_idx = torch.topk(scores, top_k)

        top_words = [idx2word[i.item()] for i in topk_idx]
        top_scores = [s.item() for s in topk_scores]

        print(f"\nTop {top_k} dot-product neighbors for '{center_word}':")
        print(f"{'Word':<12} | {'Dot-Product':>12}")
        print("-" * 30)
        for w, s in zip(top_words, top_scores):
            print(f"{w:<12} | {s:>12.4f}")
        print()

        return list(zip(top_words, top_scores))


# Example
center = "cat"
_ = predict_top_context_dot(center, top_k=5)



Top 5 dot-product neighbors for 'cat':
Word         |  Dot-Product
------------------------------
slept        |      -0.5227
quietly      |      -0.6172
the          |      -0.8214
chased       |      -1.1401
meowed       |      -1.2194

