In [1]:
from datasets import load_dataset
from tqdm import tqdm
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
dataset = load_dataset("openwebtext", trust_remote_code=True) # 45 mins first time, after that 1.5 mins

Loading dataset shards:   0%|          | 0/80 [00:00<?, ?it/s]

In [2]:
import json

# Load vocab
flyvec_embeddings_path = 'simple-flyvec-embeddings.json'
with open(flyvec_embeddings_path, 'r') as file:
    embeddings = json.load(file)


vocab = {word: idx for idx, word in enumerate(embeddings.keys())}

In [3]:
from encoder import Encoder
from flyvec_model import FlyvecModel
from collections import Counter
import numpy as np
import utils

encoder = Encoder(vocab=vocab)


# Create model
model = FlyvecModel(
    K_size=400,            # Number of neurons
    vocab_size=len(vocab),  # Size of vocab
    k=1,                    # Update top-k neurons
    lr=.2,                  # Learning rate
    norm_rate=1,            # Normalization rate
    create_target_vector=True
)

window_size = 10

id_counter = Counter()
windows_count = 0
passage_count = 0

for passage in tqdm(dataset['train'], desc="Processing Passages"):
    passage_count += 1

    text = passage['text']
    preprocessed_text = encoder.preprocess(text, remove_stopwords=True)

    words_arr = np.array(preprocessed_text)
    words_arr = words_arr[:len(words_arr) - len(words_arr) % window_size]
    train_data = words_arr.reshape(-1, window_size)

    for window in train_data:
        tokenized_window = encoder.tokenize(window.tolist())
        one_hot = encoder.one_hot(tokenized_window, create_target_vector=True)
        model.update(one_hot)
        windows_count += 1
        id_counter.update(tokenized_window)

    # Save the model every 80,000 (1% of dataset)
    if passage_count % 80000 == 0:
        pct = passage_count // 80000
        utils.save_model(model, f"trained_models/original_openwebtext_checkpoints/model_checkpoint_{pct}pct.pt")   

Processing Passages:   5%|▌         | 408220/8013769 [44:16<13:44:48, 153.68it/s]


KeyboardInterrupt: 

In [4]:
word_counter = {word: id_counter.get(id, 0) for word, id in vocab.items()}
word_counter = dict(sorted(word_counter.items(), key=lambda x: x[1], reverse=True))
word_counter

{'<UNK>': 38155221,
 '<NUM>': 8305884,
 'new': 612615,
 'like': 596510,
 'people': 584814,
 'time': 567202,
 'first': 473229,
 'year': 422146,
 'years': 346766,
 'may': 314285,
 'way': 307553,
 'well': 304598,
 'last': 304433,
 'back': 302307,
 'world': 296989,
 'see': 289925,
 'state': 284012,
 'know': 261798,
 'work': 258707,
 'made': 258177,
 'game': 257141,
 'think': 254245,
 'still': 248900,
 'right': 247901,
 'going': 247409,
 'government': 236486,
 'day': 234180,
 'take': 230045,
 'says': 227783,
 'want': 226943,
 'team': 209455,
 'say': 208699,
 'long': 207710,
 'trump': 205410,
 'need': 204438,
 'part': 199503,
 'president': 190425,
 'city': 188521,
 'life': 187150,
 'told': 180098,
 'high': 179943,
 'public': 179528,
 'police': 175530,
 'end': 175080,
 'best': 170617,
 'next': 170271,
 'old': 164880,
 'system': 164616,
 'found': 161412,
 'states': 159463,
 'great': 159280,
 'according': 157966,
 'man': 157629,
 'come': 157524,
 'including': 156748,
 'called': 155835,
 'home':

In [8]:
# Find words with embeddings most similar to the target word embedding
target_word = 'bird'
hash_length = 40
top_N_closest = 20

#model = utils.load_model('trained_models/.pt')
import utils

utils.calc_print_sim_words(
    vocab=vocab,
    word_counts=word_counter,
    model=model,
    word=target_word,
    hash_len=hash_length,
    top_N=top_N_closest,
    create_target_vector=True
)

Word            Similarity Frequency 
-----------------------------------
bird                1.000       9728
adversity           0.885        993
threeday            0.885          1
citizens            0.880      31104
here’s              0.880          0
shop                0.880      17188
jake                0.880       5190
rewards             0.880       6772
est                 0.880       4538
urls                0.880       1261
group’s             0.875          0
responsibility      0.875      19924
tournament          0.875      17125
drawing             0.875      11858
admitting           0.875       2856
gaining             0.875       5233
out”                0.875          0
iso                 0.875       2720
sizeable            0.875        819
biblical            0.875       3602
