<a href="https://colab.research.google.com/github/easonwangzk/UChicago/blob/main/Skip_gram.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
import re
import json
import torch
import random
import requests
import numpy as np
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from collections import Counter, OrderedDict
from itertools import chain
from typing import List

## Device setup

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} as device.")

Using cuda as device.


## Configs

In [3]:
data_dir = "./data"
model_dir = "./models"
debug = True

if debug:
    CONTEXT_WINDOW = 2
    EMBEDDING_SIZE = 5
    MIN_FREQ = 5
    BATCH_SIZE = 3
    N_EPOCHS = 1
else:
    CONTEXT_WINDOW = 4
    EMBEDDING_SIZE = 100
    MIN_FREQ = 25
    BATCH_SIZE = 64
    N_EPOCHS = 3

## Create dirs

In [4]:
os.makedirs(data_dir, exist_ok=True)
os.makedirs(model_dir, exist_ok=True)

## Download and tokenize text

In [5]:
url = "https://www.gutenberg.org/cache/epub/7370/pg7370.txt"
response = requests.get(url)
raw_text = response.text.lower()
raw_text = re.sub(r'[^a-z\s]', '', raw_text)
sentences = [line.split() for line in raw_text.split('\n') if line.strip()]
print(f"Number of sentences: {len(sentences):,}")

Number of sentences: 4,957



## Vocabulary class (same as CBOW version)

In [6]:
class Vocab:
    def __init__(self, word_counts: OrderedDict, min_freq: int = 1, max_size: int = None, specials: List[str] = None, unk_token: str = "<unk>"):
        self.word_counts = word_counts
        self.min_freq = min_freq
        self.max_size = max_size
        self.unk_token = unk_token
        self.specials = list(specials) if specials else []

        if self.unk_token not in self.specials:
            self.specials.insert(0, self.unk_token)

        self.token2idx = {}
        self.idx2token = []
        self._prepare_vocab()

    def __len__(self):
        return len(self.idx2token)

    def __contains__(self, value):
        return value in self.idx2token

    def _prepare_vocab(self):
        vocab_list = self.specials.copy()
        filtered_words = [word for word, freq in self.word_counts.items() if freq >= self.min_freq and word not in self.specials]
        if self.max_size is not None:
            filtered_words = filtered_words[:self.max_size - len(self.specials)]
        vocab_list.extend(filtered_words)
        self.idx2token = vocab_list
        self.token2idx = {word: idx for idx, word in enumerate(vocab_list)}

    def get_token(self, idx: int) -> str:
        return self.idx2token[idx] if 0 <= idx < len(self.idx2token) else self.unk_token

    def get_index(self, token: str) -> int:
        return self.token2idx.get(token, self.token2idx[self.unk_token])

    def get_tokens(self, indices: List[int]) -> List[str]:
        return [self.get_token(idx) for idx in indices]

    def get_indices(self, tokens: List[str]) -> List[int]:
        return [self.get_index(token) for token in tokens]

## Padding for skip-gram

In [7]:
def pad_sentences(sentences: List[List[str]], context_length: int, pad_token: str = "<pad>") -> List[List[str]]:
    padded_sentences = []
    for sentence in sentences:
        padded_sentence = [pad_token] * context_length + sentence + [pad_token] * context_length
        padded_sentences.append(padded_sentence)
    return padded_sentences

In [8]:
sentences = pad_sentences(sentences, CONTEXT_WINDOW)

## Build vocab

In [9]:
vocab = Vocab(OrderedDict(Counter(chain.from_iterable(sentences))), min_freq=MIN_FREQ, specials=["<pad>"])
print(f"Size of Vocabulary: {len(vocab):,}")

Size of Vocabulary: 1,158



## Skip-gram pair generation

In [10]:
def generate_skipgram_pairs(sentences: List[List[str]], context_length: int, vocab: Vocab):
    inputs = []
    outputs = []
    for sentence in sentences:
        encoded = vocab.get_indices(sentence)
        for center_idx in range(context_length, len(encoded) - context_length):
            center_word = encoded[center_idx]
            context = encoded[center_idx - context_length : center_idx] + encoded[center_idx + 1 : center_idx + context_length + 1]
            for context_word in context:
                inputs.append(center_word)
                outputs.append(context_word)
    return torch.tensor(inputs), torch.tensor(outputs)

inputs, outputs = generate_skipgram_pairs(sentences, CONTEXT_WINDOW, vocab)
print(f"Number of training examples: {len(inputs):,}")

Number of training examples: 236,704


## Dataset class

In [11]:
class SkipGramDataset(Dataset):
    def __init__(self, inputs, targets):
        self.inputs = inputs
        self.targets = targets

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return self.inputs[idx], self.targets[idx]

## Skip-gram model

In [12]:
class SkipGram(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super().__init__()
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.linear = nn.Linear(embedding_dim, vocab_size)

    def forward(self, center_words):
        embeds = self.embeddings(center_words)
        out = self.linear(embeds)
        return out

    def debug_forward(self, center_words):
        embeds = self.embeddings(center_words)
        print("\nembeddings shape:", embeds.shape)
        print(embeds)
        out = self.linear(embeds)
        print("\nlogits shape:", out.shape)
        print(out)
        return out

## Instantiate model

In [13]:
model = SkipGram(vocab_size=len(vocab), embedding_dim=EMBEDDING_SIZE).to(device)
print(model)

SkipGram(
  (embeddings): Embedding(1158, 5)
  (linear): Linear(in_features=5, out_features=1158, bias=True)
)


## Loss and optimizer

In [14]:
criterion = nn.CrossEntropyLoss(ignore_index=vocab.get_index(vocab.unk_token))
optimizer = optim.Adam(model.parameters(), lr=0.001)

## Dataloader

In [15]:
dataset = SkipGramDataset(inputs, outputs)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

## Training loop

In [16]:
for epoch in range(N_EPOCHS):
    total_loss = 0
    for batch_inputs, batch_outputs in dataloader:
        batch_inputs, batch_outputs = batch_inputs.to(device), batch_outputs.to(device)

        optimizer.zero_grad()
        if debug:
            predictions = model.debug_forward(batch_inputs)
        else:
            predictions = model.forward(batch_inputs)

        loss = criterion(predictions, batch_outputs)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        if debug: break
    if debug: break
    print(f"Epoch {epoch+1}/{N_EPOCHS}, Loss: {total_loss/len(dataset):.4f}")


embeddings shape: torch.Size([3, 5])
tensor([[ 0.0099,  0.8007, -0.2172, -1.7865, -0.1345],
        [-0.1325, -1.2426, -0.1149,  1.1431,  0.3546],
        [-2.8135,  0.0679,  0.0196, -0.9808,  0.5849]], device='cuda:0',
       grad_fn=<EmbeddingBackward0>)

logits shape: torch.Size([3, 1158])
tensor([[ 0.6994,  0.6161,  0.1455,  ...,  0.0060,  0.0517,  0.4258],
        [-0.0163,  0.1403, -0.6382,  ..., -0.2566, -0.6915, -0.8407],
        [-0.2563, -0.1448, -0.0099,  ..., -0.1927,  0.5144,  0.6950]],
       device='cuda:0', grad_fn=<AddmmBackward0>)


## Save trained model weights and vocab

In [18]:
import torch.nn.functional as F
import pickle

In [19]:
torch.save(model.embeddings.weight.data, f"{model_dir}/weights.pt")
with open(f"{model_dir}/vocab.pkl", "wb") as f:
    pickle.dump(vocab, f)


## Define function to compute closest words

In [20]:
def closest_words(embeddings, vocab, word, n=10):
    if word not in vocab.token2idx:
        raise ValueError(f"'{word}' not in vocabulary")

    word_idx = vocab.get_index(word)
    word_embedding = embeddings[word_idx]

    similarities = F.cosine_similarity(word_embedding.unsqueeze(0), embeddings, dim=1)
    similarities[word_idx] = -1  # exclude itself

    top_indices = similarities.topk(n).indices
    return [(vocab.get_token(idx), similarities[idx].item()) for idx in top_indices]

In [21]:
if torch.cuda.is_available():
    loaded_embeddings = torch.load(f"{model_dir}/weights.pt", weights_only=True)
else:
    loaded_embeddings = torch.load(f"{model_dir}/weights.pt", weights_only=True, map_location=torch.device("cpu"))

with open(f"{model_dir}/vocab.pkl", "rb") as f:
    loaded_vocab = pickle.load(f)

## Run similarity search

In [22]:
print("Trained model:")
print(closest_words(embeddings=loaded_embeddings, vocab=loaded_vocab, word="love", n=10))

Trained model:
[('was', 0.9637019038200378), ('make', 0.905871570110321), ('judge', 0.9049926996231079), ('body', 0.9028736352920532), ('brought', 0.8975038528442383), ('minority', 0.8772290945053101), ('wise', 0.8752244710922241), ('kingdom', 0.8742192387580872), ('grown', 0.8639228940010071), ('food', 0.8584514856338501)]



## Compare with untrained model

In [23]:
model_untrained = SkipGram(vocab_size=len(vocab), embedding_dim=EMBEDDING_SIZE)
untrained_embeddings = model_untrained.embeddings.weight.data

print("\nUntrained model:")
print(closest_words(embeddings=untrained_embeddings, vocab=vocab, word="love", n=10))


Untrained model:
[('or', 0.9682488441467285), ('draw', 0.9493830800056458), ('agreed', 0.9472517371177673), ('paragraph', 0.9458328485488892), ('many', 0.9302278161048889), ('grants', 0.928244411945343), ('measure', 0.9108309149742126), ('understood', 0.9096970558166504), ('freemen', 0.9056882262229919), ('power', 0.89534991979599)]
