In [None]:
from datasets import load_dataset
from tqdm import tqdm
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
dataset = load_dataset("openwebtext", trust_remote_code=True) # 45 mins first time, after that 1.5 mins

In [2]:
import json

# Load vocab
flyvec_embeddings_path = 'simple-flyvec-embeddings.json'
with open(flyvec_embeddings_path, 'r') as file:
    embeddings = json.load(file)


vocab = {word: idx for idx, word in enumerate(embeddings.keys())}

In [None]:
from encoder import Encoder
from flyvec_model import FlyvecModel
from collections import Counter
import numpy as np
import utils

encoder = Encoder(vocab=vocab)


# Create model
model = FlyvecModel(
    K_size=400,            # Number of neurons
    vocab_size=len(vocab),  # Size of vocab
    k=1,                    # Update top-k neurons
    lr=.2,                  # Learning rate
    norm_rate=1,            # Normalization rate
    create_target_vector=True
)

window_size = 10

id_counter = Counter()
windows_count = 0
passage_count = 0

for passage in tqdm(dataset['train'], desc="Processing Passages"):
    passage_count += 1

    text = passage['text']
    preprocessed_text = encoder.preprocess(text, remove_stopwords=True)

    words_arr = np.array(preprocessed_text)
    words_arr = words_arr[:len(words_arr) - len(words_arr) % window_size]
    train_data = words_arr.reshape(-1, window_size)

    for window in train_data:
        tokenized_window = encoder.tokenize(window.tolist())
        one_hot = encoder.one_hot(tokenized_window, create_target_vector=True)
        model.update(one_hot)
        windows_count += 1
        id_counter.update(tokenized_window)

    # Save the model every 80,000 (1% of dataset)
    if passage_count % 80000 == 0:
        pct = passage_count // 80000
        utils.save_model(model, f"trained_models/original_openwebtext_checkpoints/model_checkpoint_{pct}pct.pt")   

In [None]:
word_counter = {word: id_counter.get(id, 0) for word, id in vocab.items()}
word_counter = dict(sorted(word_counter.items(), key=lambda x: x[1], reverse=True))
word_counter

In [None]:
# Find words with embeddings most similar to the target word embedding
target_word = 'bird'
hash_length = 40
top_N_closest = 20

#model = utils.load_model('trained_models/.pt')
import utils

utils.calc_print_sim_words(
    vocab=vocab,
    word_counts=word_counter,
    model=model,
    word=target_word,
    hash_len=hash_length,
    top_N=top_N_closest,
    create_target_vector=True
)