In [1]:
from datasets import load_dataset
from tqdm import tqdm

dataset = load_dataset("openwebtext", trust_remote_code=True) # 45 mins first time, after that 1.5 mins

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
import json

# Load vocab
flyvec_embeddings_path = 'simple-flyvec-embeddings.json'
with open(flyvec_embeddings_path, 'r') as file:
    embeddings = json.load(file)


vocab = {word: idx for idx, word in enumerate(embeddings.keys())}

In [None]:
from encoder import Encoder
from context_model import ContextModel
from collections import Counter
import numpy as np
import utils

encoder = Encoder(vocab=vocab)


# Create model
model = ContextModel(
    K_size= 400,            # Number of neurons
    vocab_size=len(vocab),  # Size of vocab
    k=5,                    # Update top-k neurons
    lr=.1,                  # Learning rate
    norm_rate=5             # Normalization rate
)


window_size = 10

id_counter = Counter()
windows_count = 0
passage_count = 0

for passage in tqdm(dataset['train'], desc="Processing Passages"):
    passage_count += 1

    text = passage['text']
    preprocessed_text = encoder.preprocess(text, remove_stopwords=True)

    words_arr = np.array(preprocessed_text)
    words_arr = words_arr[:len(words_arr) - len(words_arr) % window_size]
    train_data = words_arr.reshape(-1, window_size)    

    for window in train_data:
        tokenized_window = encoder.tokenize(window.tolist())
        one_hot = encoder.one_hot(tokenized_window)
        model.update(one_hot)       
        windows_count += 1  
        id_counter.update(tokenized_window)
        
    # Save the model every 80,000 (1% of dataset)
    if passage_count % 80000 == 0:
        pct = passage_count // 80000
        utils.save_model(model, f"trained_models/context_openwebtext_checkpoints/model_checkpoint_{pct}pct.pt")   

Processing Passages:  18%|█▊        | 1437786/8013769 [9:55:20<45:22:54, 40.25it/s] 


KeyboardInterrupt: 

In [14]:
word_counter = {word: id_counter.get(id, 0) for word, id in vocab.items()}
word_counter = dict(sorted(word_counter.items(), key=lambda x: x[1], reverse=True))
word_counter

{'<UNK>': 133635521,
 '<NUM>': 29368999,
 'new': 2139213,
 'like': 2087492,
 'people': 2049514,
 'time': 1977936,
 'first': 1664966,
 'year': 1485465,
 'years': 1213932,
 'may': 1095476,
 'way': 1077164,
 'last': 1069148,
 'well': 1064590,
 'back': 1057194,
 'world': 1037575,
 'see': 1018095,
 'state': 1001386,
 'know': 912465,
 'game': 905369,
 'made': 905246,
 'work': 905141,
 'think': 890999,
 'still': 872491,
 'right': 870287,
 'going': 863836,
 'government': 828714,
 'day': 821223,
 'take': 806901,
 'says': 794998,
 'want': 793269,
 'team': 738008,
 'trump': 728453,
 'say': 726921,
 'long': 723920,
 'need': 717550,
 'part': 699527,
 'president': 671164,
 'city': 664278,
 'life': 647836,
 'public': 634444,
 'high': 631308,
 'told': 630133,
 'end': 611370,
 'police': 607518,
 'next': 603023,
 'best': 598467,
 'old': 575923,
 'system': 575325,
 'found': 563720,
 'states': 561891,
 'great': 555454,
 'according': 553467,
 'man': 551494,
 'come': 550918,
 'including': 547455,
 'called':

In [None]:
# Find words with embeddings most similar to the target word embedding
target_word = 'concert'
hash_length = 70
top_N_closest = 20

#model = utils.load_model('trained_models/.pt')
import utils

utils.calc_print_sim_words(
    vocab=vocab,
    word_counts=word_counter,
    model=model,
    word=target_word,
    hash_len=hash_length,
    top_N=top_N_closest
)

Word            Similarity Frequency 
-----------------------------------
concert             1.000      24579
missed              0.770      55568
pop                 0.770      62099
diet                0.770      38875
mood                0.770      24341
lyrics              0.770      17302
swift               0.770      22656
coat                0.765      14766
remaining           0.765      62704
noticeable          0.765       9035
essay               0.760      19074
instrumental        0.760      10018
impoverished        0.760       5335
types               0.760      93433
father              0.760     168250
located             0.760      70000
songs               0.760      57015
injection           0.760      13946
invest              0.760      29146
memo                0.760      19636
