In [None]:
from transformers import AutoTokenizer, Trainer, AutoModelForCausalLM
import torch
from torch import optim

if torch.cuda.is_available():
    device = 'cuda'
elif torch.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'


model = AutoModelForCausalLM.from_pretrained(
    'EleutherAI/gpt-neox-20b',
    dtype=torch.float16,
    low_cpu_mem_usage=True,
    )

tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

embed_mat = model.get_input_embeddings().weight.to(device)
embed_mat.to(device)
del model, tokenizer

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

if torch.cuda.is_available():
    device = 'cuda'
elif torch.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'


model = AutoModelForCausalLM.from_pretrained('openai-community/gpt2', dtype=torch.float16)

tokenizer = AutoTokenizer.from_pretrained('openai-community/gpt2')
if tokenizer.pad_token is None:
    tokenizer.pad_token = '<|pad|>'

tokenizer.pad_token

In [None]:
import torch.nn.functional as F
from hf_olmo import OLMoTokenizerFast

# def create_neighbor_lookup_table(embed_matrix, num_neighbors):
#     seq_len = embed_matrix.shape[0]
#     scaffold = []
#     for token_idx in range(seq_len):
#         target = embed_matrix[token_idx]
#         norms = torch.norm(embed_matrix - target, dim=-1, p=2)
#         norms[token_idx] = float('inf')
#         knn = norms.topk(num_neighbors, largest=False, sorted=True) # chooses smallest, sorts from smallest to largest
#         distances, indexes = knn[0], knn[1]
#         # distance_weights = F.softmax(-distances, dim=-1) # for weighting by distance
#         # neg_log_likelihoods = -torch.log(distance_weights) * distance_weights
#         # assert neg_log_likelihoods.shape[0] == num_neighbors, "check neg_log_likelihood/weights sizes"
#         scaffold.append([(distances, indexes)])
#     return scaffold

def create_neighbor_lookup_table(embed_matrix, k): # selects neighbors based on cos similarity and l2 distance now. we have achieved a dynamic set!
    with torch.no_grad():
        #cos vers
        cosine_similarity_scores = F.normalize(embed_matrix, p=2, dim=-1) @ (F.normalize(embed_matrix, p=2, dim=-1)).T
        cosine_similarity_scores = torch.clamp(cosine_similarity_scores, min=-1.0 + 1e-7, max=1.0-1e-7)
        angles = torch.acos(cosine_similarity_scores)
        angles.fill_diagonal_(float('inf')) #angle of vector with itself is 0. would be selected lmao.
        cos_sorted = torch.topk(torch.acos(cosine_similarity_scores), k=k, dim=-1, largest=False, sorted=True)
        cos_angles, cos_indices = cos_sorted[0], cos_sorted[1]
    
        # embed_matrix = embed_matrix.unsqueeze(0) # (1, seq_len, dim)
        # dists = torch.cdist(embed_matrix, embed_matrix, p=2).squeeze() # cdist necessitates batch dim for som e reason. creates it, does the cdist, then removes the auxiliary dim.
        # # distance of a vector with itself is zero. fill it with float('inf') so that topk doesnl't select it
        # dists.fill_diagonal_(float('inf'))
        # sorted = torch.sort(dists, dim=-1)
        # l2_indices = sorted[1]
        # l2_distances = sorted[0]

        # filter = cos_indicies == l2_indices

        return cos_angles, cos_indices
    
#ver 3
def create_lookup_table(model, num_neighbors): #perform on h100 or higher for improved efficacy.

    tokenizer = OLMoTokenizerFast.from_pretrained("allenai/OLMo-1B")
    sorted_vocab = sorted(tokenizer.get_vocab().items(), key=lambda item: item[1]) #list of tuples (token, id)
    dict_sorted_vocab = dict(sorted_vocab)
    token_strings = [pair[0] for pair in sorted_vocab]

    similarity_list = [] #will store ordered cos sims
    index_list = [] # will store respective indices
    for token in token_strings:
        similarities = model.similarity(model.encode_query(token), model.encode_document(token_strings))
        similarities[dict_sorted_vocab[token]] = float('-inf') #when passed into toipk, it will be not selected. we want the tokens with largest cos sim.
        knn = torch.topk(similarities, k=num_neighbors, largest=True, sorted=True) #greatest to least
        similarity_list.append(list(knn[0]))
        index_list.append(list(knn[1]))

    return torch.tensor(similarity_list), torch.tensor(index_list)




In [None]:
test_scaffold = create_neighbor_lookup_table(embed_mat, 1)

test = torch.ones((3,2))
test

In [None]:
model_state_dict = model.state_dict()

embed_mat = model_state_dict[list(model_state_dict.keys())[0]]
embed_mat.shape
embed_mat.shape[0]


In [None]:
import torch
test = torch.ones((3))
test.shape[0]

In [None]:
import torch

rand_idx = int((torch.rand((1,)) *  embed_mat.shape[0]).item())


test_target = embed_mat[rand_idx]
dist = torch.norm(embed_mat - test_target, dim=1, p=None)
dist[rand_idx] = float('inf')
knn = dist.topk(2, largest=False, sorted=True)

knn



In [None]:
import torch

x = torch.tensor([10, 3, 7, 20, 15])

values, indices = torch.topk(x, k=3, sorted=True)

print("input:", x)
print("topk values:", values)
print("topk indices:", indices)
print("values via indices:", x[indices])


In [None]:
import numpy as np
import torch
from transformers import AutoModel, AutoTokenizer
from sklearn.metrics.pairwise import cosine_similarity
import matplotlib.pyplot as plt

# Load EmbeddingGemma model and tokenizer
model_name = "google/embeddinggemma-300m"
other_model = AutoModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Get embedding matrix
embed_matrix = other_model.get_input_embeddings().weight

# Convert to numpy
if torch.is_tensor(embed_matrix):
    embed_matrix_np = embed_matrix.detach().cpu().float().numpy()
else:
    embed_matrix_np = np.array(embed_matrix)

print(f"Embedding matrix shape: {embed_matrix_np.shape}")
print(f"Vocabulary size: {len(tokenizer)}")

In [None]:
def find_nearest_neighbors(token, k=10):
    """
    Find k nearest neighbors to a given token based on cosine similarity.
    
    Args:
        token: Input token (string)
        k: Number of nearest neighbors to return
    
    Returns:
        List of tuples: (neighbor_token, cosine_similarity, angle_degrees, l2_distance)
    """
    # Encode the token
    token_ids = tokenizer.encode(token, add_special_tokens=False)
    
    if len(token_ids) == 0:
        return f"Could not encode token: {token}"
    
    # Use the first token if multiple tokens are generated
    token_id = token_ids[0]
    
    # Get the embedding for this token
    query_embedding = embed_matrix_np[token_id].reshape(1, -1)
    
    # Compute cosine similarities with all tokens
    similarities = cosine_similarity(query_embedding, embed_matrix_np)[0]
    
    # Get top k+1 indices (including the token itself)
    top_indices = np.argsort(similarities)[::-1][:k+1]
    
    results = []
    for idx in top_indices:
        if idx == token_id:
            continue  # Skip the token itself
            
        neighbor_token = tokenizer.decode([idx])
        cos_sim = similarities[idx]
        
        # Calculate angle in degrees: angle = arccos(cosine_similarity)
        angle_rad = np.arccos(np.clip(cos_sim, -1.0, 1.0))
        angle_deg = np.degrees(angle_rad)
        
        # Calculate L2 distance
        l2_dist = np.linalg.norm(query_embedding - embed_matrix_np[idx])
        
        results.append((neighbor_token, cos_sim, angle_deg, l2_dist))
        
        if len(results) == k:
            break
    
    return results


def find_nearest_neighbors_l2(token, k=10):
    """
    Find k nearest neighbors to a given token based on L2 distance.
    
    Args:
        token: Input token (string)
        k: Number of nearest neighbors to return
    
    Returns:
        List of tuples: (neighbor_token, angle_degrees, l2_distance)
    """
    # Encode the token
    token_ids = tokenizer.encode(token, add_special_tokens=False)
    
    if len(token_ids) == 0:
        return f"Could not encode token: {token}"
    
    # Use the first token if multiple tokens are generated
    token_id = token_ids[0]
    
    # Get the embedding for this token
    query_embedding = embed_matrix_np[token_id]
    
    # Compute L2 distances with all tokens
    l2_distances = np.linalg.norm(embed_matrix_np - query_embedding, axis=1)
    
    # Get top k+1 indices (including the token itself)
    top_indices = np.argsort(l2_distances)[:k+1]
    
    results = []
    for idx in top_indices:
        if idx == token_id:
            continue  # Skip the token itself
            
        neighbor_token = tokenizer.decode([idx])
        l2_dist = l2_distances[idx]
        
        # Calculate cosine similarity and angle
        neighbor_embedding = embed_matrix_np[idx]
        cos_sim = np.dot(query_embedding, neighbor_embedding) / (
            np.linalg.norm(query_embedding) * np.linalg.norm(neighbor_embedding)
        )
        
        # Calculate angle in degrees: angle = arccos(cosine_similarity)
        angle_rad = np.arccos(np.clip(cos_sim, -1.0, 1.0))
        angle_deg = np.degrees(angle_rad)
        
        results.append((neighbor_token, angle_deg, l2_dist))
        
        if len(results) == k:
            break
    
    return results


def plot_neighbors_l2_vs_angle(neighbors, title=None, save_path=None, show=True):
    """Plot L2 distance (x) vs angle in degrees (y) for a neighbors list.

    neighbors: list of (token_str, angle_deg, l2_distance)
    """
    if not neighbors:
        raise ValueError("neighbors list is empty")

    tokens = [n[0] for n in neighbors]
    angles = np.array([float(n[1]) for n in neighbors])
    l2 = np.array([float(n[2]) for n in neighbors])

    fig, ax = plt.subplots(figsize=(7, 5))
    sc = ax.scatter(l2, angles, c=l2, cmap="viridis", s=80, edgecolors="k", alpha=0.9)

    for i, tok in enumerate(tokens):
        ax.annotate(str(tok), (l2[i], angles[i]), xytext=(6, 4), textcoords="offset points", fontsize=8)

    ax.set_xlabel("L2 distance")
    ax.set_ylabel("Angle (degrees)")
    ax.set_title(title or "L2 distance vs Angle")
    cb = fig.colorbar(sc, ax=ax)
    cb.set_label("L2 distance")
    fig.tight_layout()

    if save_path:
        fig.savefig(save_path, dpi=200)

    if show:
        plt.show()

    return fig, ax


def plot_neighbors_l2_vs_angle_cos(neighbors, title=None, save_path=None, show=True):
    """Plot L2 distance (x) vs angle in degrees (y) for a neighbors list.

    neighbors: list of (token_str, cos_sim, angle_deg, l2_distance)
    """
    if not neighbors:
        raise ValueError("neighbors list is empty")

    tokens = [n[0] for n in neighbors]
    angles = np.array([float(n[2]) for n in neighbors])
    l2 = np.array([float(n[3]) for n in neighbors])

    fig, ax = plt.subplots(figsize=(7, 5))
    sc = ax.scatter(l2, angles, c=l2, cmap="viridis", s=80, edgecolors="k", alpha=0.9)

    for i, tok in enumerate(tokens):
        ax.annotate(str(tok), (l2[i], angles[i]), xytext=(6, 4), textcoords="offset points", fontsize=8)

    ax.set_xlabel("L2 distance")
    ax.set_ylabel("Angle (degrees)")
    ax.set_title(title or "L2 distance vs Angle")
    cb = fig.colorbar(sc, ax=ax)
    cb.set_label("L2 distance")
    fig.tight_layout()

    if save_path:
        fig.savefig(save_path, dpi=200)

    if show:
        plt.show()

    return fig, ax


In [9]:
token = "to"

neighbors_cos = find_nearest_neighbors(token, k=50)

print(f"\nNearest neighbors for token (using cosine similarity): '{token}'\n")
print(f"{'Token':<20} {'Cosine Sim':<12} {'Angle (°)':<12} {'L2 Distance':<12}")
print("-" * 60)
for neighbor, cos_sim, angle, l2_dist in neighbors_cos:
    print(f"{neighbor:<20} {cos_sim:<12.6f} {angle:<12.4f} {l2_dist:<12.4f}")


Nearest neighbors for token (using cosine similarity): 'to'

Token                Cosine Sim   Angle (°)    L2 Distance 
------------------------------------------------------------
To                   0.668752     48.0292      1.2046      
 to                  0.626267     51.2248      1.2688      
TO                   0.617295     51.8811      1.3679      
 To                  0.598961     53.2045      1.3048      
 TO                  0.552411     56.4674      1.3444      
the                  0.508897     59.4096      1.4505      
ta                   0.495218     60.3159      1.6385      
tos                  0.491079     60.5885      1.7003      
of                   0.446269     63.4954      1.5997      
so                   0.446135     63.5040      1.7486      
for                  0.420356     65.1429      1.6808      
with                 0.419574     65.1923      1.6739      
te                   0.416712     65.3728      1.7772      
and                  0.411123     65.

In [None]:
try:
    fig, ax = plot_neighbors_l2_vs_angle_cos(neighbors_cos, title=f"Neighbors | cos_similarity for '{token}'")
except NameError:
    print("`neighbors` or `token` not defined in notebook; call find_nearest_neighbors_l2 first and re-run this cell.")

Appending new objectives it would be as if it were optimizing over a separate sequence. Since they both yield similar neighbors when k ≅ 5 we can choose either euclidean distance or cosine similarity arbitrarily. However, if we would like to glean better results, we would have to perform this empirically. Run the experiment both on euclidean and cosine and see which one works better.

 In order to reconcile both metrics, as cosine similarity or euclidean distance may be more informative depending on the context, we can choose neighbors in which both metrics align. (meaning that if one increases and the other also increases (doesn't decrease) and the token name is the same, keep the token).

In [None]:
import torch
import math
from transformers import AutoModelForCausalLM, AutoTokenizer

# Set device to GPU if available, else CPU
device = "mps" if torch.mps.is_available() else "cpu"
print(f"Running on: {device}")

def generate_with_top_k_stats(prompt, num_generate=10, top_k_candidates=5, model_name='gpt2'):
    print(f"Loading {model_name}...")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name)
    
    # MOVE MODEL TO DEVICE
    model.to(device)
    model.eval()

    # FIX PAD TOKEN WARNING (GPT-2 specific)
    tokenizer.pad_token_id = tokenizer.eos_token_id

    # Encode and move inputs to device
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    input_ids = inputs["input_ids"]

    generation_stats = []

    print(f"\nGenerating {num_generate} tokens for prompt: '{prompt}'\n")

    with torch.no_grad():
        for i in range(num_generate):
            # Forward pass
            outputs = model(input_ids)
            predictions = outputs.logits

            # Get logits for the last token
            next_token_logits = predictions[0, -1, :]

            # Softmax
            probs = torch.softmax(next_token_logits, dim=-1)

            # Get Top-K
            top_probs, top_indices = torch.topk(probs, top_k_candidates)

            step_candidates = []
            for prob, idx in zip(top_probs, top_indices):
                token_str = tokenizer.decode([idx])
                # Calculate perplexity (1/prob)
                token_perplexity = math.exp(-torch.log(prob).item()) if prob > 0 else float('inf')
                
                step_candidates.append({
                    "token": token_str,
                    "prob": prob.item(),
                    "perplexity": token_perplexity
                })

            # Sample next token
            next_token_index = torch.multinomial(probs, num_samples=1)
            
            # Store stats
            chosen_token_str = tokenizer.decode(next_token_index)
            generation_stats.append({
                "step": i + 1,
                "chosen_token": chosen_token_str,
                "top_candidates": step_candidates
            })

            # Append to input_ids
            input_ids = torch.cat([input_ids, next_token_index.unsqueeze(0)], dim=1)

    full_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
    return full_text, generation_stats

# --- Usage ---
prompt_text = "Given that "
generated_text, stats = generate_with_top_k_stats(
    prompt=prompt_text, 
    num_generate=15,      
    top_k_candidates=10,  
    model_name='openai-community/gpt2' 
)

print(f"Final Text: {generated_text}\n")
print("-" * 30)
for stat in stats:
    print(f"Step {stat['step']}: Chosen -> '{stat['chosen_token']}'")
    for cand in stat['top_candidates']:
        print(f"    Token: {repr(cand['token']):<15} | Prob: {cand['prob']:.4f} | Perplexity: {cand['perplexity']:.4f}")

Running on: mps
Loading openai-community/gpt2...

Generating 15 tokens for prompt: 'Given that '

Final Text: Given that  it's a task for museums, I decided to design an interactive representations

------------------------------
Step 1: Chosen -> ' '
    Token: '\xa0'          | Prob: 0.4823 | Perplexity: 2.0733
    Token: '________'      | Prob: 0.0326 | Perplexity: 30.6886
    Token: 'vern'          | Prob: 0.0278 | Perplexity: 35.9865
    Token: '____'          | Prob: 0.0218 | Perplexity: 45.9751
    Token: '_____'         | Prob: 0.0171 | Perplexity: 58.4956
    Token: 'ive'           | Prob: 0.0143 | Perplexity: 70.1315
    Token: '_______'       | Prob: 0.0136 | Perplexity: 73.3011
    Token: '�'             | Prob: 0.0121 | Perplexity: 82.8051
    Token: 'iced'          | Prob: 0.0120 | Perplexity: 83.4721
    Token: '�'             | Prob: 0.0117 | Perplexity: 85.4473
Step 2: Chosen -> 'it'
    Token: 'the'           | Prob: 0.1345 | Perplexity: 7.4353
    Token: 'it'        