In [None]:
import os
import torch as tf
import numpy as np
import seaborn as sns
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
os.environ["TORCHINDUCTOR_AUTOGRAD_CACHE"] = "1"

In [None]:
N_TOKENS = 16
N_L2_REF_DISTANCES = 1000

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd


def render_L2_distance_comparison(token_travel_dists):

    # Your data: shape [16, 1000]
    data = L2_distance_dists  # shape [16, 1000]

    # Reshape to long format for seaborn
    df = pd.DataFrame({
        'Example': np.repeat(np.arange(16) + 1, 1000),
        'L2 Distance': data.flatten()
    })

    plt.style.use('ggplot')
    sns.set_palette("Set2")

    reference_points = token_travel_dists

    fig, ax = plt.subplots(figsize=(14, 6))
    sns.violinplot(data=df, x='Example', y='L2 Distance', inner='quartile', color='gray', ax = ax)
    ax.scatter([], [], color='gray', label='Distribution of L2 Distances To Random Points In Hyper-Ellipse')

    ax.scatter(np.arange(16), reference_points, color='black', s=100, 
            zorder=5, marker='x', edgecolors='darkred', linewidth=2, label='L2 Distance To Epsilon Level Set Converged Points')

    plt.title('Comparison of L2 Distances To Epsilon Level-Set vs. Distribution Of L2 Distance To Other Points (N=1000)')
    plt.xlabel('Example Index')
    plt.ylabel('L2 Distance')
    plt.legend(loc="upper right")
    plt.ylim(plt.ylim()[0], plt.ylim()[1]*1.2)
    plt.tight_layout()
    plt.show()

In [None]:
from compute_batch_llama_epsilon_level_sets import sample_vocab_ellipse
randomly_sampled_input_embeds = sample_vocab_ellipse(N_TOKENS, random_seed = 42)

In [None]:
L2_ref_embeddings = sample_vocab_ellipse(N_TOKENS * N_L2_REF_DISTANCES, random_seed = 43).reshape([N_TOKENS, N_L2_REF_DISTANCES, -1])

In [None]:
L2_distance_dists = np.linalg.norm(np.expand_dims(randomly_sampled_input_embeds, 1) - L2_ref_embeddings, axis = -1) # [N_TOKENS, N_L2_REF_DISTANCES]

In [None]:

# EXP 1
# An important discovery: If we set learning rate to 1e-2 we get a point on the epsilon level-set FURTHER AWAY than typical point!
# Thus discovering the "closest" (or closer) points seems dependent on choice of learning rate


from llama_experiments import HORSE_DISTRIBUTION
import importlib, compute_batch_llama_epsilon_level_sets
importlib.reload(compute_batch_llama_epsilon_level_sets)
learn_embeddings = compute_batch_llama_epsilon_level_sets.learn_embeddings
BatchArgs = compute_batch_llama_epsilon_level_sets.BatchArgs
import numpy as np

all_args = {}
results = {}

args = BatchArgs(max_iters = 1_000_000,
        learning_rate = 1e-3,
        epsilon = 1e-3,
        epsilon_tolerance_scale=1e-3,
        inputs_embeds = randomly_sampled_input_embeds,
        target_probabilities=np.tile(HORSE_DISTRIBUTION, (N_TOKENS,1))
)

result = learn_embeddings(args)

token_L2_travel_dist_HORSE_LR_1e_1 = np.linalg.norm(result.inputs_embeds - result.starting_inputs_embeds, axis = 1)

render_L2_distance_comparison(token_L2_travel_dist_HORSE_LR_1e_1)

In [None]:
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
plt.style.use('ggplot')

from llama_models import np_vocab_embedding_no_special_tokens


print(type(np_vocab_embedding_no_special_tokens), np_vocab_embedding_no_special_tokens.shape)
idxs_1 = np.random.choice(list(range(len(np_vocab_embedding_no_special_tokens))), 1000)
idxs_2 = np.random.choice(list(range(len(np_vocab_embedding_no_special_tokens))), 1000)
random_vocab_pair_norms = np.linalg.norm(np_vocab_embedding_no_special_tokens[idxs_1] - np_vocab_embedding_no_special_tokens[idxs_2], ord = 2, axis = -1)

sns.kdeplot(random_vocab_pair_norms)
plt.title("Δθ = ‖θⱼ − θᵢ‖₂")
plt.show()


In [None]:
import seaborn as sns

diff = (learned_embeddings - initial_embeddings)
diff = diff.reshape(-1,diff.shape[-1])
norms = np.linalg.norm(learned_embeddings - initial_embeddings, ord = 2, axis = -1)
norms = norms.flatten()
diff = diff[norms > 0]
plt.title("Per-Dimension Displacement")

sns.kdeplot(diff.flatten())
plt.show()
