In [None]:
"""
This contains code to use SVD to decompose hidden states based on whether they're used by routing or not.
"""
None

In [None]:
"""
Imports
"""
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import pandas as pd
import numpy as np
import scipy
import cupy
import cuml
import sklearn

import importlib
import gc
import pickle
import os

from tqdm import tqdm
from termcolor import colored
import plotly.express as px
from plotly.subplots import make_subplots

from utils.memory import check_memory, clear_all_cuda_memory
from utils.quantize import compare_bf16_fp16_batched
from utils.svd import decompose_orthogonal, decompose_sideways
from utils.vis import combine_plots

main_device = 'cuda:0'
seed = 1234

clear_all_cuda_memory()
check_memory()

## Load model & data

In [None]:
"""
Load the base tokenizer/model
"""
model_ix = 2
models_list = [
    ('allenai/OLMoE-1B-7B-0125-Instruct', 'olmoe', 0),
    ('Qwen/Qwen1.5-MoE-A2.7B-Chat', 'qwen1.5moe', 0),
    ('deepseek-ai/DeepSeek-V2-Lite', 'dsv2', 1),
    ('Qwen/Qwen3-30B-A3B', 'qwen3moe', 0)
]

model_id, model_prefix, model_pre_mlp_layers = models_list[model_ix]
tokenizer = AutoTokenizer.from_pretrained(model_id, add_eos_token = False, add_bos_token = False, padding_side = 'left', trust_remote_code = True)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype = torch.bfloat16, trust_remote_code = True).cuda().eval()

In [None]:
"""
Load dataset
"""
def load_data(model_prefix, max_data_files):
    """
    Load data saved by `export-activations-sm.ipynb`
    """
    folders = [f'./../export-data/activations-sm/{model_prefix}/{i:02d}' for i in range(max_data_files)]
    folders = [f for f in folders if os.path.isdir(f)]

    all_pre_mlp_hs = []
    sample_df = []
    topk_df = []

    for f in tqdm(folders):
        sample_df.append(pd.read_pickle(f'{f}/samples.pkl'))
        topk_df.append(pd.read_pickle(f'{f}/topks.pkl'))
        all_pre_mlp_hs.append(torch.load(f'{f}/all-pre-mlp-hidden-states.pt'))

    sample_df = pd.concat(sample_df)
    topk_df = pd.concat(topk_df)
    all_pre_mlp_hs = torch.concat(all_pre_mlp_hs)    

    with open(f'./../export-data/activations-sm/{model_prefix}/metadata.pkl', 'rb') as f:
        metadata = pickle.load(f)
    
    gc.collect()
    return sample_df, topk_df, all_pre_mlp_hs, metadata['all_pre_mlp_hidden_states_layers']

# Due to mem constraints, for Qwen3Moe max_data_files = 3
sample_df_import, topk_df_import, all_pre_mlp_hs_import, act_map = load_data(model_prefix, 3)

In [None]:
"""
Let's clean up the mappings here. We'll get everything to a sample_ix level first.
"""
sample_df_raw =\
    sample_df_import\
    .assign(sample_ix = lambda df: df.groupby(['batch_ix', 'sequence_ix', 'token_ix']).ngroup())\
    .assign(seq_id = lambda df: df.groupby(['batch_ix', 'sequence_ix']).ngroup())\
    .reset_index()

topk_df =\
    topk_df_import\
    .merge(sample_df_raw[['sample_ix', 'batch_ix', 'sequence_ix', 'token_ix']], how = 'inner', on = ['sequence_ix', 'token_ix', 'batch_ix'])\
    .drop(columns = ['sequence_ix', 'token_ix', 'batch_ix'])\
    .assign(layer_ix = lambda df: df['layer_ix'] + model_pre_mlp_layers)

topk1_df =\
    topk_df\
    .pipe(lambda df: df[df['topk_ix'] == 1])

sample_df =\
    sample_df_raw\
    .drop(columns = ['batch_ix', 'sequence_ix'])

def get_sample_df_for_layer(sample_df, topk_df, layer_ix):
    """
    Helper to take the sample df and merge layer-level expert selection information
    """
    topk_layer_df = topk_df.pipe(lambda df: df[df['layer_ix'] == layer_ix])
    topk_l1_layer_df = topk_df.pipe(lambda df: df[df['layer_ix'] == layer_ix - 1])
    topk_l2_layer_df = topk_df.pipe(lambda df: df[df['layer_ix'] == layer_ix - 2])

    layer_df =\
        sample_df\
        .merge(topk_layer_df.pipe(lambda df: df[df['topk_ix'] == 1])[['sample_ix', 'expert']], how = 'inner', on = 'sample_ix')\
        .merge(topk_l1_layer_df.pipe(lambda df: df[df['topk_ix'] == 1]).rename(columns = {'expert': 'prev_expert'})[['sample_ix', 'prev_expert']], how = 'left', on = 'sample_ix')\
        .merge(topk_l2_layer_df.pipe(lambda df: df[df['topk_ix'] == 1]).rename(columns = {'expert': 'prev2_expert'})[['sample_ix', 'prev2_expert']], how = 'left', on = 'sample_ix')\
        .merge(topk_layer_df.pipe(lambda df: df[df['topk_ix'] == 2]).rename(columns = {'expert': 'expert2'})[['sample_ix', 'expert2']], how = 'left', on = 'sample_ix')\
        .assign(leading_path = lambda df: df['prev2_expert'] + '-' + df['prev_expert'])
    
    return layer_df

del sample_df_import, sample_df_raw, topk_df_import

gc.collect()
display(topk_df)
display(sample_df)

In [None]:
"""
Convert activations to fp16 (for compatibility with cupy later) + dict
"""
all_pre_mlp_hs = all_pre_mlp_hs_import.to(torch.float16)
# compare_bf16_fp16_batched(all_pre_mlp_hs_import, all_pre_mlp_hs)
del all_pre_mlp_hs_import
all_pre_mlp_hs = {(layer_ix + model_pre_mlp_layers): all_pre_mlp_hs[:, save_ix, :] for save_ix, layer_ix in enumerate(act_map)}

gc.collect()

## SVD Decomposition

In [None]:
"""
Let's take the pre-MLP hidden states and split them using SVD into parallel and orthogonal components.
"""
h_para_by_layer = {}
h_orth_by_layer = {}

for layer_ix in tqdm(list(all_pre_mlp_hs.keys())):
    h_para_by_layer[layer_ix], h_orth_by_layer[layer_ix] = decompose_orthogonal(
        all_pre_mlp_hs[layer_ix].to(torch.float32),
        model.model.layers[layer_ix].mlp.gate.weight.detach().cpu().to(torch.float32),
        'svd'
    )

## Orth vs Para Rotation

In [None]:
"""
Get row-space rotation stability
"""
bootstrap_samples = 50

def get_sample_res(hs_by_layer, samples_to_test = 1):
    
    samples = np.random.randint(0, hs_by_layer[1].shape[0], samples_to_test)

    # Cast into sample-level list
    sample_tensors = torch.stack([layer_hs[samples, :] for _, layer_hs in hs_by_layer.items()], dim = 1).unbind(dim = 0)

    sims = []
    for s in sample_tensors:
        cos_sim = sklearn.metrics.pairwise.cosine_similarity(s)
        sims.append(np.diag(cos_sim, 1))

    return np.mean(np.stack(sims, axis = 0), axis = 0)

para_res = np.stack([get_sample_res(h_para_by_layer) for _ in range(bootstrap_samples)], axis = 0) # bootstrap_samples x layer_diffs

para_mean_across_layers = para_res.mean(axis = 0)
para_cis_across_layers = 1.96 * np.std(para_res, axis = 0)

para_mean_overall = np.mean(para_mean_across_layers)
para_mean_ci = 1.96 * np.std(np.mean(para_res, axis = 1)).item()

# print(f"Mean across layer transitions: {para_mean_across_layers}")
print(f"Mean across layer transitions + samples: {para_mean_overall:.2f} +/- {para_mean_ci:.2f}")

In [None]:
"""
Get null-space rotation stability
"""
orth_res = np.stack([get_sample_res(h_orth_by_layer) for _ in range(bootstrap_samples)], axis = 0) # bootstrap_samples x layer_diffs

orth_mean_across_layers = orth_res.mean(axis = 0)
orth_cis_across_layers = 1.96 * np.std(orth_res, axis = 0)

orth_mean_overall = np.mean(orth_mean_across_layers)
orth_mean_ci = 1.96 * np.std(np.mean(orth_res, axis = 1)).item()

# print(f"Mean across layer transitions: {orth_mean_across_layers}")
print(f"Mean across layer transitions + samples: {orth_mean_overall:.2f} +/- {orth_mean_ci:.2f}")

In [None]:
"""
Export
"""
export_df = pd.DataFrame({
    'layer_ix_1': list(range(model_pre_mlp_layers + 1, len(all_pre_mlp_hs) + model_pre_mlp_layers)), # +1 to 1 index
    'para_mean_across_layers': para_mean_across_layers,
    'orth_mean_across_layers': orth_mean_across_layers,
    'para_cis': para_cis_across_layers,
    'orth_cis': orth_cis_across_layers
})

export_df.to_csv(f'exports/svd-transition-stability-{model_prefix}.csv', index = False)

## Reconstruction/probing tests

In [None]:
"""
Logistic regression - predict expert ID
"""
def run_lr(x_cp, y_cp):
    x_train, x_test, y_train, y_test = cuml.train_test_split(x_cp, y_cp, test_size = 0.2, random_state = 123)
    lr_model = cuml.linear_model.LogisticRegression(penalty = 'l2', max_iter = 1000, fit_intercept = True)
    lr_model.fit(x_train, y_train)
    accuracy = lr_model.score(x_test, y_test)
    return accuracy

def run_lr_with_mi(x_cp, y_cp):
    x_train, x_test, y_train, y_test = cuml.train_test_split(x_cp, y_cp, test_size = 0.2, random_state = 123)
    lr_model = cuml.linear_model.LogisticRegression(penalty = 'l2', max_iter = 1000, fit_intercept = True)
    lr_model.fit(x_train, y_train)
    accuracy = lr_model.score(x_test, y_test)
    train_acc = lr_model.score(x_train, y_train)
    y_actual_np = cupy.asnumpy(y_test)
    y_pred_np = cupy.asnumpy(lr_model.predict(x_test))
    mi = sklearn.metrics.mutual_info_score(y_actual_np, y_pred_np) # nats
    max_entropy = sklearn.metrics.mutual_info_score(y_actual_np, y_actual_np) # H(y)
    return accuracy, mi.item(), max_entropy.item(), train_acc

current_layer_accuracy = []
for test_layer in tqdm(list(h_para_by_layer.keys())):
    expert_ids =\
        topk_df\
        .pipe(lambda df: df[df['layer_ix'] == test_layer])\
        .pipe(lambda df: df[df['topk_ix'] == 1])\
        ['expert'].tolist()

    expert_ids_cp = cupy.asarray(expert_ids)
    x_cp_para = cupy.asarray(h_para_by_layer[test_layer].to(torch.float16).detach().cpu())
    x_cp_orth = cupy.asarray(h_orth_by_layer[test_layer].to(torch.float16).detach().cpu())

    para_res = run_lr_with_mi(x_cp_para, expert_ids_cp)
    orth_res = run_lr_with_mi(x_cp_orth, expert_ids_cp)

    current_layer_accuracy.append({
        'test_layer_1': test_layer + model_pre_mlp_layers + 1,
        'para_acc': para_res[0],
        'para_train_acc': para_res[3],
        'para_mi_bits': para_res[1]/np.log(2.0), # Convert from nats to bits
        'para_entropy_bits': para_res[2]/np.log(2.0),
        'para_mi_pct': para_res[1]/para_res[2],
        'orth_acc': orth_res[0],
        'para_train_acc': para_res[3],
        'orth_mi_bits': orth_res[1]/np.log(2.0),
        'orth_entropy_bits': orth_res[2]/np.log(2.0),
        'orth_mi_pct': orth_res[1]/orth_res[2]
    })

pd.DataFrame(current_layer_accuracy)

In [None]:
"""
Use h_para and h_orth to predict NEXT layer expert ids (note - this does not remove expert info, remove below)
"""
# next_layer_accuracy = []
# for test_layer in tqdm(list(h_para_by_layer.keys())[:-1]):
#     expert_ids =\
#         topk_df\
#         .pipe(lambda df: df[df['layer_ix'] == test_layer + 1])\
#         .pipe(lambda df: df[df['topk_ix'] == 1])\
#         ['expert'].tolist()

#     expert_ids_cp = cupy.asarray(expert_ids)
#     x_cp_para = cupy.asarray(h_para_by_layer[test_layer].to(torch.float16).detach().cpu())
#     x_cp_orth = cupy.asarray(h_orth_by_layer[test_layer].to(torch.float16).detach().cpu())

#     para_res = run_lr_with_mi(x_cp_para, expert_ids_cp)
#     orth_res = run_lr_with_mi(x_cp_orth, expert_ids_cp)
    
#     next_layer_accuracy.append({
#         'test_layer_1': test_layer + model_pre_mlp_layers + 1,
#         'para_acc': para_res[0],
#         'para_mi_bits': para_res[1]/np.log(2.0),
#         'para_entropy_bits': para_res[2]/np.log(2.0),
#         'para_mi_pct': para_res[1]/para_res[2],
#         'orth_acc': orth_res[0],
#         'orth_mi_bits': orth_res[1]/np.log(2.0),
#         'orth_entropy_bits': orth_res[2]/np.log(2.0),
#         'orth_mi_pct': orth_res[1]/orth_res[2]
#     })

# pd.DataFrame(next_layer_accuracy)

In [None]:
"""
Use h_para and h_orth to predict NEXT layer expert ids. Remove expert centroids first.
"""
centroids_para = {}
centroids_orth = {}

# Get current-layer expert IDs for layer
for layer_ix in h_para_by_layer.keys():

    cur_layer_expert_ids =\
        topk_df\
        .pipe(lambda df: df[df['layer_ix'] == layer_ix])\
        .pipe(lambda df: df[df['topk_ix'] == 1])\
        ['expert'].tolist()

    cur_layer_expert_ids_cp = cupy.asarray(cur_layer_expert_ids)

    # H_para/h_orth for layer
    h_para_cp = cupy.asarray(h_para_by_layer[layer_ix].to(torch.float16).detach().cpu())
    h_orth_cp = cupy.asarray(h_orth_by_layer[layer_ix].to(torch.float16).detach().cpu())

    # Compute centroids per expert id
    centroids_para[layer_ix] = {}
    centroids_orth[layer_ix] = {}

    for e in set(cur_layer_expert_ids):
        idx_cp = cupy.where(cur_layer_expert_ids_cp == e)[0]
        centroids_para[layer_ix][e] = h_para_cp[idx_cp].mean(axis = 0)
        centroids_orth[layer_ix][e] = h_orth_cp[idx_cp].mean(axis = 0)

next_layer_accuracy_cond = []
for test_layer in tqdm(list(h_para_by_layer.keys())[:-1]):
    # Target = next-layer slot-1 expert IDs (same as before)
    y_cp =\
        topk_df\
        .pipe(lambda df: df[df['layer_ix'] == test_layer + 1])\
        .pipe(lambda df: df[df['topk_ix'] == 1])\
        ['expert'].to_numpy()
    y_cp = cupy.asarray(y_cp)

    # Current-layer top-1 expert IDs - needed for residual lookup
    cur_exp_cp =\
        topk_df\
        .pipe(lambda df: df[df['layer_ix'] == test_layer])\
        .pipe(lambda df: df[df['topk_ix'] == 1])\
        ['expert'].to_numpy()
    cur_exp_cp = cupy.asarray(cur_exp_cp)

    # Pull h_para / h_orth tensors and convert to cupy
    h_para_cp = cupy.asarray(h_para_by_layer[test_layer].to(torch.float16).detach().cpu())
    h_orth_cp = cupy.asarray(h_orth_by_layer[test_layer].to(torch.float16).detach().cpu())

    # Subtract extract centroids
    mu_para_mat = cupy.stack([centroids_para[test_layer][int(e)] for e in cur_exp_cp])
    mu_orth_mat = cupy.stack([centroids_orth[test_layer][int(e)] for e in cur_exp_cp])
    h_para_res = h_para_cp - mu_para_mat
    h_orth_res = h_orth_cp - mu_orth_mat

    # Run the unchanged probe
    para_res = run_lr_with_mi(h_para_res, y_cp)
    orth_res = run_lr_with_mi(h_orth_res, y_cp)

    next_layer_accuracy_cond.append({
        'test_layer_1': test_layer + model_pre_mlp_layers + 1,
        'para_acc': para_res[0],
        'para_mi_bits': para_res[1]/np.log(2.0),
        'para_entropy_bits': para_res[2]/np.log(2.0),
        'para_mi_pct': para_res[1]/para_res[2],
        'orth_acc': orth_res[0],
        'orth_mi_bits': orth_res[1]/np.log(2.0),
        'orth_entropy_bits': orth_res[2]/np.log(2.0),
        'orth_mi_pct': orth_res[1]/orth_res[2]
    })

display(pd.DataFrame(next_layer_accuracy_cond))

In [None]:
"""
Export results
"""
layer_transitions_export_df = pd.concat([
    pd.DataFrame(current_layer_accuracy).assign(target = 'current_layer'),
    pd.DataFrame(next_layer_accuracy_cond).assign(target = 'next_layer')
]).assign(model = model_prefix)

display(layer_transitions_export_df)

layer_transitions_export_df.to_csv(f'exports/svd-probe-expert-id-{model_prefix}.csv', index = False)

In [None]:
"""
Predict language - presplit, seperate TIDs
"""
# def run_lr_with_mi_presplit(x_train, x_test, y_train, y_test):
#     lr_model = cuml.linear_model.LogisticRegression(penalty = 'l2', max_iter = 1000, fit_intercept = True)
#     lr_model.fit(x_train, y_train)
#     accuracy = lr_model.score(x_test, y_test)
#     train_acc = lr_model.score(x_train, y_train)
#     y_actual_np = cupy.asnumpy(y_test)
#     y_pred_np = cupy.asnumpy(lr_model.predict(x_test))
#     mi = sklearn.metrics.mutual_info_score(y_actual_np, y_pred_np) # nats
#     max_entropy = sklearn.metrics.mutual_info_score(y_actual_np, y_actual_np) # H(y)
#     return accuracy, mi.item(), max_entropy.item(), train_acc

# lang_probe_accs = []
# # Split train/test, different TIDs in each
# gss = sklearn.model_selection.GroupShuffleSplit(n_splits = 1, test_size = 0.2, random_state = 123)
# train_ix, test_ix = next(gss.split(sample_df, groups = sample_df['token_id']))

# train_sample_df = sample_df.take(train_ix)
# test_sample_df = sample_df.take(test_ix)

# # Prep y values
# source_mapping = {source: i for i, source in enumerate(sample_df['source'].unique())}

# y_train = cupy.asarray(train_sample_df.assign(source = lambda df: df['source'].map(source_mapping))['source'].tolist())
# y_test = cupy.asarray(test_sample_df.assign(source = lambda df: df['source'].map(source_mapping))['source'].tolist())

# for test_layer in tqdm(list(h_para_by_layer.keys())[::2]):

#     x_para = cupy.asarray(h_para_by_layer[test_layer].to(torch.float16).detach().cpu())
#     x_orth = cupy.asarray(h_orth_by_layer[test_layer].to(torch.float16).detach().cpu())
    
#     x_train_para = x_para[train_sample_df['sample_ix'].tolist(), :]
#     x_test_para = x_para[test_sample_df['sample_ix'].tolist(), :]
#     x_train_orth = x_orth[train_sample_df['sample_ix'].tolist(), :]
#     x_test_orth = x_orth[test_sample_df['sample_ix'].tolist(), :]

#     para_res = run_lr_with_mi_presplit(x_train_para, x_test_para, y_train, y_test)
#     orth_res = run_lr_with_mi_presplit(x_train_orth, x_test_orth, y_train, y_test)

#     lang_probe_accs.append({
#         'test_layer_1': test_layer + model_pre_mlp_layers + 1,
#         'para_acc': para_res[0],
#         'para_train_acc': para_res[3],
#         'para_mi_bits': para_res[1]/np.log(2.0),
#         'para_entropy_bits': para_res[2]/np.log(2.0),
#         'para_mi_pct': para_res[1]/para_res[2],
#         'orth_acc': orth_res[0],
#         'orth_train_acc': orth_res[3],
#         'orth_mi_bits': orth_res[1]/np.log(2.0),
#         'orth_entropy_bits': orth_res[2]/np.log(2.0),
#         'orth_mi_pct': orth_res[1]/orth_res[2]
#     })

#     display(pd.DataFrame(lang_probe_accs))

In [None]:
"""
Predict Language
"""
lang_probe_accs = []

for test_layer in tqdm(list(h_para_by_layer.keys())[::]):

    source_mapping = {source: i for i, source in enumerate(sample_df['source'].unique())}

    y_df =\
        sample_df\
        .assign(source = lambda df: df['source'].map(source_mapping))\
        ['source']\
        .tolist()

    y_cp = cupy.asarray(y_df)
    x_cp_para = cupy.asarray(h_para_by_layer[test_layer].to(torch.float16).detach().cpu())
    x_cp_orth = cupy.asarray(h_orth_by_layer[test_layer].to(torch.float16).detach().cpu())
    
    para_res = run_lr_with_mi(x_cp_para, y_cp)
    orth_res = run_lr_with_mi(x_cp_orth, y_cp)

    lang_probe_accs.append({
        'test_layer_1': test_layer + model_pre_mlp_layers + 1,
        'para_acc': para_res[0],
        'para_train_acc': para_res[3],
        'para_mi_bits': para_res[1]/np.log(2.0),
        'para_entropy_bits': para_res[2]/np.log(2.0),
        'para_mi_pct': para_res[1]/para_res[2],
        'orth_acc': orth_res[0],
        'orth_train_acc': orth_res[3],
        'orth_mi_bits': orth_res[1]/np.log(2.0),
        'orth_entropy_bits': orth_res[2]/np.log(2.0),
        'orth_mi_pct': orth_res[1]/orth_res[2]
    })

display(pd.DataFrame(lang_probe_accs))

In [None]:
"""
Export
"""
display(sample_df.groupby('source', as_index = False).agg(z = ('sample_ix', 'count')))
lang_export_df = pd.DataFrame(lang_probe_accs)
display(lang_export_df)

lang_export_df.to_csv(f'exports/svd-probe-lang-{model_prefix}.csv', index = False)

In [None]:
"""
Predict TID
"""
tid_probe_accs = []

for test_layer in tqdm(list(h_para_by_layer.keys())[::2]):

    clear_all_cuda_memory(False)

    top_tids =\
        sample_df\
        .pipe(lambda df: df[df['source'] == 'en'])\
        .groupby(['token_id', 'token'], as_index = False)\
        .agg(n = ('token', 'count')).sort_values(by = 'n', ascending = False)\
        .head(500)

    valid_samples =\
        sample_df\
        .assign(token_id = lambda df: np.where(df['token_id'].isin(top_tids['token_id']), df['token_id'], 999999))
        # .pipe(lambda df: df[df['token_id'].isin(top_tids['token_id'].tolist())])

    y_df =\
        valid_samples\
        ['token_id']\
        .tolist()

    y_cp = cupy.asarray(y_df)
    x_cp_para = cupy.asarray(h_para_by_layer[test_layer][valid_samples['sample_ix'].tolist(), :].to(torch.float16).detach().cpu())
    x_cp_orth = cupy.asarray(h_orth_by_layer[test_layer][valid_samples['sample_ix'].tolist(), :].to(torch.float16).detach().cpu())

    para_res = run_lr_with_mi(x_cp_para, y_cp)
    orth_res = run_lr_with_mi(x_cp_orth, y_cp)

    tid_probe_accs.append({
        'test_layer_1': test_layer + model_pre_mlp_layers + 1,
        'para_acc': para_res[0],
        'para_train_acc': para_res[3],
        'para_mi_bits': para_res[1]/np.log(2.0),
        'para_entropy_bits': para_res[2]/np.log(2.0),
        'para_mi_pct': para_res[1]/para_res[2],
        'orth_acc': orth_res[0],
        'orth_train_acc': orth_res[3],
        'orth_mi_bits': orth_res[1]/np.log(2.0),
        'orth_entropy_bits': orth_res[2]/np.log(2.0),
        'orth_mi_pct': orth_res[1]/orth_res[2]
    })

pd.DataFrame(tid_probe_accs)

In [None]:
"""
Export
"""
tid_export_df = pd.DataFrame(tid_probe_accs)
display(tid_export_df)

tid_export_df.to_csv(f'exports/svd-probe-tid-{model_prefix}.csv', index = False)