In [1]:
import re
from functools import partial
from collections import Counter
import numpy as np
from tqdm import tqdm
from datasets import load_dataset
import torch
from torch import nn, optim
from torch.utils.data import DataLoader

In [2]:
torch.manual_seed(0)

<torch._C.Generator at 0x117c99fd0>

In [3]:
embedding_dim = 300
max_norm = 1
min_freq = 50
num_words = 4
max_sequence_length = 256

epochs = 10
batch_size = 96
learning_rate = 0.005

In [4]:
class Vocab:
    def __init__(self, stoi, default_index):
        self.stoi = stoi
        self.itos = {i: t for t, i in stoi.items()}
        self.default_index = default_index

    def __len__(self):
        return len(self.stoi)

    def __getitem__(self, token):
        return self.stoi.get(token, self.default_index)

    def __call__(self, tokens):
        return self.lookup_indices(tokens)

    def lookup_indices(self, tokens):
        return [self[token] for token in tokens]

    def lookup_token(self, index):
        return self.itos[index]

    def forward(self, tokens):
        return self.lookup_indices(tokens)

    def get_stoi(self):
        return self.stoi

    def set_default_index(self, index):
        self.default_index = index

In [None]:
def build_vocab_from_iterator(token_iter, specials=["<unk>"], min_freq=1):
    counter = Counter()
    for tokens in token_iter:
        counter.update(tokens)

    # Filter by min_freq and sort by frequency descending, then alphabetically
    word_freq = sorted([(w, f) for w, f in counter.items() if f >= min_freq], key=lambda x: (-x[1], x[0]))
    words = specials + [w for w, _ in word_freq]
    stoi = {word: idx for idx, word in enumerate(words)}
    default_index = stoi["<unk>"]
    return Vocab(stoi, default_index)


def get_tokenizer(tokenizer_name, language="en"):
    if tokenizer_name != "basic_english":
        raise ValueError(f"Unsupported tokenizer: {tokenizer_name}")
    def basic_english_tokenize(text):
        text = text.lower()
        text = re.sub(r"([.,!?;])", r" \1 ", text)
        text = re.sub(r"[^a-zA-Z.,!?;]+", r" ", text)
        return text.split()
    return basic_english_tokenize


def build_vocab(data_iter, tokenizer):
    vocab = build_vocab_from_iterator(map(tokenizer, data_iter), specials=["<unk>"], min_freq=min_freq)
    vocab.set_default_index(vocab["<unk>"])
    return vocab


def collate_cbow(batch, text_pipeline):
    inputs, outputs = [], []
    for text in batch:
        token_ids = text_pipeline(text)
        if len(token_ids) < num_words * 2 + 1:
            continue
        token_ids = token_ids[:max_sequence_length]
        for i in range(len(token_ids) - num_words * 2):
            sequence = token_ids[i : num_words * 2 + i + 1]
            outputs.append(sequence.pop(num_words))
            inputs.append(sequence)
    return torch.tensor(inputs, dtype=torch.long), torch.tensor(outputs, dtype=torch.long)


def get_dataloader_and_vocab(ds_type, batch_size, shuffle=True, vocab=None):
    dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split=ds_type)
    tokenizer = get_tokenizer("basic_english")

    def data_iter():
        for example in dataset:
            yield example["text"]

    if not vocab:
        vocab = build_vocab(data_iter(), tokenizer)

    text_pipeline = lambda x: vocab(tokenizer(x))  # vocab(tokens) works via forward/lookup_indices

    def collate_fn(batch):
        texts = [example["text"] for example in batch]
        return collate_cbow(texts, text_pipeline)

    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)
    return dataloader, vocab

In [6]:
train_dataloader, vocab = get_dataloader_and_vocab("train", batch_size)
valid_dataloader, _ = get_dataloader_and_vocab("validation", batch_size, vocab=vocab)
vocab_size = len(vocab.get_stoi())
print(f"Vocabulary size: {vocab_size}")
train_steps, valid_steps = len(train_dataloader), len(valid_dataloader)

Vocabulary size: 3867


In [7]:
class CBOW(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.embeddings = nn.Embedding(vocab_size, embedding_dim=embedding_dim, max_norm=max_norm)
        self.linear = nn.Linear(embedding_dim, vocab_size)

    def __call__(self, x):
        x = self.embeddings(x)
        x = x.mean(axis=1)
        x = self.linear(x)
        return x

In [8]:
model = CBOW(vocab_size)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)

In [9]:
for _ in range(epochs):
    model.train()
    for i, (x, y) in (t := tqdm(enumerate(train_dataloader), total=train_steps)):
        optimizer.zero_grad()
        loss = criterion(model(x), y)
        loss.backward()
        optimizer.step()
        t.set_description(f"train loss {loss.item():.2f}")
    model.eval()
    total = 0.0
    for i, (x, y) in enumerate(valid_dataloader):
        loss = criterion(model(x), y)
        total += loss.item()
    print(f"validation loss {(total/valid_steps):.2f}")

train loss 5.35: 100%|██████████| 383/383 [00:22<00:00, 16.85it/s]


validation loss 5.25


train loss 4.86: 100%|██████████| 383/383 [00:20<00:00, 18.29it/s]


validation loss 5.12


train loss 5.04: 100%|██████████| 383/383 [00:21<00:00, 18.18it/s]


validation loss 5.03


train loss 4.95: 100%|██████████| 383/383 [00:22<00:00, 17.37it/s]


validation loss 4.97


train loss 4.97: 100%|██████████| 383/383 [00:22<00:00, 16.90it/s]


validation loss 4.92


train loss 4.70: 100%|██████████| 383/383 [00:22<00:00, 16.95it/s]


validation loss 4.89


train loss 5.00: 100%|██████████| 383/383 [00:21<00:00, 18.08it/s]


validation loss 4.85


train loss 4.94: 100%|██████████| 383/383 [00:21<00:00, 17.89it/s]


validation loss 4.83


train loss 4.89: 100%|██████████| 383/383 [00:21<00:00, 18.13it/s]


validation loss 4.80


train loss 4.77: 100%|██████████| 383/383 [00:21<00:00, 17.58it/s]


validation loss 4.78


In [10]:
embeddings = next(model.parameters()).detach().numpy()
norms = ((embeddings ** 2).sum(axis=1) ** 0.5).reshape(-1, 1)
embeddings_norm = embeddings / norms
embeddings_norm.shape

(3867, 300)

In [11]:
def get_similar(word, n=10):
    word_id = vocab[word]
    if word_id == 0:
        print("out of vocabulary word")
        return {}
    word_vec = embeddings_norm[word_id].flatten()
    dists = np.matmul(embeddings_norm, word_vec).flatten()
    top_ids = np.argsort(-dists)[1:n+1]
    top_dict = {}
    for sim_word_id in top_ids:
        sim_word = vocab.lookup_token(sim_word_id)
        top_dict[sim_word] = dists[sim_word_id]
    return top_dict


for word, score in get_similar("father").items():
    print(f"{word}: {score:.3f}")

wife: 0.863
brother: 0.852
daughter: 0.821
mother: 0.801
son: 0.733
friend: 0.711
husband: 0.702
opponent: 0.690
successor: 0.656
death: 0.638


In [12]:
emb = embeddings[vocab["king"]] - embeddings[vocab["man"]] + embeddings[vocab["woman"]]
norm = (emb ** 2).sum() ** 0.5
emb_norm = (emb / norm).flatten()
dists = np.matmul(embeddings_norm, emb_norm).flatten()
word_ids = np.argsort(-dists)[:10]

for word_id in word_ids:
    print(f"{vocab.lookup_token(word_id)}: {dists[word_id]:.3f}")

king: 0.714
queen: 0.655
lord: 0.607
ambassador: 0.595
representative: 0.595
church: 0.551
mother: 0.550
palace: 0.549
nation: 0.539
henry: 0.534


In [13]:
class SkipGram(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.embeddings = nn.Embedding(vocab_size, embedding_dim=embedding_dim, max_norm=max_norm)
        self.linear = nn.Linear(embedding_dim, vocab_size)

    def forward(self, x):
        x = self.embeddings(x)
        x = self.linear(x)
        return x

In [None]:
def collate_skipgram(batch, text_pipeline):
    targets, contexts = [], []
    for text in batch:
        token_ids = text_pipeline(text)
        if len(token_ids) < 2:
            continue
        token_ids = token_ids[:max_sequence_length]
        for i in range(len(token_ids)):
            for offset in range(1, num_words + 1):
                if i - offset >= 0:
                    targets.append(token_ids[i])
                    contexts.append(token_ids[i - offset])
                if i + offset < len(token_ids):
                    targets.append(token_ids[i])
                    contexts.append(token_ids[i + offset])

    if not targets:
        return torch.empty(0, dtype=torch.long), torch.empty(0, dtype=torch.long)

    return torch.tensor(targets, dtype=torch.long), torch.tensor(contexts, dtype=torch.long)


def get_dataloader_and_vocab(ds_type, batch_size, shuffle=True, vocab=None):
    dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split=ds_type)
    tokenizer = get_tokenizer("basic_english")

    def data_iter():
        for example in dataset:
            yield example["text"]

    if not vocab:
        vocab = build_vocab(data_iter(), tokenizer)

    text_pipeline = lambda x: vocab(tokenizer(x))

    def collate_fn(batch):
        texts = [example["text"] for example in batch]
        return collate_skipgram(texts, text_pipeline)

    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)
    return dataloader, vocab

In [15]:
train_dataloader, vocab = get_dataloader_and_vocab("train", batch_size)
valid_dataloader, _ = get_dataloader_and_vocab("validation", batch_size)
vocab_size = len(vocab.get_stoi())
print(f"Vocabulary size: {vocab_size}")
train_steps, valid_steps = len(train_dataloader), len(valid_dataloader)

Vocabulary size: 3867


In [16]:
model = SkipGram(vocab_size)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)

In [17]:
for _ in range(epochs):
    model.train()
    for i, (x, y) in (t := tqdm(enumerate(train_dataloader), total=train_steps)):
        optimizer.zero_grad()
        loss = criterion(model(x), y)
        loss.backward()
        optimizer.step()
        t.set_description(f"train loss {loss.item():.2f}")
    model.eval()
    total = 0.0
    for i, (x, y) in enumerate(valid_dataloader):
        loss = criterion(model(x), y)
        total += loss.item()
    print(f"validation loss {(total/valid_steps):.2f}")

train loss 5.40: 100%|██████████| 383/383 [02:21<00:00,  2.71it/s]


validation loss 3.89


train loss 5.44: 100%|██████████| 383/383 [02:11<00:00,  2.92it/s]


validation loss 3.86


train loss 5.50: 100%|██████████| 383/383 [02:09<00:00,  2.96it/s]


validation loss 3.88


train loss 5.34: 100%|██████████| 383/383 [02:06<00:00,  3.02it/s]


validation loss 3.85


train loss 5.45: 100%|██████████| 383/383 [02:11<00:00,  2.90it/s]


validation loss 3.86


train loss 5.40: 100%|██████████| 383/383 [02:19<00:00,  2.75it/s]


validation loss 3.92


train loss 5.41: 100%|██████████| 383/383 [02:18<00:00,  2.77it/s]


validation loss 3.87


train loss 5.32: 100%|██████████| 383/383 [02:12<00:00,  2.90it/s]


validation loss 3.86


train loss 5.24: 100%|██████████| 383/383 [02:04<00:00,  3.08it/s]


validation loss 3.86


train loss 5.38: 100%|██████████| 383/383 [02:04<00:00,  3.07it/s]


validation loss 3.91


In [18]:
embeddings = next(model.parameters()).detach().numpy()
norms = ((embeddings ** 2).sum(axis=1) ** 0.5).reshape(-1, 1)
embeddings_norm = embeddings / norms
embeddings_norm.shape

(3867, 300)

In [19]:
for word, score in get_similar("father").items():
    print(f"{word}: {score:.3f}")

mother: 0.739
brother: 0.712
daughter: 0.699
son: 0.682
wife: 0.675
parents: 0.669
friend: 0.663
husband: 0.635
pitman: 0.621
marriage: 0.616


In [20]:
emb = embeddings[vocab["king"]] - embeddings[vocab["man"]] + embeddings[vocab["woman"]]
norm = (emb ** 2).sum() ** 0.5
emb_norm = (emb / norm).flatten()
dists = np.matmul(embeddings_norm, emb_norm).flatten()
word_ids = np.argsort(-dists)[:10]

for word_id in word_ids:
    print(f"{vocab.lookup_token(word_id)}: {dists[word_id]:.3f}")

king: 0.685
queen: 0.607
woman: 0.597
daughter: 0.593
henry: 0.584
calvert: 0.579
edward: 0.555
reign: 0.555
lord: 0.554
elizabeth: 0.537
