In [1]:
from functools import partial
import numpy as np
from tqdm import tqdm
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchtext.vocab import build_vocab_from_iterator 
from torchtext.data import to_map_style_dataset, utils
from torchtext.datasets import WikiText2
torch.manual_seed(1337)

<torch._C.Generator at 0x104d926d0>

In [2]:
embedding_dim = 300
max_norm = 1
min_freq = 50
num_words = 4
max_sequence_length = 256

data_dir = "../data"
epochs = 10
batch_size = 96
learning_rate = 0.005

In [3]:
class CBOW(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.embeddings = nn.Embedding(vocab_size, embedding_dim=embedding_dim, max_norm=max_norm)
        self.fc = nn.Linear(embedding_dim, vocab_size)
    
    def __call__(self, x):
        x = self.embeddings(x)
        x = x.mean(axis=1)
        x = self.fc(x)
        return x

In [4]:
def build_vocab(data_iter, tokenizer): 
    vocab = build_vocab_from_iterator(map(tokenizer, data_iter), specials=["<unk>"], min_freq=min_freq)
    vocab.set_default_index(vocab["<unk>"])
    return vocab


def collate_cbow(batch, text_pipeline):
    inputs, outputs = [], []
    for text in batch:
        token_ids = text_pipeline(text)
        if len(token_ids) < num_words * 2 + 1:
            continue
        token_ids = token_ids[:max_sequence_length]
        for i in range(len(token_ids) - num_words * 2):
            sequence = token_ids[i:num_words*2+i+1]
            outputs.append(sequence.pop(num_words))
            inputs.append(sequence)
        return torch.tensor(inputs, dtype=torch.long), torch.tensor(outputs, dtype=torch.long)


def get_dataloader_and_vocab(data_dir, ds_type, batch_size, shuffle=True, vocab=None):
    data_iter = to_map_style_dataset(WikiText2(root=data_dir, split=ds_type))
    tokenizer = utils.get_tokenizer("basic_english", language="en")
    if not vocab:
        vocab = build_vocab(data_iter, tokenizer)
    collate_fn = partial(collate_cbow, text_pipeline=lambda x: vocab(tokenizer(x)))
    dataloader = DataLoader(data_iter, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)
    return dataloader, vocab

In [5]:
train_dataloader, vocab = get_dataloader_and_vocab(data_dir, "train", batch_size)
valid_dataloader, _ = get_dataloader_and_vocab(data_dir, "valid", batch_size)
vocab_size = len(vocab.get_stoi())
print(f"Vocabulary size: {vocab_size}")
train_steps, valid_steps = len(train_dataloader), len(valid_dataloader)

model = CBOW(vocab_size)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)

for _ in range(epochs):
    model.train()
    for i, (x, y) in (t := tqdm(enumerate(train_dataloader), total=train_steps)):
        optimizer.zero_grad()
        loss = criterion(model(x), y)
        loss.backward()
        optimizer.step()
        t.set_description(f"train loss {loss.item():.2f}")
    model.eval()
    total = 0.0
    for i, (x, y) in enumerate(valid_dataloader):
        loss = criterion(model(x), y)
        total += loss.item()
    print(f"validation loss {(total/valid_steps):.2f}")

Vocabulary size: 4099


train loss 10.09: 100%|█████████████████████████████████████████████████| 383/383 [00:03<00:00, 112.11it/s]


validation loss 3.68


train loss 6.14: 100%|██████████████████████████████████████████████████| 383/383 [00:03<00:00, 114.20it/s]


validation loss 3.97


train loss 6.67: 100%|██████████████████████████████████████████████████| 383/383 [00:03<00:00, 111.54it/s]


validation loss 3.99


train loss 5.46: 100%|██████████████████████████████████████████████████| 383/383 [00:03<00:00, 111.06it/s]


validation loss 4.04


train loss 5.58: 100%|██████████████████████████████████████████████████| 383/383 [00:03<00:00, 109.29it/s]


validation loss 3.83


train loss 5.21: 100%|██████████████████████████████████████████████████| 383/383 [00:03<00:00, 111.19it/s]


validation loss 4.05


train loss 4.98: 100%|██████████████████████████████████████████████████| 383/383 [00:03<00:00, 110.15it/s]


validation loss 3.97


train loss 4.61: 100%|██████████████████████████████████████████████████| 383/383 [00:03<00:00, 108.88it/s]


validation loss 3.81


train loss 5.45: 100%|██████████████████████████████████████████████████| 383/383 [00:03<00:00, 110.66it/s]


validation loss 3.90


train loss 5.55: 100%|██████████████████████████████████████████████████| 383/383 [00:03<00:00, 110.82it/s]


validation loss 3.68


In [6]:
embeddings = next(model.parameters()).detach().numpy()
norms = ((embeddings ** 2).sum(axis=1) ** 0.5).reshape(-1, 1)
embeddings_norm = embeddings / norms
embeddings_norm.shape

(4099, 300)

In [7]:
def get_similar(word, n=10):
    word_id = vocab[word]
    if word_id == 0:
        print("out of vocabulary word")
        return {}
    word_vec = embeddings_norm[word_id].flatten()
    dists = np.matmul(embeddings_norm, word_vec).flatten()
    top_ids = np.argsort(-dists)[1:n+1]
    top_dict = {}
    for sim_word_id in top_ids:
        sim_word = vocab.lookup_token(sim_word_id)
        top_dict[sim_word] = dists[sim_word_id]
    return top_dict


for word, score in get_similar("father").items():
    print(f"{word}: {score:.3f}")

writings: 0.849
honour: 0.825
genre: 0.812
religion: 0.812
mother: 0.806
doubt: 0.802
speaking: 0.799
faced: 0.794
calvert: 0.789
friend: 0.785


In [8]:
emb = embeddings[vocab["king"]] - embeddings[vocab["man"]] + embeddings[vocab["woman"]]
norm = (emb ** 2).sum() ** 0.5
emb_norm = (emb / norm).flatten()
dists = np.matmul(embeddings_norm, emb_norm).flatten()
word_ids = np.argsort(-dists)[:10]

for word_id in word_ids:
    print(f"{vocab.lookup_token(word_id)}: {dists[word_id]:.3f}")

king: 0.867
queen: 0.832
assistant: 0.824
harbour: 0.819
attributed: 0.809
howard: 0.808
haiti: 0.808
legend: 0.804
mode: 0.803
lennon: 0.795


In [9]:
class SkipGram(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.embeddings = nn.Embedding(vocab_size, embedding_dim=embedding_dim, max_norm=max_norm)
        self.fc = nn.Linear(embedding_dim, vocab_size)
    
    def forward(self, x):
        x = self.embeddings(x)
        x = self.fc(x)
        return x

In [10]:
def collate_skipgram(batch, text_pipeline):
    inputs, outputs = [], []
    for text in batch:
        token_ids = text_pipeline(text)
        if len(token_ids) < num_words * 2 + 1:
            continue
        token_ids = token_ids[:max_sequence_length]
        for i in range(len(token_ids) - num_words * 2):
            sequence = token_ids[i:num_words*2+i+1]
            input_ = sequence.pop(num_words)
            for output in sequence:
                inputs.append(input_)
                outputs.append(output)
    return torch.tensor(inputs, dtype=torch.long), torch.tensor(outputs, dtype=torch.long)
    
    
def get_dataloader_and_vocab(data_dir, ds_type, batch_size, shuffle=True, vocab=None):
    data_iter = to_map_style_dataset(WikiText2(root=data_dir, split=ds_type))
    tokenizer = utils.get_tokenizer("basic_english", language="en")
    if not vocab:
        vocab = build_vocab(data_iter, tokenizer)
    collate_fn = partial(collate_skipgram, text_pipeline=lambda x: vocab(tokenizer(x)))
    dataloader = DataLoader(data_iter, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)
    return dataloader, vocab

In [11]:
train_dataloader, vocab = get_dataloader_and_vocab(data_dir, "train", batch_size)
valid_dataloader, _ = get_dataloader_and_vocab(data_dir, "valid", batch_size)
vocab_size = len(vocab.get_stoi())
print(f"Vocabulary size: {vocab_size}")
train_steps, valid_steps = len(train_dataloader), len(valid_dataloader)

model = SkipGram(vocab_size)
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)

for _ in range(epochs):
    model.train()
    for i, (x, y) in (t := tqdm(enumerate(train_dataloader), total=train_steps)):
        optimizer.zero_grad()
        loss = criterion(model(x), y)
        loss.backward()
        optimizer.step()
        t.set_description(f"train loss {loss.item():.2f}")
    model.eval()
    total = 0.0
    for i, (x, y) in enumerate(valid_dataloader):
        loss = criterion(model(x), y)
        total += loss.item()
    print(f"validation loss {(total/valid_steps):.2f}")

Vocabulary size: 4099


train loss 5.57: 100%|███████████████████████████████████████████████████| 383/383 [03:14<00:00,  1.97it/s]


validation loss 3.96


train loss 5.48: 100%|███████████████████████████████████████████████████| 383/383 [03:14<00:00,  1.97it/s]


validation loss 4.01


train loss 5.40: 100%|███████████████████████████████████████████████████| 383/383 [03:14<00:00,  1.97it/s]


validation loss 3.98


train loss 5.47: 100%|███████████████████████████████████████████████████| 383/383 [03:14<00:00,  1.97it/s]


validation loss 3.97


train loss 5.50: 100%|███████████████████████████████████████████████████| 383/383 [03:14<00:00,  1.97it/s]


validation loss 3.96


train loss 5.47: 100%|███████████████████████████████████████████████████| 383/383 [03:14<00:00,  1.97it/s]


validation loss 3.96


train loss 5.35: 100%|███████████████████████████████████████████████████| 383/383 [03:14<00:00,  1.97it/s]


validation loss 3.98


train loss 5.40: 100%|███████████████████████████████████████████████████| 383/383 [03:14<00:00,  1.97it/s]


validation loss 3.95


train loss 5.42: 100%|███████████████████████████████████████████████████| 383/383 [03:14<00:00,  1.97it/s]


validation loss 3.95


train loss 5.45: 100%|███████████████████████████████████████████████████| 383/383 [03:15<00:00,  1.96it/s]


validation loss 3.98


In [12]:
embeddings = next(model.parameters()).detach().numpy()
norms = ((embeddings ** 2).sum(axis=1) ** 0.5).reshape(-1, 1)
embeddings_norm = embeddings / norms
embeddings_norm.shape

(4099, 300)

In [13]:
for word, score in get_similar("father").items():
    print(f"{word}: {score:.3f}")

mother: 0.797
wife: 0.765
brother: 0.763
son: 0.741
husband: 0.717
sister: 0.708
pitman: 0.704
isabella: 0.677
daughter: 0.676
rosebery: 0.674


In [14]:
emb = embeddings[vocab["king"]] - embeddings[vocab["man"]] + embeddings[vocab["woman"]]
norm = (emb ** 2).sum() ** 0.5
emb_norm = (emb / norm).flatten()
dists = np.matmul(embeddings_norm, emb_norm).flatten()
word_ids = np.argsort(-dists)[:10]

for word_id in word_ids:
    print(f"{vocab.lookup_token(word_id)}: {dists[word_id]:.3f}")

king: 0.702
woman: 0.629
son: 0.547
banai: 0.516
isabella: 0.513
goddess: 0.513
philip: 0.508
edward: 0.507
queen: 0.506
jesus: 0.504
