In [None]:
"""
This contains code to use SVD to decompose hidden states based on whether they're used by routing or not.
"""
None

In [1]:
"""
Imports
"""
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import pandas as pd
import numpy as np
import scipy
import cupy
import cuml
import sklearn

import importlib
import gc
import pickle
import os

from tqdm import tqdm
from termcolor import colored
import plotly.express as px
from plotly.subplots import make_subplots

from utils.memory import check_memory, clear_all_cuda_memory
from utils.quantize import compare_bf16_fp16_batched
from utils.svd import decompose_orthogonal, decompose_sideways
from utils.vis import combine_plots

main_device = 'cuda:0'
seed = 1234
clear_all_cuda_memory()
check_memory()

All CUDA memory cleared on all devices.
Device 0: NVIDIA H100 PCIe
  Allocated: 0.00 GB
  Reserved: 0.00 GB
  Total: 79.10 GB



## Load model & data

In [2]:
"""
Load the base tokenizer/model
"""
model_ix = 3
models_list = [
    ('allenai/OLMoE-1B-7B-0125-Instruct', 'olmoe', 0),
    ('Qwen/Qwen1.5-MoE-A2.7B-Chat', 'qwen1.5moe', 0),
    ('deepseek-ai/DeepSeek-V2-Lite', 'dsv2', 1),
    ('Qwen/Qwen3-30B-A3B', 'qwen3moe', 0)
]

model_id, model_prefix, model_pre_mlp_layers = models_list[model_ix]
tokenizer = AutoTokenizer.from_pretrained(model_id, add_eos_token = False, add_bos_token = False, padding_side = 'left', trust_remote_code = True)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype = torch.bfloat16, trust_remote_code = True).cuda().eval()

Loading checkpoint shards:   0%|          | 0/16 [00:00<?, ?it/s]

In [3]:
"""
Load dataset
"""
def load_data(model_prefix, max_data_files):
    """
    Load data saved by `export-activations-sm.ipynb`
    """
    folders = [f'./../export-data/activations-sm/{model_prefix}/{i:02d}' for i in range(max_data_files)]
    folders = [f for f in folders if os.path.isdir(f)]

    all_pre_mlp_hs = []
    sample_df = []
    topk_df = []

    for f in tqdm(folders):
        sample_df.append(pd.read_pickle(f'{f}/samples.pkl'))
        topk_df.append(pd.read_pickle(f'{f}/topks.pkl'))
        all_pre_mlp_hs.append(torch.load(f'{f}/all-pre-mlp-hidden-states.pt'))

    sample_df = pd.concat(sample_df)
    topk_df = pd.concat(topk_df)
    all_pre_mlp_hs = torch.concat(all_pre_mlp_hs)    

    with open(f'./../export-data/activations-sm/{model_prefix}/metadata.pkl', 'rb') as f:
        metadata = pickle.load(f)
    
    gc.collect()
    return sample_df, topk_df, all_pre_mlp_hs, metadata['all_pre_mlp_hidden_states_layers']

# Due to mem constraints, for Qwen3Moe max_data_files = 2
sample_df_import, topk_df_import, all_pre_mlp_hs_import, act_map = load_data(model_prefix, 2)

100%|██████████| 2/2 [00:19<00:00,  9.62s/it]


In [4]:
"""
Let's clean up the mappings here. We'll get everything to a sample_ix level first.
"""
sample_df_raw =\
    sample_df_import\
    .assign(sample_ix = lambda df: df.groupby(['batch_ix', 'sequence_ix', 'token_ix']).ngroup())\
    .assign(seq_id = lambda df: df.groupby(['batch_ix', 'sequence_ix']).ngroup())\
    .reset_index()

topk_df =\
    topk_df_import\
    .merge(sample_df_raw[['sample_ix', 'batch_ix', 'sequence_ix', 'token_ix']], how = 'inner', on = ['sequence_ix', 'token_ix', 'batch_ix'])\
    .drop(columns = ['sequence_ix', 'token_ix', 'batch_ix'])\
    .assign(layer_ix = lambda df: df['layer_ix'] + model_pre_mlp_layers)

topk1_df =\
    topk_df\
    .pipe(lambda df: df[df['topk_ix'] == 1])

sample_df =\
    sample_df_raw\
    .drop(columns = ['batch_ix', 'sequence_ix'])

def get_sample_df_for_layer(sample_df, topk_df, layer_ix):
    """
    Helper to take the sample df and merge layer-level expert selection information
    """
    topk_layer_df = topk_df.pipe(lambda df: df[df['layer_ix'] == layer_ix])
    topk_l1_layer_df = topk_df.pipe(lambda df: df[df['layer_ix'] == layer_ix - 1])
    topk_l2_layer_df = topk_df.pipe(lambda df: df[df['layer_ix'] == layer_ix - 2])

    layer_df =\
        sample_df\
        .merge(topk_layer_df.pipe(lambda df: df[df['topk_ix'] == 1])[['sample_ix', 'expert']], how = 'inner', on = 'sample_ix')\
        .merge(topk_l1_layer_df.pipe(lambda df: df[df['topk_ix'] == 1]).rename(columns = {'expert': 'prev_expert'})[['sample_ix', 'prev_expert']], how = 'left', on = 'sample_ix')\
        .merge(topk_l2_layer_df.pipe(lambda df: df[df['topk_ix'] == 1]).rename(columns = {'expert': 'prev2_expert'})[['sample_ix', 'prev2_expert']], how = 'left', on = 'sample_ix')\
        .merge(topk_layer_df.pipe(lambda df: df[df['topk_ix'] == 2]).rename(columns = {'expert': 'expert2'})[['sample_ix', 'expert2']], how = 'left', on = 'sample_ix')\
        .assign(leading_path = lambda df: df['prev2_expert'] + '-' + df['prev_expert'])
    
    return layer_df

del sample_df_import, sample_df_raw, topk_df_import

gc.collect()
display(topk_df)
display(sample_df)

Unnamed: 0,layer_ix,topk_ix,expert,weight,sample_ix
0,0,1,74,0.21,0
1,0,2,118,0.19,0
2,0,3,24,0.15,0
3,0,4,47,0.14,0
4,0,5,4,0.13,0
...,...,...,...,...,...
59693563,47,4,61,0.09,155451
59693564,47,5,103,0.09,155451
59693565,47,6,12,0.08,155451
59693566,47,7,48,0.08,155451


Unnamed: 0,index,token_ix,token_id,output_id,output_prob,token,source,sample_ix,seq_id
0,0,62,1593,353,0.70,LO,es,0,0
1,1,63,29676,3385,0.24,QUE,es,1,0
2,2,64,434,969,0.87,F,es,2,0
3,3,65,35830,56550,0.94,ALT,es,3,0
4,4,66,56550,25,0.22,ABA,es,4,0
...,...,...,...,...,...,...,...,...,...
155447,79849,507,15,20,0.70,0,es,155447,499
155448,79850,508,21,17449,0.95,6,es,155448,499
155449,79851,509,17449,220,1.00,Sep,es,155449,499
155450,79852,510,220,17,1.00,,es,155450,499


In [5]:
"""
Convert activations to fp16 (for compatibility with cupy later) + dict
"""
all_pre_mlp_hs = all_pre_mlp_hs_import.to(torch.float16)
# compare_bf16_fp16_batched(all_pre_mlp_hs_import, all_pre_mlp_hs)
del all_pre_mlp_hs_import
all_pre_mlp_hs = {(layer_ix + model_pre_mlp_layers): all_pre_mlp_hs[:, save_ix, :] for save_ix, layer_ix in enumerate(act_map)}

gc.collect()

0

## SVD Decomposition

In [6]:
"""
Let's take the pre-MLP hidden states and split them using SVD into parallel and orthogonal components.
"""
h_para_by_layer = {}
h_orth_by_layer = {}

for layer_ix in tqdm(list(all_pre_mlp_hs.keys())):
    h_para_by_layer[layer_ix], h_orth_by_layer[layer_ix] = decompose_orthogonal(
        all_pre_mlp_hs[layer_ix].to(torch.float32),
        model.model.layers[layer_ix].mlp.gate.weight.detach().cpu().to(torch.float32),
        'svd'
    )

100%|██████████| 48/48 [02:43<00:00,  3.40s/it]


## Orth vs Para Rotation

In [35]:
bootstrap_samples = 50

def get_sample_res(hs_by_layer, samples_to_test = 1):
    
    samples = np.random.randint(0, hs_by_layer[0].shape[0], samples_to_test)

    # Cast into sample-level list
    sample_tensors = torch.stack([layer_hs[samples, :] for _, layer_hs in hs_by_layer.items()], dim = 1).unbind(dim = 0)

    sims = []
    for s in sample_tensors:
        cos_sim = sklearn.metrics.pairwise.cosine_similarity(s)
        sims.append(np.diag(cos_sim, 1))

    return np.mean(np.stack(sims, axis = 0), axis = 0)

para_res = np.stack([get_sample_res(h_para_by_layer) for _ in range(bootstrap_samples)], axis = 0) # bootstrap_samples x layer_diffs

para_mean_across_layers = para_res.mean(axis = 0)
para_cis_across_layers = 1.96 * np.std(para_res, axis = 0)

para_mean_overall = np.mean(para_mean_across_layers)
para_mean_ci = 1.96 * np.std(np.mean(para_res, axis = 1)).item()

# print(f"Mean across layer transitions: {para_mean_across_layers}")
print(f"Mean across layer transitions + samples: {para_mean_overall:.2f} +/- {para_mean_ci:.2f}")

Mean across layer transitions + samples: 0.71 +/- 0.04


In [36]:
orth_res = np.stack([get_sample_res(h_orth_by_layer) for _ in range(bootstrap_samples)], axis = 0) # bootstrap_samples x layer_diffs

orth_mean_across_layers = orth_res.mean(axis = 0)
orth_cis_across_layers = 1.96 * np.std(orth_res, axis = 0)

orth_mean_overall = np.mean(orth_mean_across_layers)
orth_mean_ci = 1.96 * np.std(np.mean(orth_res, axis = 1)).item()

# print(f"Mean across layer transitions: {orth_mean_across_layers}")
print(f"Mean across layer transitions + samples: {orth_mean_overall:.2f} +/- {orth_mean_ci:.2f}")

Mean across layer transitions + samples: 0.84 +/- 0.05


In [37]:
export_df = pd.DataFrame({
    'layer_ix': list(range(model_pre_mlp_layers + 1, len(all_pre_mlp_hs) + model_pre_mlp_layers)),
    'para_mean_across_layers': para_mean_across_layers,
    'orth_mean_across_layers': orth_mean_across_layers,
    'para_cis': para_cis_across_layers,
    'orth_cis': orth_cis_across_layers
})

export_df.to_csv(f'exports/transition-stability-{model_prefix}.csv', index = False)

## Orth vs Para Clusters

In [None]:
"""
Helper functions for clustering
"""
def print_samples(df, grouping_cols):
    """
    Takes a wide dataframe and groups it, then prints random groups
    """
    res =\
        df\
        .groupby(grouping_cols, as_index = False)\
        .agg(
            n_samples = ('token', 'size'),
            samples = ('token', lambda s: s.sample(n = min(len(s), 10)).tolist())
        )\
        .pipe(lambda df: df[df['n_samples'] >= 5])\
        .sample(35)
    
    display(res)

In [None]:
"""
Let's cluster the para and ortho using k-means and see what clusters we get
"""
def cluster_kmeans(layer_hs: torch.Tensor, n_clusters = 512):
    """
    K-means clustering
    """
    kmeans_model = cuml.cluster.KMeans(n_clusters = n_clusters, max_iter = 1000, random_state = 123)
    kmeans_model.fit(cupy.asarray(layer_hs.to(torch.float32)))
    clear_all_cuda_memory(False)

    return kmeans_model.labels_.tolist()

def get_cluster(sample_df, hidden_states_by_layer, n_clusters = 256):
    """
    Get k-means clusters across hidden state layers
    """
    cluster_ids_by_layer = [
        {'layer_ix': layer_ix, 'cluster_ids': cluster_kmeans(layer_hs, n_clusters)} 
        for layer_ix, layer_hs in tqdm(hidden_states_by_layer.items())
    ]

    cluster_ids_df =\
        pd.concat([pd.DataFrame({'layer_' + str(x['layer_ix']) + '_id': x['cluster_ids']}) for x in cluster_ids_by_layer], axis = 1)\
        .pipe(lambda df: pd.concat([df, sample_df], axis = 1))
    
    display(
        cluster_ids_df.groupby('layer_1_id', as_index = False).agg(n_samples = ('token', 'size')).sort_values(by = 'n_samples', ascending = False)
    )

    return cluster_ids_df

para_clusters_df = get_cluster(sample_df, h_para_by_layer)
orth_clusters_df = get_cluster(sample_df, h_orth_by_layer)

print_samples(para_clusters_df, ['layer_1_id', 'layer_2_id'])
print_samples(orth_clusters_df, ['layer_1_id', 'layer_2_id'])

In [None]:
print_samples(para_clusters_df, ['layer_6_id', 'layer_7_id'])
print_samples(orth_clusters_df, ['layer_6_id', 'layer_7_id'])

In [None]:
"""
Count how many clusters are token-specific
"""
def get_single_token_cluster_counts(cluster_df, layer_ix):
    """
    Get how many tokens belong to a single cluster
    """
    res =\
        cluster_df\
        .groupby([f'layer_{str(layer_ix)}_id'], as_index = False)\
        .agg(
            n_samples = ('token', 'size'),
            samples = ('token', lambda s: s.sample(n = min(len(s), 20)).tolist())
        )\
        .pipe(lambda df: df[df['n_samples'] >= 5])\
        .assign(is_eq = lambda df: df.samples.apply(lambda s: 1 if len(set(s)) == 1 else 0))\
        .groupby('is_eq', as_index = False)\
        .agg(count = ('is_eq', 'count'))

    return(res)

display(get_single_token_cluster_counts(para_clusters_df, 7))
display(get_single_token_cluster_counts(orth_clusters_df, 7))

In [None]:
"""
Count entropy distribution
"""
def get_entropy_distribution(cluster_df, layer_ix, min_cluster_size = 1):
    cluster_id_col = f'layer_{str(layer_ix)}_id'

    def calculate_dominance(series):
        """Calculates the proportion of the most frequent item."""
        if series.empty:
            return np.nan
        counts = series.value_counts()
        return counts.iloc[0] / counts.sum()

    def calculate_normalized_entropy(series):
        """Calculates entropy normalized by log2(n_unique_tokens)."""
        if series.empty:
            return np.nan
        counts = series.value_counts()
        n_unique = len(counts)
        
        if n_unique <= 1:
            return 0.0 # Perfectly pure cluster has zero entropy

        ent = scipy.stats.entropy(counts, base=2)
        
        # Normalize by log2 of the number of unique elements
        return ent / np.log2(n_unique)

    # Perform aggregation
    agg_metrics =\
        cluster_df\
        .groupby(cluster_id_col, as_index = False)\
        .agg(
            n_samples=('token', 'size'),
            n_unique_tokens=('token', 'nunique'),
            dominance=('token', calculate_dominance),
            normalized_entropy=('token', calculate_normalized_entropy)
        )\
        .pipe(lambda df: df[df['n_samples'] >= min_cluster_size])

    return agg_metrics

para_entropy = get_entropy_distribution(para_clusters_df, 1)
orth_entropy = get_entropy_distribution(orth_clusters_df, 1)

print(f"Para entropy: {para_entropy['normalized_entropy'].mean()}")
print(f"Orth entropy: {orth_entropy['normalized_entropy'].mean()}")

## Reconstruction/probing tests

In [None]:
"""
Logistic regression - predict topk using h_orth?
"""
# Test layer 
test_layer = 0

def run_lr(x_cp, y_cp):
    x_train, x_test, y_train, y_test = cuml.train_test_split(x_cp, y_cp, test_size = 0.1, random_state = 123)
    lr_model = cuml.linear_model.LogisticRegression(penalty = 'l2', max_iter = 10000, fit_intercept = False)
    lr_model.fit(x_train, y_train)
    accuracy = lr_model.score(x_test, y_test)
    print(f"Accuracy: {accuracy:.2%}")

expert_ids =\
    topk_df\
    .pipe(lambda df: df[df['layer_ix'] == test_layer])\
    .pipe(lambda df: df[df['topk_ix'] == 1])\
    ['expert'].tolist()

expert_ids_cp = cupy.asarray(expert_ids)
x_cp_para = cupy.asarray(h_para_by_layer[test_layer].to(torch.float16).detach().cpu())
x_cp_orth = cupy.asarray(h_orth_by_layer[test_layer].to(torch.float16).detach().cpu())

run_lr(x_cp_para, expert_ids_cp)
run_lr(x_cp_orth, expert_ids_cp)

In [None]:
test_layer = 2

expert_ids =\
    topk_df\
    .pipe(lambda df: df[df['layer_ix'] == test_layer])\
    .pipe(lambda df: df[df['topk_ix'] == 1])\
    ['expert'].tolist()

expert_ids_cp = cupy.asarray(expert_ids)
x_cp_para = cupy.asarray(h_para_by_layer[test_layer].to(torch.float16).detach().cpu())
x_cp_orth = cupy.asarray(h_orth_by_layer[test_layer].to(torch.float16).detach().cpu())

run_lr(x_cp_para, expert_ids_cp)
run_lr(x_cp_orth, expert_ids_cp)

In [None]:
"""
Use h_para and h_orth to predict NEXT layer expert ids
"""
test_layer = 1

expert_ids =\
    topk_df\
    .pipe(lambda df: df[df['layer_ix'] == test_layer + 1])\
    .pipe(lambda df: df[df['topk_ix'] == 1])\
    ['expert'].tolist()

expert_ids_cp = cupy.asarray(expert_ids)
x_cp_para = cupy.asarray(h_para_by_layer[test_layer].to(torch.float16).detach().cpu())
x_cp_orth = cupy.asarray(h_orth_by_layer[test_layer].to(torch.float16).detach().cpu())
# x_cp_ccat = cupy.asarray(torch.cat(
#     [h_para_by_layer[test_layer].to(torch.float16).detach().cpu(), h_orth_by_layer[test_layer].to(torch.float16).detach().cpu()],
#     dim = 1
#     ))

run_lr(x_cp_para, expert_ids_cp)
run_lr(x_cp_orth, expert_ids_cp)
# run_lr(x_cp_ccat, expert_ids_cp)

In [None]:
"""
Predict token ID
"""
display(
    sample_df.groupby('token', as_index = False).agg(n = ('token', 'count')).sort_values(by = 'n', ascending = False).head(30)
)

test_layer = 0

y_df =\
    sample_df\
    .assign(is_sample = lambda df: np.where(df['token'].isin([' the']), 1, 0))\
    ['is_sample'].tolist()

y_cp = cupy.asarray(y_df)
x_cp_para = cupy.asarray(h_para_by_layer[test_layer].to(torch.float16).detach().cpu())
x_cp_orth = cupy.asarray(h_orth_by_layer[test_layer].to(torch.float16).detach().cpu())

run_lr(x_cp_para, y_cp)
run_lr(x_cp_orth, y_cp)

## Logit lens

"""
Logit lens - take a single prompt and see what the different hidden states are predicting 
"""

sample_ix = []

pre_mlp_hidden_states = []