In [1]:
import pickle

import torch
import torch.optim as optim
import torch.nn as nn
import wandb
from tqdm import tqdm
import collections

import more_itertools
import torch
from torch.utils.data import TensorDataset, DataLoader
import numpy as np

In [3]:
# map tokens to existings 

In [4]:
# get existing tokens
device = "cuda" if torch.cuda.is_available() else "cpu"

print("Loading corpus")
with open("data/processed.pkl", "rb") as f:
    (
        corpus,
        tokens,  # corpus as tokens
        words_to_ids,
        ids_to_words,
    ) = pickle.load(f)
print("Loaded corpus")

Loading corpus
Loaded corpus


In [5]:
len(tokens)

16680599

In [6]:
# ids_to_words

In [7]:
# get fine-tune vocab
with open("1m_titles.txt") as f:
    titles: str = f.read()

# get fine-tune tokens 
def preprocess(text: str) -> list[str]:
    text = text.lower()
    text = text.replace(".", " <PERIOD> ")
    text = text.replace(",", " <COMMA> ")
    text = text.replace('"', " <QUOTATION_MARK> ")
    text = text.replace(";", " <SEMICOLON> ")
    text = text.replace("!", " <EXCLAMATION_MARK> ")
    text = text.replace("?", " <QUESTION_MARK> ")
    text = text.replace("(", " <LEFT_PAREN> ")
    text = text.replace(")", " <RIGHT_PAREN> ")
    text = text.replace("--", " <HYPHENS> ")
    text = text.replace("?", " <QUESTION_MARK> ")
    text = text.replace(":", " <COLON> ")
    words = text.split()
    stats = collections.Counter(words)
    words = [word for word in words if stats[word] > 5]
    return words


ft_corpus: list[str] = preprocess(titles)
print(f"fine-tune corpus created: ", ft_corpus[:5])

def create_lookup_tables(words: list[str]) -> tuple[dict[str, int], dict[int, str]]:
    word_counts = collections.Counter(words)
    vocab = sorted(word_counts, key=lambda k: word_counts.get(k), reverse=True)
    int_to_vocab = {ii + 1: word for ii, word in enumerate(vocab)}
    int_to_vocab[0] = "<PAD>"
    vocab_to_int = {word: ii for ii, word in int_to_vocab.items()}
    return vocab_to_int, int_to_vocab


# Subsample corpus
def subsample_corpus(corpus_tokens, threshold):
    token_counts = collections.Counter(corpus_tokens)

    def keep_probability(token) -> float:
        return 1 - np.sqrt((total_token_len * threshold) / token_counts[token])

    tqdm_tkn = tqdm(corpus_tokens, desc="Subsampling")

    subsampled = [
        token for token in tqdm_tkn if keep_probability(token) > np.random.ranf()
    ]
    return subsampled


ft_words_to_ids, ft_ids_to_words = create_lookup_tables(ft_corpus)
print("Created fine tune lookup tables")
ft_tokens = [ft_words_to_ids[word] for word in ft_corpus]
ft_token_counts = collections.Counter(ft_tokens)
total_token_len = len(ft_tokens)

ft_tokens_sub = subsample_corpus(corpus_tokens=ft_tokens, threshold=1e-5)

context_window = 5
if context_window % 2 == 0:
    raise Exception(
        f"Context Window must be an odd number, currently: {context_window}"
    )
center_i = context_window // 2

windows = more_itertools.windowed(ft_tokens_sub, context_window)

wind_tq = tqdm(windows, desc="Sliding window", total=len(ft_tokens_sub))
neg_sample_count = 20

inputs = []
targets = []
neg_samples = []
for tkn_wind in wind_tq:
    inputs.append(tkn_wind[center_i])
    targets.append(tkn_wind[:center_i] + tkn_wind[center_i + 1 :])
    negs = [
        words_to_ids[corpus[id]]
        for id in np.random.randint(0, total_token_len, neg_sample_count)
    ]
    neg_samples.append(negs)


input_tensor = torch.LongTensor(inputs)
target_tensor = torch.LongTensor(targets)
negs_tensor = torch.LongTensor(neg_samples)

fine-tune corpus created:  ['a', 'p2p', 'orchestrator', 'darcs', '2']
Created fine tune lookup tables


Subsampling: 100%|██████████| 8188131/8188131 [00:17<00:00, 467869.63it/s]
Sliding window: 100%|█████████▉| 5871651/5871655 [02:13<00:00, 43911.85it/s]


In [8]:
# ft_ids_to_words
def map_fine_tune_tokens(ft_tokens, ids_to_words):
    words_to_ids = {word: id_ for id_, word in ids_to_words.items()}
    
    ft_ids_to_words = {}

    for token in ft_tokens:
        ft_ids_to_words[token] = words_to_ids.get(token, 0)
    
    return ft_ids_to_words

In [9]:
ft_ids_to_words = map_fine_tune_tokens(ft_tokens, ids_to_words)

In [10]:
len(ft_ids_to_words)

45181

In [17]:
class Word2Vec(nn.Module):
    def __init__(self, embedding_dim, vocab_size):
        super().__init__()
        self.center_embed = nn.Embedding(vocab_size, embedding_dim)
        self.context_projection_embed = nn.Embedding(vocab_size, embedding_dim)
        # self.sig = nn.Sigmoid()
        self.loss = nn.BCEWithLogitsLoss()

    def get_loss(self, inpt, trgs, rand):
        emb = self.center_embed(inpt)

        ctx = self.context_projection_embed(trgs)

        neg = self.context_projection_embed(rand)

        pos_logits = torch.bmm(ctx, emb.unsqueeze(-1)).squeeze()
        neg_logits = torch.bmm(neg, emb.unsqueeze(-1)).squeeze()

        pos_labels = torch.ones_like(pos_logits)
        neg_labels = torch.zeros_like(neg_logits)

        logits = torch.cat([pos_logits, neg_logits], dim=1)
        labels = torch.cat([pos_labels, neg_labels], dim=1)

        return self.loss(logits, labels)

    def forward(self, id):
        return self.center_embed(id)

In [18]:
# initialise model 
vocab_size = len(ids_to_words)
embed_dim = 50
model = Word2Vec(embed_dim, len(ids_to_words)).to(device)

In [19]:
torch.cuda.is_available()
device = "cuda" if torch.cuda.is_available() else "cpu"

In [20]:
# initialise model with saved weights and biases 
# Model setup
# Load the w2v weights via the model
device = "cuda" if torch.cuda.is_available() else "cpu"
w2v = Word2Vec(embed_dim, vocab_size).to(device)
model_path = "models/w2v_epoch_11.pth"

w2v.load_state_dict(torch.load(model_path, map_location=device, weights_only=True))
print("W2V loaded")

W2V loaded


In [21]:
print("Model initialised")

optimizer = optim.Adam(model.parameters(), lr=0.01)

context_window = 5
batch_size = 500_000

load_path = "1m_titles_set.pth"
# input_tensor, target_tensor, negs_tensor = torch.load(load_path)
dataset = torch.utils.data.TensorDataset(input_tensor, target_tensor, negs_tensor)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
print("Loaded dataset")

Model initialised
Loaded dataset


In [22]:
wandb.init(
    project="word2vec",
    name="recovering",
    config={
        "batch_size": batch_size,
        "context_window": context_window,
        "embed_dims": embed_dim,
    },
)
for epoch in range(1):
    prgs = tqdm(dataloader, desc=f"Epoch {epoch+1}")

    for inputs, targets, negs in prgs:
        inputs, targets, negs = inputs.to(device), targets.to(device), negs.to(device)
        optimizer.zero_grad()

        loss = model.get_loss(inputs, targets, negs)

        loss.backward()

        optimizer.step()
        wandb.log({"loss": loss.item()})

    if (not (epoch + 1) % 5):
        save_path = f"checkpoints/w2v_epoch_{epoch+1}.pth"
        torch.save(model.state_dict(), save_path)
wandb.finish()

Epoch 1: 100%|██████████| 12/12 [01:27<00:00,  7.29s/it]


0,1
loss,█▇▇▆▅▅▄▃▃▂▂▁

0,1
loss,2.36106


In [None]:
# get embeddings 
