In [None]:
# =========================
# Skip-Gram with Negative Sampling
# =========================

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import random

# 1. Hyperparameters
embedding_dim = 10
context_size = 2   # number of words on each side
num_negatives = 5  # number of negative samples
epochs = 25
lr = 0.01



In [2]:
# -----------------------------
# 3. Toy Corpus (~20 sentences)
# -----------------------------
corpus = [
    "the cat sat down",
    "the cat ate food",
    "the dog sat down",
    "the dog ate food",
    "a cat chased a mouse",
    "the dog chased the cat",
    "a dog barked loudly",
    "the cat meowed softly",
    "the bird sang sweetly",
    "a bird flew away",
    "the fish swam fast",
    "a fish jumped high",
    "the boy played ball",
    "the girl sang song",
    "a boy read book",
    "a girl wrote letter",
    "the sun shines bright",
    "the moon glows softly",
    "the stars twinkle bright",
    "a cat slept quietly"
]


# 2. Vocabulary
tokens = sorted(list(set(" ".join(corpus).split())))
word2idx = {w: i for i, w in enumerate(tokens)}
idx2word = {i: w for w, i in word2idx.items()}
V = len(tokens)

In [3]:
# 3. Noise distribution for negative sampling
import collections
counts = collections.Counter(" ".join(corpus).split())
total = sum(counts.values())
freqs = torch.tensor([counts[w]/total for w in tokens], dtype=torch.float)
# Use unigram^3/4 for negative sampling (Mikolov et al.)
noise_dist = freqs ** 0.75
noise_dist = noise_dist / noise_dist.sum()

# Print the noise distribution with words
print(f"{'Word':<10} | {'Noise Probability':>20}")
print("-" * 35)
for w, p in zip(tokens, noise_dist):
    print(f"{w:<10} | {p.item():>20.6f}")

Word       |    Noise Probability
-----------------------------------
a          |             0.073365
ate        |             0.025938
away       |             0.015423
ball       |             0.015423
barked     |             0.015423
bird       |             0.025938
book       |             0.015423
boy        |             0.025938
bright     |             0.025938
cat        |             0.059127
chased     |             0.025938
dog        |             0.043623
down       |             0.025938
fast       |             0.015423
fish       |             0.025938
flew       |             0.015423
food       |             0.025938
girl       |             0.025938
glows      |             0.015423
high       |             0.015423
jumped     |             0.015423
letter     |             0.015423
loudly     |             0.015423
meowed     |             0.015423
moon       |             0.015423
mouse      |             0.015423
played     |             0.015423
quietly    |

In [4]:
# 4. Generate Skip-Gram pairs
def generate_skipgram_pairs(corpus, context_size=2):
    pairs = []
    for sentence in corpus:
        words = sentence.split()
        for i, target in enumerate(words):
            target_idx = word2idx[target]
            # context window
            for j in range(max(0, i - context_size), min(len(words), i + context_size + 1)):
                if j != i:
                    context_idx = word2idx[words[j]]
                    pairs.append((target_idx, context_idx))
    return pairs

pairs = generate_skipgram_pairs(corpus, context_size)

# 5. Dataset
class SkipGramDataset(Dataset):
    def __init__(self, pairs):
        self.pairs = pairs
    def __len__(self):
        return len(self.pairs)
    def __getitem__(self, idx):
        return self.pairs[idx]

dataset = SkipGramDataset(pairs)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)


In [5]:
# Example of iterating through the DataLoader and printing words
for center, context in dataloader:
    center_words = [idx2word[idx.item()] for idx in center]
    context_words = [idx2word[idx.item()] for idx in context]
    
    print("Center words:", center_words)
    print("Context words:", context_words)
    break  # Display only the first batch

Center words: ['twinkle', 'dog', 'the', 'chased']
Context words: ['stars', 'sat', 'sang', 'dog']


In [6]:

# 6. Skip-Gram Model

class SkipGramNS(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super().__init__()
        self.in_embed = nn.Embedding(vocab_size, embedding_dim)
        self.out_embed = nn.Embedding(vocab_size, embedding_dim)
    
    def forward(self, center, context, neg_samples):
        # center: [B]
        # context: [B]
        # neg_samples: [B, K]
        v_c = self.in_embed(center)           # [B, D]
        u_o = self.out_embed(context)         # [B, D]
        u_k = self.out_embed(neg_samples)     # [B, K, D]

        # positive score
        pos_score = torch.sum(v_c * u_o, dim=1)  # [B]
        pos_loss = F.logsigmoid(pos_score)

        # negative score
        neg_score = torch.bmm(u_k, v_c.unsqueeze(2)).squeeze()  # [B, K]
        neg_loss = F.logsigmoid(-neg_score).sum(1)             # [B]

        return -(pos_loss + neg_loss).mean()  # mean over batch

model = SkipGramNS(V, embedding_dim)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)


In [7]:

# 7. Training loop
for epoch in range(epochs):
    total_loss = 0
    for center, context in dataloader:
        # generate negative samples
        neg_samples = torch.multinomial(noise_dist, len(center)*num_negatives, replacement=True)
        neg_samples = neg_samples.view(len(center), num_negatives)

        optimizer.zero_grad()
        loss = model(center, context, neg_samples)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss:.4f}")


Epoch 1/20, Loss: 435.8474
Epoch 2/20, Loss: 368.7665
Epoch 3/20, Loss: 314.8301
Epoch 4/20, Loss: 280.7955
Epoch 5/20, Loss: 251.7243
Epoch 6/20, Loss: 243.0566
Epoch 7/20, Loss: 200.1051
Epoch 8/20, Loss: 195.5157
Epoch 9/20, Loss: 169.0322
Epoch 10/20, Loss: 169.3974
Epoch 11/20, Loss: 157.2252
Epoch 12/20, Loss: 153.0578
Epoch 13/20, Loss: 147.8771
Epoch 14/20, Loss: 137.3757
Epoch 15/20, Loss: 132.1392
Epoch 16/20, Loss: 127.3621
Epoch 17/20, Loss: 127.7830
Epoch 18/20, Loss: 118.1502
Epoch 19/20, Loss: 120.0880
Epoch 20/20, Loss: 119.2882


In [8]:
import torch

def predict_top_context(center_word, top_k=5):
    model.eval()  # set to eval mode
    with torch.no_grad():
        # get index of the center word
        center_idx = torch.tensor([word2idx[center_word]])
        # get embedding of the center word
        v_c = model.in_embed(center_idx)  # [1, D]

        # compute scores with all output embeddings
        u_all = model.out_embed.weight  # [V, D]
        scores = torch.matmul(u_all, v_c.t()).squeeze()  # [V]

        # top k context words
        topk_scores, topk_idx = torch.topk(scores, top_k)
        top_words = [idx2word[i.item()] for i in topk_idx]
        return top_words

# Example usage
center = 'cat'
top_context = predict_top_context(center, top_k=5)
print(f"Top predicted context words for '{center}': {top_context}")


center = 'boy'
top_context = predict_top_context(center, top_k=5)
print(f"Top predicted context words for '{center}': {top_context}")


Top predicted context words for 'cat': ['swam', 'ate', 'away', 'bird', 'moon']
Top predicted context words for 'boy': ['loudly', 'moon', 'played', 'book', 'down']


In [11]:
import torch

def predict_top_context(center_word, top_k=5):
    model.eval()
    with torch.no_grad():
        # get index of the center word
        center_idx = torch.tensor([word2idx[center_word]])
        # get embedding of the center word
        v_c = model.in_embed(center_idx)  # [1, D]

        # compute scores with all output embeddings
        u_all = model.out_embed.weight  # [V, D]
        scores = torch.matmul(u_all, v_c.t()).squeeze()  # [V]

        # top k context words (excluding the center word itself)
        scores[word2idx[center_word]] = -float('inf')  # exclude self
        topk_scores, topk_idx = torch.topk(scores, top_k)

        # prepare output
        top_words = [idx2word[i.item()] for i in topk_idx]
        top_scores = [s.item() for s in topk_scores]

        print(f"\nTop {top_k} predicted context words for '{center_word}':")
        print(f"{'Word':<12} | {'Score':>10}")
        print("-" * 25)
        for w, s in zip(top_words, top_scores):
            print(f"{w:<12} | {s:>10.4f}")
        print()

        return list(zip(top_words, top_scores))

# Example usage
center = 'cat'
_ = predict_top_context(center, top_k=5)

center = 'fish'
_ = predict_top_context(center, top_k=5)



Top 5 predicted context words for 'cat':
Word         |      Score
-------------------------
swam         |    -0.4673
ate          |    -0.8462
away         |    -0.9408
bird         |    -0.9638
moon         |    -1.0851


Top 5 predicted context words for 'fish':
Word         |      Score
-------------------------
bright       |    -0.7526
the          |    -0.7723
fast         |    -0.9207
jumped       |    -0.9452
swam         |    -1.3696

