In [None]:
from ray.rllib.utils.numpy import torch

'''
Question: if you take any vector from the embedding metric,
is it true that its dot-product will be maximal with itself, and with other vectors - less?
In other words, if you encode a token into a vector, then is it decoded back using argmax.
Obviously, if all vectors are normalized, then yes - the dot product with itself gives 1, and with others - less than 1.
But for real embedding matrices this is not the case.
In Llama, about 500 out of 128,000 vectors are not decoded back.
And basically these are either tokens not from English, or tokens from Python like \t\t\t\t\t\t\t\t\r\n, or reserved_special_token.
If you do the same with the head of the model, there are more such errors - ~800, some long tokens from the English language are added there.
It's also interesting that after normalizing everything with a vector,
everything fell into place for the input embedding matrix, but head still gave 200 discrepancies,
apparently due to a numerical error, if I didn't impose it anywhere.
'''

In [None]:
from time import time
import numpy as np
import torch
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from collections import Counter
import matplotlib.pyplot as plt
import seaborn as sns


In [None]:
def get_norms(tensor):

    norms = torch.norm(tensor, p=2, dim=1)
    print(norms.shape)
    average_norm = norms.mean().item()
    std_norm = norms.std()
    print(f"Average Norm: {average_norm}")
    print(f"Standard Deviation of Norms: {std_norm}")
    print(f"std/mead: {std_norm/average_norm}")


def find_self_embeds(embedding: torch.Tensor, tokenizer, query_matrix: torch.Tensor | None = None) -> tuple[list[int], list[int], list]:

    '''
    For each vector of embedding we find the closest vector in a matrix.
    Ideally it should be the same vector. If not, we count and return number of such discrepancies.
    We do it in batches, since all vectors can not fit the GPU
    '''

    start = 0
    end = len(embedding)
    all_indices = np.array(range(end))
    result_indices = []
    batch_size = 10000
    batches = [
        range(i, min(i + batch_size, end)) for i in range(start, end, batch_size)
    ]
    eqs = []
    if query_matrix is None:
        query_matrix = embedding
    query_matrix = query_matrix.cuda()
    embedding = embedding.cuda()

    for batch in tqdm(batches):
        # indices = np.array(list(batch))
        X = query_matrix[batch]
        w = torch.argmax(X @ embedding.T, dim=-1).cpu() #.numpy()
        result_indices.extend(w)
    result_indices = np.array(result_indices)
    eqs = all_indices == result_indices
    fail_indices = all_indices[~eqs]
    fail_result_indices = result_indices[~eqs]

    failed_emb_toks = [tokenizer.decode(idx) for idx in fail_indices]
    failed_res_emb_toks = [tokenizer.decode(idx) for idx in fail_result_indices]
    failed_pairs = list(zip(failed_emb_toks, failed_res_emb_toks))

    return fail_indices, fail_result_indices, failed_pairs


In [None]:
def get_tokens_from_vectors(embedding_matrix, batch_size, num_batches, do_rms=False, model_norm=None):

    input_dim = embedding_matrix.shape[1]
    emb_mean = torch.mean(torch.mean(embedding_matrix, dim=0))
    emb_std = torch.mean(torch.std(embedding_matrix, dim=0))
    torch.manual_seed(1234)
    vectors = torch.normal(mean=emb_mean, std=emb_std, size=(num_batches * batch_size, input_dim))

    tokens = []
    for i in tqdm(range(0, len(vectors), batch_size)):
        batch = vectors[i : i + batch_size].cuda()
        if do_rms:
            batch  = model_norm(batch)
        predictions = torch.matmul(batch, embedding_matrix.T)
        token_ids = torch.argmax(predictions, dim=1).tolist()
        tokens.extend(token_ids)
    return tokens

def plot_dist(tokens):
    token_counts = Counter(tokens)
    sorted_tokens = sorted(token_counts.items())
    tokens, counts = zip(*sorted_tokens)

    # plt.figure(figsize=(12, 6))
    plt.bar(tokens, counts, width=1.0, edgecolor="black", color='blue')
    plt.title("Token Distribution")
    plt.xlabel("Token Index")
    plt.ylabel("Frequency")
    plt.show()

    return token_counts

def plot_emb_dist(emb):
    # Convert it into a NumPy array for easier plotting
    numpy_tensor = emb.detach().cpu().numpy()

    # Select a subset of positions to visualize (for simplicity, use every 100th position)
    positions_to_plot = range(0, 4096, 10)  # Change step size depending on how dense you want the plot

    # Prepare the figure
    plt.figure(figsize=(12, 8))

    # Plot the distributions for selected positions
    for pos in positions_to_plot:
        sns.kdeplot(numpy_tensor[:, pos], linewidth=1) # label=f"Position {pos}",

    # Add titles and labels
    plt.title("Distribution of Values at Different Positions")
    plt.xlabel("Value")
    plt.ylabel("Density")
    plt.legend(loc="upper right", fontsize='small')
    plt.show()

In [None]:
def get_model_and_embed(model_name):

    model = AutoModelForCausalLM.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    for param in model.parameters():
        param.requires_grad = False

    embedding = model.model.embed_tokens.weight.cuda()  # Embedding layer weights
    head = model.lm_head.weight.cuda()
    model_norm = model.model.norm.cuda()

    emb_norms = torch.norm(embedding, dim=1)
    filtered_norms = emb_norms[emb_norms >= 10e-5]
    mean_norm = torch.mean(filtered_norms)

    return model, embedding, head, model_norm, mean_norm, tokenizer

In [None]:
# Loading embeddings:
model_name = "meta-llama/Meta-Llama-3.1-8B"  # "meta-llama/Meta-Llama-3.1-8B" , "gpt2", "meta-llama/Llama-2-7b-hf" # Example: using GPT-2
model, embedding, head, model_norm, mean_norm, tokenizer = get_model_and_embed(model_name)

embedding_norm = (embedding.T / torch.norm(embedding, p=2, dim=1)).T
head_norm = (head.T / torch.norm(head, p=2, dim=1)).T
embedding_rms_norm  = model_norm(embedding)
head_rms_norm  = model_norm(head)

In [None]:
tokens_dist = get_tokens_from_vectors(head, batch_size=1000, num_batches=1000, do_rms=False, model_norm=model_norm)

In [None]:
token_counts = plot_dist(tokens_dist)

In [None]:
decoded_counts = {tokenizer.decode(token_id): count for token_id, count in token_counts.items()}
decoded_counts = dict(sorted(decoded_counts.items(), key=lambda item: item[1], reverse=True))

In [None]:
tok_freqs = list(decoded_counts.values())

In [None]:
tok_freqs_filt = [freq for freq in tok_freqs if 0 <= freq <= 100]
sns.histplot(tok_freqs_filt, bins=300, kde=False, log=True)

In [None]:
    failed_emb, failed_res_emb, failed_pairs = find_self_embeds(embedding, tokenizer, embedding_rms_norm)

In [None]:
failed_emb, failed_res_emb, failed_pairs = find_self_embeds(head, tokenizer, head_rms_norm)

In [None]:
failed_emb = find_self_embeds(embedding, embedding)
print(len(failed_emb))
failed_emb_norm = find_self_embeds(embedding_norm, embedding_norm)
print(len(failed_emb))

failed_head = find_self_embeds(head, head)

failed_emb_toks = [tokenizer.decode(idx) for idx in failed_emb]
failed_head_toks = [tokenizer.decode(idx) for idx in failed_head]

joined_set = set(failed_emb_toks) | set(failed_head_toks)

In [None]:
failed_emb = find_self_embeds(embedding, embedding_norm)

In [None]:
model