In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
import torch
import pickle

from src.pytorch_word2vec import neg_skipgram
from src.dataset import build_corpus

In [None]:
SIZE = 50000
RELATIONS_CSV = 'data/filtered-questions-words.csv'
MODEL_NAME = 'pytorch_model_50k_8'
VARIANT = 'reg1'
VOCAB_SIZE = 161333 if SIZE == 10000 else 512461

GENERATE_WORD_RELATIONS_LOG = False

pytorch_model = neg_skipgram(vocab_size=VOCAB_SIZE,
                             w2idx_path=f'data/{MODEL_NAME}/{VARIANT}/w2idx.pkl',
                             embedding_dimension=300,
                             regularization=0.01)

pytorch_model.load_state_dict(torch.load(f'data/{MODEL_NAME}/{VARIANT}/7_w2v.pt'))

In [3]:
def check_relation(row):
    try:
        word_one = row['word_one']
        word_two = row['word_two']
        word_three = row['word_three']
        word_four = row['word_four']
        
        embedding_two = pytorch_model.embed(word_two)
        embedding_one = pytorch_model.embed(word_one)
        embedding_three = pytorch_model.embed(word_three)
        
        predicted_vector = embedding_two - embedding_one + embedding_three

        most_similar = pytorch_model.get_similar(predicted_vector, n=5)
        words_only = [word for word, similarity in most_similar]

        return word_four in words_only, most_similar[0][0], words_only
    except KeyError:
        return None, None, None

def process_relations(df):
    results = []

    for _, row in df.iterrows():
        is_correct, predicted_word, words_only = check_relation(row)
        if predicted_word is not None:
            results.append({
                'row_id': row['row_id'],
                'category': row['category'],
                'word_one': row['word_one'],
                'word_two': row['word_two'],
                'word_three': row['word_three'],
                'word_four': row['word_four'],
                'is_correct': is_correct,
                'predicted_word': predicted_word,
                'top5': words_only
            })

    return pd.DataFrame(results)

def save_log():
    path = f'data/{MODEL_NAME}/{VARIANT}/log/'
    os.makedirs(path, exist_ok=True)
    print(path)
    with open(f'{path}/log-{VARIANT}.txt', 'w') as f:
        f.write(f"*{VARIANT}*\n")
        f.write(f"Word analogies accuracy: {word_analogy_acc:.2%}, {df['is_correct'].astype(int).sum()}/{len(df)}\n")
        f.write(f"Analogies CSV total len: {len(csv_df)}\n")

if GENERATE_WORD_RELATIONS_LOG:
    csv_df = pd.read_csv(RELATIONS_CSV)
    df = process_relations(csv_df)

    word_analogy_acc = df['is_correct'].astype(int).sum()/len(df)
    save_log()

# KL Divergence x Norm

In [4]:
with open(f'data/kl/kl-{SIZE}.txt', 'r') as arq:
    kl_divergences = eval(arq.read())

In [None]:
data = build_corpus(SIZE, return_fields=['corpus', 'word2idx', 'idx2word', 'word_count'],  load=True)
corpus, w2idx, idx2w, wc = data['corpus'], data['word2idx'], data['idx2word'], data['word_count']
filtered_words = {key: value for key, value in wc.items() if value >=10}
len(filtered_words)

In [6]:
def norm(embedding):
    return np.linalg.norm(embedding)

x, y = zip(*[(norm(pytorch_model.embed(word)), kl_divergences[word]) 
             for word in filtered_words.keys() if word in pytorch_model])

In [None]:
import scipy.stats
from matplotlib import rcParams

rcParams['font.family'] = 'Ubuntu'

grid_size = 50
hist, x_edges, y_edges = np.histogram2d(x, y, bins=grid_size)

x_idx = np.clip(np.searchsorted(x_edges, x, side='right') - 1, 0, hist.shape[0] - 1)
y_idx = np.clip(np.searchsorted(y_edges, y, side='right') - 1, 0, hist.shape[1] - 1)

density = hist[x_idx, y_idx]

fig, ax = plt.subplots(figsize=(10, 6))

scatter = plt.scatter(x, y, c=density, cmap='viridis', alpha=0.8, s=50)

cbar = plt.colorbar(scatter)
cbar.set_label('Density', fontsize=12)

m, b, r_value, p_value, std_err = scipy.stats.linregress(x, y)
ax.plot(x, m * np.array(x) + b, color='black', label='Regression Line')
ax.legend()

x_mean, y_mean = np.mean(x)*1.32, np.mean(y)*(np.mean(y)-2.4)/np.mean(y)
ax.annotate(f'r²: {r_value**2:.2f}', xy=(x_mean, y_mean + 1 * y_mean), fontsize=11)
ax.annotate(f'formula: {m:.2f}x + {b:.2f}', xy=(x_mean, y_mean), fontsize=11)

ax.set_title("Norm x KL Divergence - Reg1 model", fontsize=14)
ax.set_xlabel("Vector norm", fontsize=12)
ax.set_ylabel("KL Divergence", fontsize=12)


plt.savefig('plots/kl_norm_our_model_reg1_50k_8e_8w.png')