In [1]:
"""
This contains code to use SVD to decompose hidden states based on whether they're used by routing or not.
"""
None

In [2]:
"""
Imports
"""
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import pandas as pd
import numpy as np
import scipy
import cupy
import cuml
import sklearn

import importlib
import gc
import pickle
import os

from tqdm import tqdm
from termcolor import colored
import plotly.express as px
from plotly.subplots import make_subplots

from utils.memory import check_memory, clear_all_cuda_memory
from utils.quantize import compare_bf16_fp16_batched
from utils.svd import decompose_orthogonal, decompose_sideways
from utils.vis import combine_plots

main_device = 'cuda:0'
seed = 1234
clear_all_cuda_memory()
check_memory()

All CUDA memory cleared on all devices.
Device 0: NVIDIA H100 PCIe
  Allocated: 0.00 GB
  Reserved: 0.00 GB
  Total: 79.10 GB



## Load model & data

In [3]:
"""
Load the base tokenizer/model
"""
model_ix = 2
models_list = [
    ('allenai/OLMoE-1B-7B-0125-Instruct', 'olmoe', 0),
    ('Qwen/Qwen1.5-MoE-A2.7B-Chat', 'qwen1.5moe', 0),
    ('deepseek-ai/DeepSeek-V2-Lite', 'dsv2', 1),
    ('Qwen/Qwen3-30B-A3B', 'qwen3moe', 0)
]

model_id, model_prefix, model_pre_mlp_layers = models_list[model_ix]
tokenizer = AutoTokenizer.from_pretrained(model_id, add_eos_token = False, add_bos_token = False, padding_side = 'left', trust_remote_code = True)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype = torch.bfloat16, trust_remote_code = True).cuda().eval()

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [4]:
"""
Load dataset
"""
def load_data(model_prefix, max_data_files):
    """
    Load data saved by `export-activations-sm.ipynb`
    """
    folders = [f'./../export-data/activations-sm/{model_prefix}/{i:02d}' for i in range(max_data_files)]
    folders = [f for f in folders if os.path.isdir(f)]

    all_pre_mlp_hs = []
    sample_df = []
    topk_df = []

    for f in tqdm(folders):
        sample_df.append(pd.read_pickle(f'{f}/samples.pkl'))
        topk_df.append(pd.read_pickle(f'{f}/topks.pkl'))
        all_pre_mlp_hs.append(torch.load(f'{f}/all-pre-mlp-hidden-states.pt'))

    sample_df = pd.concat(sample_df)
    topk_df = pd.concat(topk_df)
    all_pre_mlp_hs = torch.concat(all_pre_mlp_hs)    

    with open(f'./../export-data/activations-sm/{model_prefix}/metadata.pkl', 'rb') as f:
        metadata = pickle.load(f)
    
    gc.collect()
    return sample_df, topk_df, all_pre_mlp_hs, metadata['all_pre_mlp_hidden_states_layers']

# Due to mem constraints, for Qwen3Moe max_data_files = 2
sample_df_import, topk_df_import, all_pre_mlp_hs_import, act_map = load_data(model_prefix, 2)

100%|██████████| 2/2 [00:10<00:00,  5.17s/it]


In [5]:
"""
Let's clean up the mappings here. We'll get everything to a sample_ix level first.
"""
sample_df_raw =\
    sample_df_import\
    .assign(sample_ix = lambda df: df.groupby(['batch_ix', 'sequence_ix', 'token_ix']).ngroup())\
    .assign(seq_id = lambda df: df.groupby(['batch_ix', 'sequence_ix']).ngroup())\
    .reset_index()

topk_df =\
    topk_df_import\
    .merge(sample_df_raw[['sample_ix', 'batch_ix', 'sequence_ix', 'token_ix']], how = 'inner', on = ['sequence_ix', 'token_ix', 'batch_ix'])\
    .drop(columns = ['sequence_ix', 'token_ix', 'batch_ix'])\
    .assign(layer_ix = lambda df: df['layer_ix'] + model_pre_mlp_layers)

topk1_df =\
    topk_df\
    .pipe(lambda df: df[df['topk_ix'] == 1])

sample_df =\
    sample_df_raw\
    .drop(columns = ['batch_ix', 'sequence_ix'])

def get_sample_df_for_layer(sample_df, topk_df, layer_ix):
    """
    Helper to take the sample df and merge layer-level expert selection information
    """
    topk_layer_df = topk_df.pipe(lambda df: df[df['layer_ix'] == layer_ix])
    topk_l1_layer_df = topk_df.pipe(lambda df: df[df['layer_ix'] == layer_ix - 1])
    topk_l2_layer_df = topk_df.pipe(lambda df: df[df['layer_ix'] == layer_ix - 2])

    layer_df =\
        sample_df\
        .merge(topk_layer_df.pipe(lambda df: df[df['topk_ix'] == 1])[['sample_ix', 'expert']], how = 'inner', on = 'sample_ix')\
        .merge(topk_l1_layer_df.pipe(lambda df: df[df['topk_ix'] == 1]).rename(columns = {'expert': 'prev_expert'})[['sample_ix', 'prev_expert']], how = 'left', on = 'sample_ix')\
        .merge(topk_l2_layer_df.pipe(lambda df: df[df['topk_ix'] == 1]).rename(columns = {'expert': 'prev2_expert'})[['sample_ix', 'prev2_expert']], how = 'left', on = 'sample_ix')\
        .merge(topk_layer_df.pipe(lambda df: df[df['topk_ix'] == 2]).rename(columns = {'expert': 'expert2'})[['sample_ix', 'expert2']], how = 'left', on = 'sample_ix')\
        .assign(leading_path = lambda df: df['prev2_expert'] + '-' + df['prev_expert'])
    
    return layer_df

del sample_df_import, sample_df_raw, topk_df_import

gc.collect()
display(topk_df)
display(sample_df)

Unnamed: 0,layer_ix,topk_ix,expert,weight,sample_ix
0,1,1,34,0.12,0
1,1,2,29,0.10,0
2,1,3,23,0.10,0
3,1,4,11,0.09,0
4,1,5,28,0.05,0
...,...,...,...,...,...
24699007,26,2,32,0.07,158326
24699008,26,3,50,0.07,158326
24699009,26,4,51,0.05,158326
24699010,26,5,27,0.04,158326


Unnamed: 0,index,token_ix,token_id,output_id,output_prob,token,source,sample_ix,seq_id
0,0,7,3942,11628,0.16,LO,es,0,0
1,1,8,42791,5564,0.07,QUE,es,1,0
2,2,9,417,5771,0.59,F,es,2,0
3,3,10,42549,77552,0.50,ALT,es,3,0
4,4,11,77552,185,0.18,ABA,es,4,0
...,...,...,...,...,...,...,...,...,...
158322,81312,507,44,660,1.00,M,es,158322,499
158323,81313,508,660,1708,0.98,ens,es,158323,499
158324,81314,509,1708,658,1.00,aj,es,158324,499
158325,81315,510,658,256,1.00,ep,es,158325,499


In [6]:
"""
Convert activations to fp16 (for compatibility with cupy later) + dict
"""
all_pre_mlp_hs = all_pre_mlp_hs_import.to(torch.float16)
# compare_bf16_fp16_batched(all_pre_mlp_hs_import, all_pre_mlp_hs)
del all_pre_mlp_hs_import
all_pre_mlp_hs = {(layer_ix + model_pre_mlp_layers): all_pre_mlp_hs[:, save_ix, :] for save_ix, layer_ix in enumerate(act_map)}

gc.collect()

0

## SVD Decomposition

In [7]:
"""
Let's take the pre-MLP hidden states and split them using SVD into parallel and orthogonal components.
"""
h_para_by_layer = {}
h_orth_by_layer = {}

for layer_ix in tqdm(list(all_pre_mlp_hs.keys())):
    h_para_by_layer[layer_ix], h_orth_by_layer[layer_ix] = decompose_orthogonal(
        all_pre_mlp_hs[layer_ix].to(torch.float32),
        model.model.layers[layer_ix].mlp.gate.weight.detach().cpu().to(torch.float32),
        'svd'
    )

100%|██████████| 26/26 [01:10<00:00,  2.71s/it]


## Orth vs Para Rotation

In [8]:
bootstrap_samples = 50

def get_sample_res(hs_by_layer, samples_to_test = 1):
    
    samples = np.random.randint(0, hs_by_layer[1].shape[0], samples_to_test)

    # Cast into sample-level list
    sample_tensors = torch.stack([layer_hs[samples, :] for _, layer_hs in hs_by_layer.items()], dim = 1).unbind(dim = 0)

    sims = []
    for s in sample_tensors:
        cos_sim = sklearn.metrics.pairwise.cosine_similarity(s)
        sims.append(np.diag(cos_sim, 1))

    return np.mean(np.stack(sims, axis = 0), axis = 0)

para_res = np.stack([get_sample_res(h_para_by_layer) for _ in range(bootstrap_samples)], axis = 0) # bootstrap_samples x layer_diffs

para_mean_across_layers = para_res.mean(axis = 0)
para_cis_across_layers = 1.96 * np.std(para_res, axis = 0)

para_mean_overall = np.mean(para_mean_across_layers)
para_mean_ci = 1.96 * np.std(np.mean(para_res, axis = 1)).item()

# print(f"Mean across layer transitions: {para_mean_across_layers}")
print(f"Mean across layer transitions + samples: {para_mean_overall:.2f} +/- {para_mean_ci:.2f}")

Mean across layer transitions + samples: 0.53 +/- 0.11


In [9]:
orth_res = np.stack([get_sample_res(h_orth_by_layer) for _ in range(bootstrap_samples)], axis = 0) # bootstrap_samples x layer_diffs

orth_mean_across_layers = orth_res.mean(axis = 0)
orth_cis_across_layers = 1.96 * np.std(orth_res, axis = 0)

orth_mean_overall = np.mean(orth_mean_across_layers)
orth_mean_ci = 1.96 * np.std(np.mean(orth_res, axis = 1)).item()

# print(f"Mean across layer transitions: {orth_mean_across_layers}")
print(f"Mean across layer transitions + samples: {orth_mean_overall:.2f} +/- {orth_mean_ci:.2f}")

Mean across layer transitions + samples: 0.84 +/- 0.04


In [10]:
export_df = pd.DataFrame({
    'layer_ix_1': list(range(model_pre_mlp_layers + 1 + 1, len(all_pre_mlp_hs) + model_pre_mlp_layers + 1)), # +1 since these represent the transition-ends, and +1 to 1 index
    'para_mean_across_layers': para_mean_across_layers,
    'orth_mean_across_layers': orth_mean_across_layers,
    'para_cis': para_cis_across_layers,
    'orth_cis': orth_cis_across_layers
})

export_df.to_csv(f'exports/transition-stability-{model_prefix}.csv', index = False)

## Reconstruction/probing tests

In [11]:
"""
Logistic regression - predict topk using h_orth?
"""
def run_lr(x_cp, y_cp):
    x_train, x_test, y_train, y_test = cuml.train_test_split(x_cp, y_cp, test_size = 0.1, random_state = 123)
    lr_model = cuml.linear_model.LogisticRegression(penalty = 'l2', max_iter = 1000, fit_intercept = False)
    lr_model.fit(x_train, y_train)
    accuracy = lr_model.score(x_test, y_test)
    return accuracy

current_layer_accuracy = []
for test_layer in tqdm(list(h_para_by_layer.keys())):
    expert_ids =\
        topk_df\
        .pipe(lambda df: df[df['layer_ix'] == test_layer])\
        .pipe(lambda df: df[df['topk_ix'] == 1])\
        ['expert'].tolist()

    expert_ids_cp = cupy.asarray(expert_ids)
    x_cp_para = cupy.asarray(h_para_by_layer[test_layer].to(torch.float16).detach().cpu())
    x_cp_orth = cupy.asarray(h_orth_by_layer[test_layer].to(torch.float16).detach().cpu())

    current_layer_accuracy.append({
        'test_layer': test_layer + model_pre_mlp_layers + 1,
        'para_acc': run_lr(x_cp_para, expert_ids_cp),
        'orth_acc': run_lr(x_cp_orth, expert_ids_cp)
    })

current_layer_accuracy

 46%|████▌     | 12/26 [00:23<00:34,  2.50s/it]



 62%|██████▏   | 16/26 [00:35<00:29,  2.95s/it]



 65%|██████▌   | 17/26 [00:39<00:29,  3.26s/it]



 69%|██████▉   | 18/26 [00:43<00:28,  3.62s/it]



 77%|███████▋  | 20/26 [00:53<00:25,  4.24s/it]



 81%|████████  | 21/26 [00:58<00:22,  4.54s/it]



 85%|████████▍ | 22/26 [01:03<00:18,  4.66s/it]



 88%|████████▊ | 23/26 [01:08<00:13,  4.60s/it]



 92%|█████████▏| 24/26 [01:12<00:09,  4.54s/it]



 96%|█████████▌| 25/26 [01:17<00:04,  4.48s/it]



100%|██████████| 26/26 [01:21<00:00,  3.13s/it]






[{'test_layer': 3,
  'para_acc': 0.9633021728145528,
  'orth_acc': 0.777033855482567},
 {'test_layer': 4,
  'para_acc': 0.969365841334007,
  'orth_acc': 0.6615715007579586},
 {'test_layer': 5,
  'para_acc': 0.9692395149065185,
  'orth_acc': 0.6445174330469934},
 {'test_layer': 6,
  'para_acc': 0.9711344113188479,
  'orth_acc': 0.6361798888327438},
 {'test_layer': 7,
  'para_acc': 0.9733451237998989,
  'orth_acc': 0.6090828701364326},
 {'test_layer': 8,
  'para_acc': 0.9713239009600808,
  'orth_acc': 0.5960712481051036},
 {'test_layer': 9,
  'para_acc': 0.9739135927235978,
  'orth_acc': 0.572321879737241},
 {'test_layer': 10,
  'para_acc': 0.9737872662961091,
  'orth_acc': 0.5857124810510359},
 {'test_layer': 11,
  'para_acc': 0.9735977766548762,
  'orth_acc': 0.5781960586154623},
 {'test_layer': 12,
  'para_acc': 0.9733451237998989,
  'orth_acc': 0.5608261748357757},
 {'test_layer': 13,
  'para_acc': 0.9737872662961091,
  'orth_acc': 0.6160308236483072},
 {'test_layer': 14,
  'para_acc

In [12]:
"""
Use h_para and h_orth to predict NEXT layer expert ids
"""
next_layer_accuracy = []
for test_layer in tqdm(list(h_para_by_layer.keys())[:-1]):
    expert_ids =\
        topk_df\
        .pipe(lambda df: df[df['layer_ix'] == test_layer + 1])\
        .pipe(lambda df: df[df['topk_ix'] == 1])\
        ['expert'].tolist()

    expert_ids_cp = cupy.asarray(expert_ids)
    x_cp_para = cupy.asarray(h_para_by_layer[test_layer].to(torch.float16).detach().cpu())
    x_cp_orth = cupy.asarray(h_orth_by_layer[test_layer].to(torch.float16).detach().cpu())

    next_layer_accuracy.append({
        'test_layer': test_layer + model_pre_mlp_layers + 1,
        'para_acc': run_lr(x_cp_para, expert_ids_cp),
        'orth_acc':run_lr(x_cp_orth, expert_ids_cp)
    })

next_layer_accuracy

 44%|████▍     | 11/25 [00:47<00:39,  2.85s/it]



 52%|█████▏    | 13/25 [00:51<00:29,  2.44s/it]



 60%|██████    | 15/25 [00:56<00:23,  2.33s/it]



 64%|██████▍   | 16/25 [00:58<00:21,  2.38s/it]



 72%|███████▏  | 18/25 [01:04<00:18,  2.61s/it]



 76%|███████▌  | 19/25 [01:07<00:17,  2.92s/it]



 80%|████████  | 20/25 [01:10<00:15,  3.04s/it]



 84%|████████▍ | 21/25 [01:15<00:13,  3.37s/it]



 92%|█████████▏| 23/25 [01:23<00:07,  3.74s/it]



 96%|█████████▌| 24/25 [01:27<00:03,  3.78s/it]



100%|██████████| 25/25 [01:30<00:00,  3.64s/it]






[{'test_layer': 3,
  'para_acc': 0.46058615462354724,
  'orth_acc': 0.6954269833249116},
 {'test_layer': 4,
  'para_acc': 0.5368873168266801,
  'orth_acc': 0.7076174835775644},
 {'test_layer': 5,
  'para_acc': 0.5812910560889338,
  'orth_acc': 0.6779939363314805},
 {'test_layer': 6,
  'para_acc': 0.563100050530571,
  'orth_acc': 0.6469176351692774},
 {'test_layer': 7,
  'para_acc': 0.5716270843860536,
  'orth_acc': 0.6489388580090955},
 {'test_layer': 8,
  'para_acc': 0.552678120262759,
  'orth_acc': 0.6342849924204144},
 {'test_layer': 9,
  'para_acc': 0.5528044466902476,
  'orth_acc': 0.6443279434057605},
 {'test_layer': 10,
  'para_acc': 0.5862809499747347,
  'orth_acc': 0.6362430520464881},
 {'test_layer': 11,
  'para_acc': 0.5495199595755432,
  'orth_acc': 0.6308741788782213},
 {'test_layer': 12,
  'para_acc': 0.6043456291056089,
  'orth_acc': 0.6695300656897423},
 {'test_layer': 13,
  'para_acc': 0.5383400707427994,
  'orth_acc': 0.654434057604851},
 {'test_layer': 14,
  'para_ac

In [13]:
export_df = pd.concat([
    pd.DataFrame(current_layer_accuracy).assign(target = 'current_layer'),
    pd.DataFrame(next_layer_accuracy).assign(target = 'next_layer'),
]).assign(model = model_prefix)
display(export_df)

export_df.to_csv(f'exports/transition-probe-{model_prefix}.csv', index = False)

Unnamed: 0,test_layer,para_acc,orth_acc,target,model
0,3,0.963302,0.777034,current_layer,dsv2
1,4,0.969366,0.661572,current_layer,dsv2
2,5,0.96924,0.644517,current_layer,dsv2
3,6,0.971134,0.63618,current_layer,dsv2
4,7,0.973345,0.609083,current_layer,dsv2
5,8,0.971324,0.596071,current_layer,dsv2
6,9,0.973914,0.572322,current_layer,dsv2
7,10,0.973787,0.585712,current_layer,dsv2
8,11,0.973598,0.578196,current_layer,dsv2
9,12,0.973345,0.560826,current_layer,dsv2


In [14]:
"""
Predict TID
"""
tid_probe_accs = []

for test_layer in tqdm(list(h_para_by_layer.keys())[::4]):

    clear_all_cuda_memory(False)

    top_tids =\
        sample_df\
        .pipe(lambda df: df[df['source'] == 'en'])\
        .groupby(['token_id', 'token'], as_index = False)\
        .agg(n = ('token', 'count')).sort_values(by = 'n', ascending = False)\
        .head(20_000)

    valid_samples =\
        sample_df\
        .pipe(lambda df: df[df['token_id'].isin(top_tids['token_id'].tolist())])

    y_df =\
        valid_samples\
        ['token_id']\
        .tolist()

    y_cp = cupy.asarray(y_df)
    x_cp_para = cupy.asarray(h_para_by_layer[test_layer][valid_samples['sample_ix'].tolist(), :].to(torch.float16).detach().cpu())
    x_cp_orth = cupy.asarray(h_orth_by_layer[test_layer][valid_samples['sample_ix'].tolist(), :].to(torch.float16).detach().cpu())

    tid_probe_accs.append({
        'test_layer': test_layer + model_pre_mlp_layers + 1,
        'para_acc': run_lr(x_cp_para, y_cp),
        'orth_acc': run_lr(x_cp_orth, y_cp)
    })

 43%|████▎     | 3/7 [16:57<26:41, 400.42s/it]



 57%|█████▋    | 4/7 [24:58<21:37, 432.50s/it]



 71%|███████▏  | 5/7 [34:03<15:45, 472.99s/it]



 86%|████████▌ | 6/7 [44:32<08:45, 525.89s/it]



100%|██████████| 7/7 [52:26<00:00, 449.50s/it]






In [15]:
tid_export_df = pd.DataFrame(tid_probe_accs)
display(tid_export_df)

tid_export_df.to_csv(f'exports/transition-probe-tid-{model_prefix}.csv', index = False)

Unnamed: 0,test_layer,para_acc,orth_acc
0,3,0.885774,0.938131
1,7,0.651347,0.859007
2,11,0.540993,0.805387
3,15,0.546801,0.795539
4,19,0.65564,0.790909
5,23,0.556145,0.78165
6,27,0.532323,0.737205


In [16]:
"""
Predict Language
"""
lang_probe_accs = []

for test_layer in tqdm(list(h_para_by_layer.keys())[::2]):

    source_mapping = {source: i for i, source in enumerate(sample_df['source'].unique())}

    y_df =\
        sample_df\
        .assign(source = lambda df: df['source'].map(source_mapping))\
        ['source']\
        .tolist()

    y_cp = cupy.asarray(y_df)
    x_cp_para = cupy.asarray(h_para_by_layer[test_layer].to(torch.float16).detach().cpu())
    x_cp_orth = cupy.asarray(h_orth_by_layer[test_layer].to(torch.float16).detach().cpu())

    lang_probe_accs.append({
        'test_layer': test_layer + model_pre_mlp_layers + 1,
        'para_acc': run_lr(x_cp_para, y_cp),
        'orth_acc': run_lr(x_cp_orth, y_cp)
    })

 23%|██▎       | 3/13 [00:02<00:10,  1.06s/it]



 31%|███       | 4/13 [00:04<00:12,  1.37s/it]



 38%|███▊      | 5/13 [00:06<00:13,  1.73s/it]



 46%|████▌     | 6/13 [00:09<00:14,  2.08s/it]



 62%|██████▏   | 8/13 [00:15<00:11,  2.36s/it]



 77%|███████▋  | 10/13 [00:20<00:07,  2.47s/it]



 85%|████████▍ | 11/13 [00:22<00:04,  2.48s/it]



 92%|█████████▏| 12/13 [00:25<00:02,  2.53s/it]



100%|██████████| 13/13 [00:28<00:00,  2.19s/it]






In [17]:
lang_export_df = pd.DataFrame(lang_probe_accs)
display(lang_export_df)

lang_export_df.to_csv(f'exports/transition-probe-lang-{model_prefix}.csv', index = False)

Unnamed: 0,test_layer,para_acc,orth_acc
0,3,0.854535,0.989831
1,5,0.796046,0.987873
2,7,0.747221,0.989641
3,9,0.743052,0.991536
4,11,0.72524,0.991978
5,13,0.749368,0.993178
6,15,0.730483,0.991536
7,17,0.790993,0.991473
8,19,0.796046,0.990904
9,21,0.835965,0.991031
