In [2]:
import json
import re
from collections import Counter
import numpy as np
import random
import torch
from tqdm import tqdm
import nltk
from nltk.corpus import stopwords
#disable ssl
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
nltk.download('stopwords')

from flyvec_model import FlyvecModel
import preprocess_books as prep
import utils

[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/naturalhg/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [3]:
# Load data and combine books into a string
combined_books_text = prep.load_book_data('data/train.json')

# Clean/filter text
words_list, word_counts, vocab = prep.preprocess_text(combined_books_text)

# Create training data: np array shape [N, window_size]
train_data = prep.prepare_training_data(words_list, window_size=10)

print(f'train data shape: {train_data.shape}')
print(f'train sample: {random.choice(train_data)}')

train data shape: (232446, 10)
train sample: ['noise' 'heat' 'joy' '<unk>' '<unk>' 'noon' 'day' 'received' 'note'
 'went']


In [11]:
# Create model
model = FlyvecModel(
    K_size=350,            # Number of neurons
    vocab_size=len(vocab),  # Size of vocab
    k=1,                    # Update top-k neurons
    lr=.2,                  # Learning rate
    norm_rate=1,            # Normalization rate
    create_target_vector=True
)

# Create encoder
enc = utils.Encoder(vocab)

# Train model
num_epochs = 10

for i in range(num_epochs):
    for num, sample in enumerate(tqdm(train_data, desc=f'Epoch {i+1}/{num_epochs}', ncols=100, leave=True)):
        enc_sample = enc.one_hot(sample, create_target_vector=True)
        model.update(enc_sample)

# Save model
utils.save_model(model, f'trained_models/original_model_epoch{num_epochs}_books.pt')

Epoch 1/10: 100%|████████████████████████████████████████| 232446/232446 [00:18<00:00, 12739.85it/s]
Epoch 2/10: 100%|████████████████████████████████████████| 232446/232446 [00:19<00:00, 11979.43it/s]
Epoch 3/10: 100%|████████████████████████████████████████| 232446/232446 [00:20<00:00, 11532.27it/s]
Epoch 4/10: 100%|████████████████████████████████████████| 232446/232446 [00:20<00:00, 11599.49it/s]
Epoch 5/10: 100%|████████████████████████████████████████| 232446/232446 [00:20<00:00, 11133.17it/s]
Epoch 6/10: 100%|████████████████████████████████████████| 232446/232446 [00:20<00:00, 11342.56it/s]
Epoch 7/10: 100%|████████████████████████████████████████| 232446/232446 [00:20<00:00, 11068.89it/s]
Epoch 8/10: 100%|████████████████████████████████████████| 232446/232446 [00:20<00:00, 11226.99it/s]
Epoch 9/10: 100%|████████████████████████████████████████| 232446/232446 [00:20<00:00, 11491.91it/s]
Epoch 10/10: 100%|███████████████████████████████████████| 232446/232446 [00:20<00:00, 1150

In [12]:
# Find words with embeddings most similar to the target word embedding
target_word = 'ship'
hash_length = 40
top_N_closest = 20

#model = utils.load_model('trained_models/original_model_epoch3_books.pt')

utils.calc_print_sim_words(
    vocab=vocab,
    word_counts=word_counts,
    model=model,
    word=target_word,
    hash_len=hash_length,
    top_N=top_N_closest,
    create_target_vector=True
)

Word            Similarity Frequency 
-----------------------------------
ship                1.000       1203
springing           0.880         61
wreck               0.869        101
spots               0.869         75
assist              0.863        139
consideration       0.863        258
nations             0.863        111
cause               0.863        661
cattle              0.863         96
intense             0.863        152
mantle              0.863         61
flies               0.863         92
cartridges          0.863         52
deaf                0.857         65
secure              0.857        213
journal             0.857        145
purely              0.857         76
trembled            0.857        182
prevailed           0.857         77
submit              0.857         85
