In [1]:
%load_ext autoreload
%autoreload 2

In [4]:
import os
import jax
import sys
import optax
import pandas as pd
import numpy as np
import jax.nn as jnn
import jax.numpy as jnp
import plotly.express as px

from datasets import load_dataset
from collections import Counter
from tqdm.notebook import tqdm

from sklearn.metrics.pairwise import cosine_similarity

np.random.seed(69)

NLP_PATH = os.path.abspath(os.path.join(os.getcwd(), "../../"))

In [5]:
if NLP_PATH not in sys.path:
    sys.path.append(NLP_PATH)

from lib.text import create_vocabulary, remove_stopwords_and_common_words, generate_training_text, subsample_tokens, shuffle_dataset, preprocess_text

DATA_PATH = os.path.join(NLP_PATH, "data")

In [7]:
def find_most_similar_words(word, vocabulary, embeddings, top_n=5):

    if word not in vocabulary:
        raise ValueError(f"Word '{word}' not found in the vocabulary.")

    word_idx = vocabulary[word]
    target_embedding = embeddings[word_idx].reshape(1, -1)  # Shape: (1, EMBEDDING_DIM)

    similarities = cosine_similarity(target_embedding, embeddings)[0]  # Shape: (VOCAB_SIZE,)

    similar_indices = similarities.argsort()[::-1]
    similar_indices = [idx for idx in similar_indices if idx != word_idx]

    reverse_vocab = {idx: w for w, idx in vocabulary.items()}
    similar_words = [(reverse_vocab[idx], float(similarities[idx])) for idx in similar_indices[:top_n]]

    return similar_words

def resolve_analogy(word_a, word_b, word_c, vocabulary, embeddings, top_n=1):

    for word in [word_a, word_b, word_c]:
        if word not in vocabulary:
            raise ValueError(f"Word '{word}' not found in the vocabulary.")

    idx_a, idx_b, idx_c = vocabulary[word_a], vocabulary[word_b], vocabulary[word_c]
    embedding_a, embedding_b, embedding_c = embeddings[idx_a], embeddings[idx_b], embeddings[idx_c]

    analogy_vector = embedding_b - embedding_a + embedding_c

    similarities = cosine_similarity(analogy_vector.reshape(1, -1), embeddings)[0]  # Shape: (VOCAB_SIZE,)

    sorted_indices = similarities.argsort()[::-1]
    excluded_indices = {idx_a, idx_b, idx_c}
    sorted_indices = [idx for idx in sorted_indices if idx not in excluded_indices]

    reverse_vocab = {idx: word for word, idx in vocabulary.items()}
    similar_words = [(reverse_vocab[idx], float(similarities[idx])) for idx in sorted_indices[:top_n]]

    return similar_words

### Load the Wikipedia dataset

In [96]:
TOP_K_ARTICLES = 50_000
TRAINING_K_ARTICLES = np.arange(0, TOP_K_ARTICLES).tolist()

en_wikipedia_dataset = load_dataset("wikipedia", "20220301.en")

train_text_tokens = []
for wiki_text in tqdm(en_wikipedia_dataset["train"].select(range(TOP_K_ARTICLES)), desc="Processing Wikipedia dataset", total=TOP_K_ARTICLES):
    train_text_tokens.extend(preprocess_text(wiki_text['text']))

Loading dataset shards:   0%|          | 0/41 [00:00<?, ?it/s]

Processing Wikipedia dataset:   0%|          | 0/50000 [00:00<?, ?it/s]

#### Configure hyperparameters

In [97]:
TOP_K = 30_000
STRIDE = 1
WINDOWS_SIZE = 10

### Data processing

In [98]:
# create vocabulary from the dataset
train_dataset = remove_stopwords_and_common_words(train_text_tokens)
train_dataset = subsample_tokens(train_dataset, 1e-5)
vocabulary, _ = create_vocabulary(train_dataset, top_k=TOP_K)

train_dataset = [word for word in train_dataset if vocabulary.get(word, 0)]
train_token_counter = Counter(train_dataset)
len_train_dataset_tokens = len(train_dataset)

Subsampling tokens:   0%|          | 0/78333238 [00:00<?, ?it/s]

### Word2Vec Continuous Bag of Words (CBOW)


In [99]:
train = train_dataset

print(f"Train dataset: {len(train)}")
print(f"Vocabulary size: {len(vocabulary)}")

Train dataset: 33387950
Vocabulary size: 30000


In [100]:
BATCH_SIZE = 2048

print("Generating dataset")
full_dataset = list(generate_training_text(train, vocabulary, window_size=WINDOWS_SIZE, stride=STRIDE, batch_size=BATCH_SIZE, to_ids=True))

Generating dataset


In [101]:
EMBEDDING_DIM = 300

embedding_init = jax.nn.initializers.glorot_uniform()

cbow_params = {
    "embedding": embedding_init(jax.random.PRNGKey(69), (len(vocabulary), EMBEDDING_DIM)),
    "output": embedding_init(jax.random.PRNGKey(69), (EMBEDDING_DIM, len(vocabulary)))
}

@jax.jit
def context_projection(params, context_samples):
    context_vector_state = params["embedding"][context_samples]

    return jnp.mean(context_vector_state, axis=1)  # (BATCH, EMBEDDING_DIM)


@jax.jit
def forward(params, context_vector):
    context_projection_result = context_projection(params, context_vector) # (BATCH_SIZE, EMBEDDING_DIM)
    context_projection_dot_result = jnp.einsum("be,ev->bv", context_projection_result, params["output"]) # (BATCH_SIZE, VOCAB_SIZE)

    return context_projection_dot_result # (BATCH_SIZE, VOCAB_SIZE)


@jax.jit
def loss_fn(params, context_vector, target):
    logits = forward(params, context_vector)  # (BATCH_SIZE, VOCAB_SIZE)

    # compute the loss
    target_ohe = jnn.one_hot(target, len(vocabulary)) # (BATCH_SIZE, VOCAB_SIZE)
    loss_result = optax.losses.softmax_cross_entropy(logits, target_ohe).mean() # (BATCH_SIZE,)

    return loss_result

In [102]:
LR = 1e-3
EPOCHS = 50

optimizer = optax.adam(learning_rate=LR)
opt_state = optimizer.init(cbow_params)

training_loss = []

In [103]:
with tqdm() as training_progress:
    for epoch_id in range(EPOCHS):
        training_progress.set_description(f"Training: Epoch {epoch_id + 1}/{EPOCHS}")
        training_progress.reset(total=len(full_dataset))

        training_epoch_loss = []
        for (context_vector, target_vector) in shuffle_dataset(full_dataset):
            value_of_loss, grads = jax.value_and_grad(loss_fn)(cbow_params, context_vector, target_vector)
            updates, opt_state = optimizer.update(grads, opt_state)
            cbow_params = optax.apply_updates(cbow_params, updates)

            training_epoch_loss.append(value_of_loss)

            training_progress.set_postfix({"loss": value_of_loss})
            training_progress.update(1)

        training_loss.append(np.mean(training_epoch_loss))

0it [00:00, ?it/s]

Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7fe118cd4d30>>
Traceback (most recent call last):
  File "/home/dincaus/miniconda3/envs/jax_cuda/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(
KeyboardInterrupt: 


KeyboardInterrupt: 

In [104]:
px.line(
    pd.DataFrame({"loss": training_loss}),
    title="Training loss"
).update_xaxes(title="Epoch").update_yaxes(title="Loss")

In [105]:
find_most_similar_words("queen", vocabulary, cbow_params["embedding"], top_n=10)

[('king', 0.481521874666214),
 ('princess', 0.454241544008255),
 ('queens', 0.43920138478279114),
 ('duchess', 0.42338699102401733),
 ('monarch', 0.41053974628448486),
 ('prince', 0.39734184741973877),
 ('empress', 0.3920631408691406),
 ('royal', 0.38248980045318604),
 ('consort', 0.37136077880859375),
 ('regency', 0.36154845356941223)]

In [106]:
resolve_analogy(
    "king", "queen", "man",
    vocabulary, cbow_params["embedding"], top_n=5
)

[('woman', 0.4113513231277466),
 ('girl', 0.33931976556777954),
 ('loves', 0.3171383738517761),
 ('love', 0.308957576751709),
 ('femme', 0.2997463345527649)]

In [107]:
resolve_analogy(
    "paris", "france", "berlin",
    vocabulary, cbow_params["embedding"], top_n=5
)

[('germany', 0.47626009583473206),
 ('germanys', 0.43036797642707825),
 ('berlins', 0.388517290353775),
 ('gdr', 0.35874420404434204),
 ('austria', 0.34934213757514954)]

In [111]:
resolve_analogy(
    "bank", "money", "ammunition",
    vocabulary, cbow_params["embedding"], top_n=5
)

[('rifles', 0.3053884208202362),
 ('gun', 0.2996152639389038),
 ('bullets', 0.2995460629463196),
 ('projectiles', 0.2961429953575134),
 ('guns', 0.28552597761154175)]