In [None]:
"""
This contains code to use SVD to decompose hidden states based on whether they're used by routing or not.
"""
None

In [None]:
"""
Imports
"""
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import pandas as pd
import numpy as np
import scipy
import cupy
import cuml
import sklearn

import importlib
import gc
import pickle
import os

from tqdm import tqdm
from termcolor import colored
import plotly.express as px
from plotly.subplots import make_subplots

from utils.memory import check_memory, clear_all_cuda_memory
from utils.quantize import compare_bf16_fp16_batched
from utils.svd import decompose_orthogonal, decompose_sideways
from utils.vis import combine_plots

main_device = 'cuda:0'
seed = 1234
clear_all_cuda_memory()
check_memory()

## Load model & data

In [None]:
"""
Load the base tokenizer/model
"""
model_ix = 0
models_list = [
    ('allenai/OLMoE-1B-7B-0125-Instruct', 'olmoe', 0),
    ('Qwen/Qwen1.5-MoE-A2.7B-Chat', 'qwen1.5moe', 0),
    ('deepseek-ai/DeepSeek-V2-Lite', 'dsv2', 1),
    ('Qwen/Qwen3-30B-A3B', 'qwen3moe', 0)
]

model_id, model_prefix, model_pre_mlp_layers = models_list[model_ix]
tokenizer = AutoTokenizer.from_pretrained(model_id, add_eos_token = False, add_bos_token = False, padding_side = 'left', trust_remote_code = True)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype = torch.bfloat16, trust_remote_code = True).cuda().eval()

In [None]:
"""
Load dataset
"""
def load_data(model_prefix):
    all_pre_mlp_hs = torch.load(f'./../export-data/activations/{model_prefix}/all-pre-mlp-hidden-states.pt')
    with open(f'./../export-data/activations/{model_prefix}/metadata.pkl', 'rb') as f:
        metadata = pickle.load(f)

    return all_pre_mlp_hs, metadata['sample_df'], metadata['topk_df'], metadata['all_pre_mlp_hidden_states_layers']

all_pre_mlp_hs_import, sample_df_import, topk_df_import, act_map = load_data(model_prefix)

In [None]:
"""
Let's clean up the mappings here. We'll get everything to a sample_ix level first.
"""
MAX_SAMPLES = 2_000_000

sample_df_raw =\
    sample_df_import\
    .assign(sample_ix = lambda df: df.groupby(['batch_ix', 'sequence_ix', 'token_ix']).ngroup())\
    .assign(seq_id = lambda df: df.groupby(['batch_ix', 'sequence_ix']).ngroup())\
    .reset_index()

topk_df =\
    topk_df_import\
    .merge(sample_df_raw[['sample_ix', 'batch_ix', 'sequence_ix', 'token_ix']], how = 'inner', on = ['sequence_ix', 'token_ix', 'batch_ix'])\
    .drop(columns = ['sequence_ix', 'token_ix', 'batch_ix'])\
    .assign(layer_ix = lambda df: df['layer_ix'] + model_pre_mlp_layers)

topk1_df =\
    topk_df\
    .pipe(lambda df: df[df['topk_ix'] == 1])

sample_df =\
    sample_df_raw\
    .drop(columns = ['batch_ix', 'sequence_ix'])

def get_sample_df_for_layer(sample_df, topk_df, layer_ix):
    """
    Helper to take the sample df and merge layer-level expert selection information
    """
    topk_layer_df = topk_df.pipe(lambda df: df[df['layer_ix'] == layer_ix])
    topk_l1_layer_df = topk_df.pipe(lambda df: df[df['layer_ix'] == layer_ix - 1])
    topk_l2_layer_df = topk_df.pipe(lambda df: df[df['layer_ix'] == layer_ix - 2])

    layer_df =\
        sample_df\
        .merge(topk_layer_df.pipe(lambda df: df[df['topk_ix'] == 1])[['sample_ix', 'expert']], how = 'inner', on = 'sample_ix')\
        .merge(topk_l1_layer_df.pipe(lambda df: df[df['topk_ix'] == 1]).rename(columns = {'expert': 'prev_expert'})[['sample_ix', 'prev_expert']], how = 'left', on = 'sample_ix')\
        .merge(topk_l2_layer_df.pipe(lambda df: df[df['topk_ix'] == 1]).rename(columns = {'expert': 'prev2_expert'})[['sample_ix', 'prev2_expert']], how = 'left', on = 'sample_ix')\
        .merge(topk_layer_df.pipe(lambda df: df[df['topk_ix'] == 2]).rename(columns = {'expert': 'expert2'})[['sample_ix', 'expert2']], how = 'left', on = 'sample_ix')\
        .assign(leading_path = lambda df: df['prev2_expert'] + '-' + df['prev_expert'])
    
    return layer_df

del sample_df_import, sample_df_raw, topk_df_import

all_pre_mlp_hs_import = all_pre_mlp_hs_import[0:MAX_SAMPLES, :, :]
sample_df = sample_df[sample_df['sample_ix'] < MAX_SAMPLES]
topk_df = topk_df[topk_df['sample_ix'] < MAX_SAMPLES]
topk1_df = topk1_df[topk1_df['sample_ix'] < MAX_SAMPLES]

gc.collect()
display(topk_df)
display(sample_df)

In [None]:
"""
Convert activations to fp16 (for compatibility with cupy later) + dict
"""
all_pre_mlp_hs = all_pre_mlp_hs_import.to(torch.float16)
# compare_bf16_fp16_batched(all_pre_mlp_hs_import, all_pre_mlp_hs)
del all_pre_mlp_hs_import
all_pre_mlp_hs = {(layer_ix + model_pre_mlp_layers): all_pre_mlp_hs[:, save_ix, :] for save_ix, layer_ix in enumerate(act_map)}

gc.collect()

## SVD Decomposition

In [None]:
"""
Let's take the pre-MLP hidden states and split them using SVD into parallel and orthogonal components.
"""
h_para_by_layer = {}
h_orth_by_layer = {}

for layer_ix in tqdm(list(all_pre_mlp_hs.keys())):
    h_para_by_layer[layer_ix], h_orth_by_layer[layer_ix] = decompose_orthogonal(
        all_pre_mlp_hs[layer_ix].to(torch.float32),
        model.model.layers[layer_ix].mlp.gate.weight.detach().cpu().to(torch.float32),
        'svd'
    )

## Orth vs Para Rotation

In [None]:
# Need to average firrst

para_means = torch.stack(
    [torch.mean(layer_hs[0:100, ], dim = 0) for _, layer_hs in tqdm(sorted(h_para_by_layer.items()))],
    dim = 0
)

# para_means = para_means - para_means.mean(dim = 0)

para_sim = sklearn.metrics.pairwise.cosine_similarity(para_means.numpy())
para_sim[0, :]

In [None]:
para_sim[np.triu_indices(para_sim.shape[0], k = 1)].mean()

In [None]:
orth_means = torch.stack(
    [torch.mean(layer_hs[0:100], dim = 0) for _, layer_hs in tqdm(sorted(h_orth_by_layer.items()))],
    dim = 0
)

# orth_means = orth_means - orth_means.mean(dim = 0)

orth_sim = sklearn.metrics.pairwise.cosine_similarity(orth_means.numpy())
orth_sim[0, :]

In [None]:
orth_sim[np.triu_indices(orth_sim.shape[0], k = 1)].mean()

In [None]:
h_para_prev = h_para_by_layer[8] # shape (T,Dh)  from layer ℓ-1
h_para_curr = h_para_by_layer[4] # shape (T,Dh)  from layer ℓ
cos = (h_para_prev * h_para_curr).sum(1) / (
        h_para_prev.norm(dim=1) * h_para_curr.norm(dim=1) + 1e-9)
theta = torch.acos(torch.clamp(cos, -1.0, 1.0))  # radians
stay = (expert_id_prev == expert_id_curr)
print(theta[stay].mean(), theta[~stay].mean())

# Large gap ⇒ token-wise header rewrite when the path branches.

## Router Prediction

In [None]:
Should we then project the deltas onto the next layer's router? Would that be worth the effort?

## Orth vs Para Clusters

In [None]:
"""
Helper functions for clustering
"""
def print_samples(df, grouping_cols):
    """
    Takes a wide dataframe and groups it, then prints random groups
    """
    res =\
        df\
        .groupby(grouping_cols, as_index = False)\
        .agg(
            n_samples = ('token', 'size'),
            samples = ('token', lambda s: s.sample(n = min(len(s), 10)).tolist())
        )\
        .pipe(lambda df: df[df['n_samples'] >= 5])\
        .sample(35)
    
    display(res)

In [None]:
"""
Let's cluster the para and ortho using k-means and see what clusters we get
"""
def cluster_kmeans(layer_hs: torch.Tensor, n_clusters = 512):
    """
    K-means clustering
    """
    kmeans_model = cuml.cluster.KMeans(n_clusters = n_clusters, max_iter = 1000, random_state = 123)
    kmeans_model.fit(cupy.asarray(layer_hs.to(torch.float32)))
    clear_all_cuda_memory(False)

    return kmeans_model.labels_.tolist()

def get_cluster(sample_df, hidden_states_by_layer, n_clusters = 256):
    """
    Get k-means clusters across hidden state layers
    """
    cluster_ids_by_layer = [
        {'layer_ix': layer_ix, 'cluster_ids': cluster_kmeans(layer_hs, n_clusters)} 
        for layer_ix, layer_hs in tqdm(hidden_states_by_layer.items())
    ]

    cluster_ids_df =\
        pd.concat([pd.DataFrame({'layer_' + str(x['layer_ix']) + '_id': x['cluster_ids']}) for x in cluster_ids_by_layer], axis = 1)\
        .pipe(lambda df: pd.concat([df, sample_df], axis = 1))
    
    display(
        cluster_ids_df.groupby('layer_1_id', as_index = False).agg(n_samples = ('token', 'size')).sort_values(by = 'n_samples', ascending = False)
    )

    return cluster_ids_df

para_clusters_df = get_cluster(sample_df, h_para_by_layer)
orth_clusters_df = get_cluster(sample_df, h_orth_by_layer)

print_samples(para_clusters_df, ['layer_1_id', 'layer_2_id'])
print_samples(orth_clusters_df, ['layer_1_id', 'layer_2_id'])

In [None]:
print_samples(para_clusters_df, ['layer_6_id', 'layer_7_id'])
print_samples(orth_clusters_df, ['layer_6_id', 'layer_7_id'])

In [None]:
"""
Count how many clusters are token-specific
"""
def get_single_token_cluster_counts(cluster_df, layer_ix):
    """
    Get how many tokens belong to a single cluster
    """
    res =\
        cluster_df\
        .groupby([f'layer_{str(layer_ix)}_id'], as_index = False)\
        .agg(
            n_samples = ('token', 'size'),
            samples = ('token', lambda s: s.sample(n = min(len(s), 20)).tolist())
        )\
        .pipe(lambda df: df[df['n_samples'] >= 5])\
        .assign(is_eq = lambda df: df.samples.apply(lambda s: 1 if len(set(s)) == 1 else 0))\
        .groupby('is_eq', as_index = False)\
        .agg(count = ('is_eq', 'count'))

    return(res)

display(get_single_token_cluster_counts(para_clusters_df, 7))
display(get_single_token_cluster_counts(orth_clusters_df, 7))

In [None]:
"""
Count entropy distribution
"""
def get_entropy_distribution(cluster_df, layer_ix, min_cluster_size = 1):
    cluster_id_col = f'layer_{str(layer_ix)}_id'

    def calculate_dominance(series):
        """Calculates the proportion of the most frequent item."""
        if series.empty:
            return np.nan
        counts = series.value_counts()
        return counts.iloc[0] / counts.sum()

    def calculate_normalized_entropy(series):
        """Calculates entropy normalized by log2(n_unique_tokens)."""
        if series.empty:
            return np.nan
        counts = series.value_counts()
        n_unique = len(counts)
        
        if n_unique <= 1:
            return 0.0 # Perfectly pure cluster has zero entropy

        ent = scipy.stats.entropy(counts, base=2)
        
        # Normalize by log2 of the number of unique elements
        return ent / np.log2(n_unique)

    # Perform aggregation
    agg_metrics =\
        cluster_df\
        .groupby(cluster_id_col, as_index = False)\
        .agg(
            n_samples=('token', 'size'),
            n_unique_tokens=('token', 'nunique'),
            dominance=('token', calculate_dominance),
            normalized_entropy=('token', calculate_normalized_entropy)
        )\
        .pipe(lambda df: df[df['n_samples'] >= min_cluster_size])

    return agg_metrics

para_entropy = get_entropy_distribution(para_clusters_df, 1)
orth_entropy = get_entropy_distribution(orth_clusters_df, 1)

print(f"Para entropy: {para_entropy['normalized_entropy'].mean()}")
print(f"Orth entropy: {orth_entropy['normalized_entropy'].mean()}")

## Reconstruction/probing tests

In [None]:
"""
Logistic regression - predict topk using h_orth?
"""
# Test layer 
test_layer = 0

def run_lr(x_cp, y_cp):
    x_train, x_test, y_train, y_test = cuml.train_test_split(x_cp, y_cp, test_size = 0.1, random_state = 123)
    lr_model = cuml.linear_model.LogisticRegression(penalty = 'l2', max_iter = 10000, fit_intercept = False)
    lr_model.fit(x_train, y_train)
    accuracy = lr_model.score(x_test, y_test)
    print(f"Accuracy: {accuracy:.2%}")

expert_ids =\
    topk_df\
    .pipe(lambda df: df[df['layer_ix'] == test_layer])\
    .pipe(lambda df: df[df['topk_ix'] == 1])\
    ['expert'].tolist()

expert_ids_cp = cupy.asarray(expert_ids)
x_cp_para = cupy.asarray(h_para_by_layer[test_layer].to(torch.float16).detach().cpu())
x_cp_orth = cupy.asarray(h_orth_by_layer[test_layer].to(torch.float16).detach().cpu())

run_lr(x_cp_para, expert_ids_cp)
run_lr(x_cp_orth, expert_ids_cp)

In [None]:
test_layer = 2

expert_ids =\
    topk_df\
    .pipe(lambda df: df[df['layer_ix'] == test_layer])\
    .pipe(lambda df: df[df['topk_ix'] == 1])\
    ['expert'].tolist()

expert_ids_cp = cupy.asarray(expert_ids)
x_cp_para = cupy.asarray(h_para_by_layer[test_layer].to(torch.float16).detach().cpu())
x_cp_orth = cupy.asarray(h_orth_by_layer[test_layer].to(torch.float16).detach().cpu())

run_lr(x_cp_para, expert_ids_cp)
run_lr(x_cp_orth, expert_ids_cp)

In [None]:
"""
Use h_para and h_orth to predict NEXT layer expert ids
"""
test_layer = 1

expert_ids =\
    topk_df\
    .pipe(lambda df: df[df['layer_ix'] == test_layer + 1])\
    .pipe(lambda df: df[df['topk_ix'] == 1])\
    ['expert'].tolist()

expert_ids_cp = cupy.asarray(expert_ids)
x_cp_para = cupy.asarray(h_para_by_layer[test_layer].to(torch.float16).detach().cpu())
x_cp_orth = cupy.asarray(h_orth_by_layer[test_layer].to(torch.float16).detach().cpu())
# x_cp_ccat = cupy.asarray(torch.cat(
#     [h_para_by_layer[test_layer].to(torch.float16).detach().cpu(), h_orth_by_layer[test_layer].to(torch.float16).detach().cpu()],
#     dim = 1
#     ))

run_lr(x_cp_para, expert_ids_cp)
run_lr(x_cp_orth, expert_ids_cp)
# run_lr(x_cp_ccat, expert_ids_cp)

In [None]:
"""
Predict token ID
"""
display(
    sample_df.groupby('token', as_index = False).agg(n = ('token', 'count')).sort_values(by = 'n', ascending = False).head(30)
)

test_layer = 0

y_df =\
    sample_df\
    .assign(is_sample = lambda df: np.where(df['token'].isin([' the']), 1, 0))\
    ['is_sample'].tolist()

y_cp = cupy.asarray(y_df)
x_cp_para = cupy.asarray(h_para_by_layer[test_layer].to(torch.float16).detach().cpu())
x_cp_orth = cupy.asarray(h_orth_by_layer[test_layer].to(torch.float16).detach().cpu())

run_lr(x_cp_para, y_cp)
run_lr(x_cp_orth, y_cp)

## Stability analysis

In [None]:
"""
Analyze stability over layers
"""
def calculate_layer_transition_stability(h_orth_layers: dict, h_para_layers: dict, layer_l: int):
    """
    Calculates the stability of h_orth and h_para representations between layer_l and layer_l+1 using cosine similarity and Euclidean distance.

    Params:
        @h_orth_layers: Dictionary where keys are layer indices (int) and values are (n_samples, D) tensors for h_orth.
        @h_para_layers: Dictionary where keys are layer indices (int) andvalues are (n_samples, D) tensors for h_para.
        layer_l: The starting layer index for the transition (e.g., 6 for L6->L7).

    Returns:
        A dictionary containing:
        - 'cosine_similarity_orth': (n_samples,) tensor of cosine similarities for h_orth.
        - 'cosine_similarity_para': (n_samples,) tensor of cosine similarities for h_para.
        - 'euclidean_distance_orth': (n_samples,) tensor of L2 distances for h_orth.
        - 'euclidean_distance_para': (n_samples,) tensor of L2 distances for h_para.
    """
    layer_lp1 = layer_l + 1

    # Get tensors for the specified layers
    h_orth_l = h_orth_layers[layer_l]
    h_orth_lp1 = h_orth_layers[layer_lp1]
    h_para_l = h_para_layers[layer_l]
    h_para_lp1 = h_para_layers[layer_lp1]

    # --- Calculate Cosine Similarities (Higher is more stable) ---
    # dim = 1 calculates similarity row-wise
    sim_orth = torch.nn.functional.cosine_similarity(h_orth_l, h_orth_lp1, dim = 1)
    sim_para = torch.nn.functional.cosine_similarity(h_para_l, h_para_lp1, dim = 1)

    # --- Calculate Euclidean Distances (Lower is more stable) ---
    # torch.linalg.norm computes norms. ord=2 is L2 norm.
    dist_orth = torch.linalg.norm(h_orth_l - h_orth_lp1, ord = 2, dim = 1)
    dist_para = torch.linalg.norm(h_para_l - h_para_lp1, ord = 2, dim = 1)

    return {
        'cosine_similarity_orth': sim_orth.to(torch.float16),
        'cosine_similarity_para': sim_para.to(torch.float16),
        'euclidean_distance_orth': dist_orth.to(torch.float16),
        'euclidean_distance_para': dist_para.to(torch.float16),
    }

stability_results = calculate_layer_transition_stability(
    h_orth_layers = {
        0: h_orth_by_layer[0].to(torch.float32).detach().cpu(),
        1: h_orth_by_layer[1].to(torch.float32).detach().cpu()
    },
    h_para_layers = {
        0: h_para_by_layer[0].to(torch.float32).detach().cpu(), 
        1: h_para_by_layer[1].to(torch.float32).detach().cpu() 
    },
    layer_l = 0
)

In [None]:
print("Checking sim_orth for non-finite values:", torch.isfinite(sim_orth).all())
print("Checking dist_orth for non-finite values:", torch.isfinite(dist_orth).all())
    

print("Zero vectors in h_orth_l:", torch.where(torch.linalg.norm(h_orth_l, dim=1) == 0))
print("Zero vectors in h_orth_lp1:", torch.where(torch.linalg.norm(h_orth_lp1, dim=1) == 0))

In [None]:
sim_orth_np = stability_results['cosine_similarity_orth'].numpy()
sim_para_np = stability_results['cosine_similarity_para'].numpy()
dist_orth_np = stability_results['euclidean_distance_orth'].numpy()
dist_para_np = stability_results['euclidean_distance_para'].numpy()

# (Keep the print statements for descriptive statistics as before)

print("\nCosine Similarity (Higher is More Stable):")
print(f"  h_orth: Mean={np.mean(sim_orth_np):.4f}, Median={np.median(sim_orth_np):.4f}, Std={np.std(sim_orth_np):.4f}")
print(f"  h_para: Mean={np.mean(sim_para_np):.4f}, Median={np.median(sim_para_np):.4f}, Std={np.std(sim_para_np):.4f}")
print("\nEuclidean Distance (Lower is More Stable):")
print(f"  h_orth: Mean={np.mean(dist_orth_np):.4f}, Median={np.median(dist_orth_np):.4f}, Std={np.std(dist_orth_np):.4f}")
print(f"  h_para: Mean={np.mean(dist_para_np):.4f}, Median={np.median(dist_para_np):.4f}, Std={np.std(dist_para_np):.4f}")

layer_idx = 0

data_sim = pd.DataFrame({
    'value': np.concatenate([sim_orth_np, sim_para_np]),
    'component': ['h_orth'] * len(sim_orth_np) + ['h_para'] * len(sim_para_np),
    'metric': 'cos'
})

data_dist = pd.DataFrame({
    'value': np.concatenate([dist_orth_np, dist_para_np]),
    'component': ['h_orth'] * len(dist_orth_np) + ['h_para'] * len(dist_para_np),
    'metric': 'euc'
})

df_plot = pd.concat([data_sim, data_dist], ignore_index=True)

fig = px.histogram(
    df_plot,
    x = 'value', # Values for the histogram
    color = 'component', # Creates separate histograms for 'h_orth' and 'h_para'
    facet_col = 'metric', # Creates separate subplots (columns) for each metric type
    histnorm = 'probability density', # Normalize histograms
    barmode = 'overlay', # Overlay histograms within each subplot
    opacity = 0.75,
    title = f'Stability Comparison: Layer {layer_idx} to {layer_idx+1}',
    labels = {'component': 'Component Type'} 
    )\
    .update_xaxes(title_text = "Cosine Similarity", col = 1)\
    .update_xaxes(title_text = "Euclidean Distance", col = 2)

fig.show()

# --- Statistical Test (Example: Mann-Whitney U test) ---
u_stat_sim, p_value_sim = scipy.stats.mannwhitneyu(sim_orth_np, sim_para_np, alternative='greater') # Test if orth > para
u_stat_dist, p_value_dist = scipy.stats.mannwhitneyu(dist_para_np, dist_orth_np, alternative='greater') # Test if para > orth
print("\nMann-Whitney U Test Results:")
print(f"  Cosine Similarity (H1: orth > para): p-value = {p_value_sim:.2e} | rejection = h_orth sim is higher")
print(f"  Euclidean Distance (H1: para > orth): p-value = {p_value_dist:.2e} | rejection = h_orth distance is lower")

## Logit lens

"""
Logit lens - take a single prompt and see what the different hidden states are predicting 
"""

sample_ix = []

pre_mlp_hidden_states = []