In [None]:
"""
This contains code to test orthogonality of expert specialization.
"""
None

In [1]:
"""
Imports
"""
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import pandas as pd
import numpy as np
import scipy
import cupy
import cuml
import sklearn

import importlib
import gc
import pickle

from tqdm import tqdm
from termcolor import colored
import plotly.express as px
from plotly.subplots import make_subplots

from utils.memory import check_memory, clear_all_cuda_memory
from utils.quantize import compare_bf16_fp16_batched
from utils.svd import decompose_orthogonal, decompose_sideways
from utils.vis import combine_plots

main_device = 'cuda:0'
seed = 1234
clear_all_cuda_memory()
check_memory()

## Load base model

In [2]:
"""
Load the base tokenizer/model
"""
model_id = 'allenai/OLMoE-1B-7B-0125-Instruct'
model_prefix = 'olmoe'
tokenizer = AutoTokenizer.from_pretrained(model_id, add_eos_token = False, add_bos_token = False, padding_side = 'left', trust_remote_code = True)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype = torch.bfloat16, trust_remote_code = True).cuda().eval()

In [3]:
"""
Load dataset
"""
def load_data(model_prefix):
    all_pre_mlp_hs = torch.load(f'data/{model_prefix}/all-pre-mlp-hidden-states.pt')
    all_expert_outputs = torch.load(f'data/{model_prefix}/all-expert-outputs.pt')
    with open(f'data/{model_prefix}/metadata.pkl', 'rb') as f:
        metadata = pickle.load(f)
    
    return all_pre_mlp_hs, all_expert_outputs, metadata['sample_df'], metadata['topk_df'], metadata['all_pre_mlp_hidden_states_layers'], metadata['all_expert_outputs_layers']

all_pre_mlp_hs_import, all_expert_outputs_import, sample_df_import, topk_df_import, act_map, expert_map = load_data(model_prefix)

In [4]:
"""
Let's clean up the mappings here. We'll get everything to a sample_ix level first.
"""
sample_df_raw =\
    sample_df_import\
    .assign(sample_ix = lambda df: df.groupby(['batch_ix', 'sequence_ix', 'token_ix']).ngroup())\
    .assign(seq_id = lambda df: df.groupby(['batch_ix', 'sequence_ix']).ngroup())\
    .reset_index()

topk_df =\
    topk_df_import\
    .merge(sample_df_raw[['sample_ix', 'batch_ix', 'sequence_ix', 'token_ix']], how = 'inner', on = ['sequence_ix', 'token_ix', 'batch_ix'])\
    .drop(columns = ['sequence_ix', 'token_ix', 'batch_ix'])

topk1_df =\
    topk_df\
    .pipe(lambda df: df[df['topk_ix'] == 1])

sample_df =\
    sample_df_raw\
    .drop(columns = ['batch_ix', 'sequence_ix'])

def get_sample_df_for_layer(sample_df, topk_df, layer_ix):
    """
    Helper to take the sample df and merge layer-level expert selection information
    """
    topk_layer_df = topk_df.pipe(lambda df: df[df['layer_ix'] == layer_ix])
    topk_last_layer_df = topk_df.pipe(lambda df: df[df['layer_ix'] == layer_ix - 1])
    layer_df =\
        sample_df\
        .merge(topk_layer_df.pipe(lambda df: df[df['topk_ix'] == 1])[['sample_ix', 'expert']], how = 'inner', on = 'sample_ix')\
        .merge(topk_last_layer_df.pipe(lambda df: df[df['topk_ix'] == 1]).rename(columns = {'expert': 'prev_expert'})[['sample_ix', 'prev_expert']], how = 'left', on = 'sample_ix')\
        .merge(topk_layer_df.pipe(lambda df: df[df['topk_ix'] == 2]).rename(columns = {'expert': 'expert2'})[['sample_ix', 'expert2']], how = 'left', on = 'sample_ix')
    
    return layer_df

del sample_df_import, topk_df_import
display(topk_df)
display(sample_df)

In [5]:
"""
Convert activations to fp16 (for compatibility with cupy later) + dict
"""
all_pre_mlp_hs = all_pre_mlp_hs_import.to(torch.float16)
compare_bf16_fp16_batched(all_pre_mlp_hs_import, all_pre_mlp_hs)
del all_pre_mlp_hs_import
all_pre_mlp_hs = {layer_ix: all_pre_mlp_hs[:, save_ix, :] for save_ix, layer_ix in enumerate(act_map)}

all_expert_outputs = all_expert_outputs_import.to(torch.float16)
compare_bf16_fp16_batched(all_expert_outputs_import, all_expert_outputs)
del all_expert_outputs_import
all_expert_outputs = {layer_ix: all_expert_outputs[:, save_ix, :, :] for save_ix, layer_ix in enumerate(expert_map)}

gc.collect()

In [6]:
"""
Misc visualization helpers
"""
def reduce_pca(input_tensor: torch.Tensor, n_components = 2):
    hs_cupy = cupy.asarray(input_tensor.to(torch.float32))
    model = cuml.PCA(
        iterated_power = 100,
        n_components = n_components,
        verbose = True
    )
    model.fit(hs_cupy)
    # print(f'Explained variance ratio: {model.explained_variance_ratio_}')
    print(f'Cumulative variance ratio: {np.cumsum(model.explained_variance_ratio_)[-1]}')
    pred = cupy.asnumpy(model.fit_transform(hs_cupy))
    clear_all_cuda_memory(False)
    return pred

def reduce_umap(input_tensor: torch.Tensor, n_components = 2, metric = 'cosine', n_epochs = 200):
    hs_cupy = cupy.asarray(input_tensor.to(torch.float32))
    model = cuml.UMAP(
        n_components = n_components, 
        n_neighbors = 20, # 15 for default, smaller = more local data preserved [2 - 100]
        metric = metric, # euclidean, cosine, manhattan, l2, hamming
        min_dist = 0.2, # 0.1 by default, effective distance between embedded points
        n_epochs = n_epochs, # 200 by default for large datasets
        random_state = None, # Allow parallelism
        verbose = False
    )
    pred = cupy.asnumpy(model.fit_transform(hs_cupy))
    clear_all_cuda_memory(False)
    return pred

def plot_manifold(plot_df, color_col, hover_col, title = None):
    plot = px.scatter(
        plot_df,
        x = 'd1', y = 'd2', color = color_col, hover_data = [hover_col],
        title = title, opacity = 0.9
    ).update_layout(autosize = False, height = 400).update_traces(marker = dict(size = 5))
    return plot

## Visualize activation clusters of single (layer, expert)

In [None]:
"""
Visualize clusters
"""
test_layer_ix = 9
test_expert = 1

test_sample_df =\
    get_sample_df_for_layer(sample_df, topk_df, test_layer_ix)\
    .pipe(lambda df: df[df['expert'] == test_expert])

test_sample_indices = test_sample_df['sample_ix'].tolist()

display(test_sample_df)

test_pre_mlp_hs = all_pre_mlp_hs[test_layer_ix][test_sample_indices]

In [None]:
"""
Plot PCA + UMAP WITHIN test expert only, color by previous expert ID
"""
pca_res = reduce_pca(test_pre_mlp_hs, 2)
pca_plot_df =\
    test_sample_df.assign(d1 = pca_res[:, 0], d2 = pca_res[:, 1])\
    .sample(5000)\
    .assign(prev_expert = lambda df: df['prev_expert'].astype(str))

combine_plots([
    plot_manifold(pca_plot_df, 'prev_expert', 'token', title = 'By previous expert'),
    plot_manifold(pca_plot_df, 'source', 'token', title = 'By source')
], f"PCA for layer {test_layer_ix}, expert {test_expert}").show()

umap_res = reduce_umap(test_pre_mlp_hs, 2, 'cosine', n_epochs = 2000)
umap_plot_df =\
    test_sample_df.assign(d1 = umap_res[:, 0], d2 = umap_res[:, 1])\
    .sample(5000)\
    .assign(prev_expert = lambda df: df['prev_expert'].astype(str))

combine_plots([
    plot_manifold(umap_plot_df, 'prev_expert', 'token', title = 'By previous expert'),
    plot_manifold(umap_plot_df, 'source', 'token', title = 'By source')
], f"UMAP for layer {test_layer_ix}, expert {test_expert}").show()

In [None]:
"""
Plot PCA + UMAP, but this time include token samples from outside this expert
"""
nontest_sample_df =\
    get_sample_df_for_layer(sample_df, topk_df, test_layer_ix)\
    .pipe(lambda df: df[df['expert'] != test_expert])

nontest_pre_mlp_hs = all_pre_mlp_hs[test_layer_ix][nontest_sample_df['sample_ix'].tolist()]

# Prep all_samples_df samples df
all_samples_df = pd.concat([
    test_sample_df.assign(prev_expert = lambda df: df['prev_expert'].astype(str)), 
    nontest_sample_df.assign(prev_expert = 'NA')
]).reset_index(drop = True)

pca_res = reduce_pca(torch.concat([test_pre_mlp_hs, nontest_pre_mlp_hs], dim = 0), 2)
pca_plot_df =\
    all_samples_df.assign(d1 = pca_res[:, 0], d2 = pca_res[:, 1])\
    .pipe(lambda df: pd.concat([
        df[df['prev_expert'] == 'NA'].sample(10_000),
        df[df['prev_expert'] != 'NA'].sample(10_000)
    ]))

combine_plots([
    plot_manifold(pca_plot_df, 'prev_expert', 'token', title = 'By previous expert'),
    plot_manifold(pca_plot_df, 'source', 'token', title = 'By source')
], f"PCA for layer {test_layer_ix}, expert {test_expert}").show()

## Cross-layer source mappings

In [None]:
"""
Plot cross-layer source mappings
"""
layers_to_test = [0, 2, 7, 13, 15]

test_layers = [
    {'layer_ix': layer_ix, 'sample_df': get_sample_df_for_layer(sample_df, topk_df, layer_ix)}
    for layer_ix in layers_to_test
]

test_layers[0]

In [None]:
"""
H groupings
"""
for test_layer in test_layers:
    pca_res = reduce_pca(all_pre_mlp_hs[test_layer['layer_ix']], 2)

    source_plot_df =\
        test_layer['sample_df'].assign(d1 = pca_res[:, 0], d2 = pca_res[:, 1])\
        .sample(2500)\
        .sort_values(by = 'source')

    expert_plot_df =\
        pd.concat([pd.DataFrame({'d1': pca_res[:, 0], 'd2': pca_res[:, 1]}), test_layer['sample_df']], axis = 1)\
        .pipe(lambda df: df[df['expert'].isin(list(range(0, 5)))])\
        .assign(expert = lambda df: df['expert'].astype(str))\
        .sample(2500)\
        .sort_values(by = 'expert')
    
    combine_plots([
        plot_manifold(expert_plot_df, 'expert', 'token', title = 'By expert'),
        plot_manifold(source_plot_df, 'source', 'token', title = 'By source')
    ], f"<em>H<sub>{str(test_layer['layer_ix'])}</sub></em>").show()

In [None]:
"""
H_orth vs H_para by expert/source
"""
for test_layer in test_layers:

    h_para, h_orth = decompose_orthogonal(
        all_pre_mlp_hs[test_layer['layer_ix']].to(torch.float32),
        model.model.layers[test_layer['layer_ix']].mlp.gate.weight.to(torch.float32).detach().cpu(),
        method = 'svd'
    )
    h_para = h_para.to(torch.float32)
    h_orth = h_orth.to(torch.float32)
    
    h_para_pca_res = reduce_pca(h_para, 2)
    h_orth_pca_res = reduce_pca(h_orth, 2)
    
    h_para_plot_df =\
        test_layer['sample_df'].assign(d1 = h_para_pca_res[:, 0], d2 = h_para_pca_res[:, 1])\
        .sample(2500)\
        .sort_values(by = 'expert')

    h_orth_plot_df =\
        test_layer['sample_df'].assign(d1 = h_orth_pca_res[:, 0], d2 = h_orth_pca_res[:, 1])\
        .sample(2500)\
        .sort_values(by = 'expert')
    
    combine_plots([
        plot_manifold(h_para_plot_df, 'source', 'token', f"<em>H<sub>para</sub>({str(test_layer['layer_ix'])})</em>"),
        plot_manifold(h_orth_plot_df, 'source', 'token', f"<em>H<sub>orth</sub>({str(test_layer['layer_ix'])})</em>")
    ]).show() 

## Functional Specialization

In [7]:
"""
Print available layers
"""
list(all_expert_outputs.keys())

In [203]:
"""
Get test samples (pre and post-MLP) for experts where expert was top-1
"""
test_layer_ix = 7

test_sample_df = get_sample_df_for_layer(sample_df, topk_df, test_layer_ix)

test_exp_inputs = all_pre_mlp_hs[test_layer_ix][:, :]
test_t1_exp_outputs = all_expert_outputs[test_layer_ix][:, 0, :]

test_deltas = (test_t1_exp_outputs - test_exp_inputs)
print(f"{test_exp_inputs.shape} | {test_t1_exp_outputs.shape}")

In [182]:
"""
Show PCA of functional transformations with various groupings
"""
layer_pca = reduce_pca(test_deltas, 2)
test_sample_df_with_pca = test_sample_df.assign(d1 = layer_pca[:, 0], d2 = layer_pca[:, 1])

# Experts to test
test_experts =\
    test_sample_df\
    .groupby('expert', as_index = False).agg(n = ('expert', 'count'))\
    .sort_values('n', ascending = False)\
    .head(5)['expert'].tolist()

# Plot across experts
pca_plot_df =\
    test_sample_df_with_pca.pipe(lambda df: df[df['expert'].isin(test_experts)])\
    .sample(5_000)\
    .assign(expert = lambda df: df['expert'].astype(str))

combine_plots([
    plot_manifold(pca_plot_df.sort_values('expert'), 'expert', 'token', title = 'Deltas by expert'),
    plot_manifold(pca_plot_df.sort_values('source'), 'source', 'token', title = 'Deltas by source')
], title = f'Layer {str(test_layer_ix)}').show()

# Plot within experts
for test_expert in sorted(test_experts):

    test_expert_sample_df = test_sample_df_with_pca.pipe(lambda df: df[df['expert'] == test_expert])
    
    source_plot_df =\
        test_expert_sample_df\
        .sample(5_000)\
        .assign(prev_expert = lambda df: df['prev_expert'].astype(str))

    top_prev_experts =\
        test_expert_sample_df\
        .groupby('prev_expert', as_index = False).agg(n = ('prev_expert', 'count'))\
        .sort_values('n', ascending = False)\
        .head(5)['prev_expert'].tolist()

    prev_experts_plot_df =\
        test_expert_sample_df\
        .pipe(lambda df: df[df['prev_expert'].isin(top_prev_experts)])\
        .sample(5_000)\
        .assign(prev_expert = lambda df: df['prev_expert'].astype(str))

    top_expert_2s =\
        test_expert_sample_df\
        .groupby('expert2', as_index = False).agg(n = ('expert2', 'count'))\
        .sort_values('n', ascending = False)\
        .head(5)['expert2'].tolist()

    expert_2s_plot_df =\
        test_expert_sample_df\
        .pipe(lambda df: df[df['expert2'].isin(top_expert_2s)])\
        .sample(5_000)\
        .assign(expert2 = lambda df: df['expert2'].astype(str))
    
    combine_plots([
        plot_manifold(prev_experts_plot_df.sort_values('prev_expert'), 'prev_expert', 'token', title = 'Deltas by prev expert'),
        plot_manifold(expert_2s_plot_df.sort_values('expert2'), 'expert2', 'token', title = 'Deltas by topk = 2 expert'),
        plot_manifold(source_plot_df.sort_values('source'), 'source', 'token', title = 'Deltas by source'),
    ], title = f'PCA for layer {str(test_layer_ix)}, expert {str(test_expert)}', cols = 3).show()

In [183]:
"""
Calculate cosine similarity across languages
"""
grped_deltas = []

# Iterate over groups
for grp_val in sorted(test_sample_df['source'].unique().tolist()):
    
    this_grp_sample_indices = test_sample_df[test_sample_df['source'] == grp_val]['sample_ix'].tolist()
    if len(this_grp_sample_indices) <= 10:
        continue
        
    this_grp_deltas = test_deltas[this_grp_sample_indices, :]
    
    grped_deltas.append({
        'grp': grp_val,
        'grped_vals': this_grp_deltas.mean(dim = 0)
    })

cos_sim = sklearn.metrics.pairwise.cosine_similarity(
    torch.stack([x['grped_vals'] for x in grped_deltas], dim = 0).numpy()
)

upper_triangle_indices = np.triu_indices(cos_sim.shape[0], k = 1)
off_diagonal_values = cos_sim[upper_triangle_indices]
np.mean(np.abs(off_diagonal_values)).item()

In [184]:
"""
Calculate cosine similarity across experts
"""
grped_deltas = []

# Iterate over groups
for grp_val in sorted(test_sample_df['expert'].unique().tolist()):
    
    this_grp_sample_indices = test_sample_df[test_sample_df['expert'] == grp_val]['sample_ix'].tolist()
    if len(this_grp_sample_indices) <= 10:
        continue
        
    this_grp_deltas = test_deltas[this_grp_sample_indices, :]
    
    grped_deltas.append({
        'grp': grp_val,
        'grped_vals': this_grp_deltas.mean(dim = 0)
    })

cos_sim = sklearn.metrics.pairwise.cosine_similarity(
    torch.stack([x['grped_vals'] for x in grped_deltas], dim = 0).numpy()
)
upper_triangle_indices = np.triu_indices(cos_sim.shape[0], k = 1)
off_diagonal_values = cos_sim[upper_triangle_indices]
np.mean(np.abs(off_diagonal_values)).item()

In [210]:
"""
Calculate cosine similarity across expert-source combinations
"""
grped_deltas = []

col = 'prev_expert' # Use source, prev_expert, expert2

exp_src_combinations =\
    test_sample_df\
    .groupby(['expert', col], as_index = False)\
    .agg(
        size = ('sample_ix', 'count'),
        sample_indices = ('sample_ix', list)
    )\
    .pipe(lambda df: df[df['size'] > 100])

means = []
for comb in exp_src_combinations.to_dict('records'):
    means.append(test_deltas[comb['sample_indices']].mean(dim = 0).numpy())

exp_src_combinations = exp_src_combinations.assign(mean_vec = means)

In [211]:
"""
Across expert within lang, across lang within expert
"""
src_sims = {}
for src in exp_src_combinations[col].unique().tolist():
    all_vecs_for_src = exp_src_combinations.pipe(lambda df: df[df[col] == src])['mean_vec'].tolist()
    if len(all_vecs_for_src) < 3:
        continue
    cos_sim = sklearn.metrics.pairwise_distances(np.stack(all_vecs_for_src, axis = 0), metric = 'cosine')
    off_diag = cos_sim[np.triu_indices(cos_sim.shape[0], k = 1)]
    src_sims[src] = 1 - np.mean(np.abs(off_diag)).item()
# print(src_sims)

exp_sims = {}
for exp in exp_src_combinations['expert'].unique().tolist():
    all_vecs_for_exp = exp_src_combinations.pipe(lambda df: df[df['expert'] == exp])['mean_vec'].tolist()
    if len(all_vecs_for_exp) < 3:
        continue
    cos_sim = sklearn.metrics.pairwise_distances(np.stack(all_vecs_for_exp, axis = 0), metric = 'cosine')
    off_diag = cos_sim[np.triu_indices(cos_sim.shape[0], k = 1)]
    exp_sims[exp] = 1 - np.mean(np.abs(off_diag)).item()
# print(exp_sims)

mean_src_sim = np.mean([v for _, v in src_sims.items()]).round(4)
mean_exp_sim = np.mean([v for _, v in exp_sims.items()]).round(4)

print(f"Average delta sim across experts within {col} {mean_src_sim}")
print(f"Average delta sim across {col} within experts {mean_exp_sim}")
if mean_src_sim > mean_exp_sim:
    print(f'Delta is dominated by {col} over expert')
else:
    print(f'Delta is dominated by expert over {col}')

In [207]:
from collections import defaultdict

# Assuming test_sample_df and test_deltas are pre-loaded and aligned
# test_sample_df: DataFrame with columns 'sample_ix', 'expert', 'prev_expert', 'source', etc.
# test_deltas: Tensor (N, D) of delta vectors corresponding to test_sample_df

def get_baselines(sample_df, delta_vectors, grouping_col, n_shuffles: int = 5, min_samples_per_group: int = 100, min_groups_for_sim_calc: int = 3) -> tuple[float, float]:
    """
    Calculates the average shuffled baseline similarities over multiple runs.

    Params:
        @sample_df: Dataframe linking samples to experts and the grouping_col
        @delta_vectors: n_samples x D of tensor corresponding to sample_df
        @grouping_col: The column name to group by ('prev_expert', 'source', 'etc').
        @n_shuffles: Number of shuffling iterations.
        @min_samples_per_group: Min samples required to compute a group mean.
        @min_groups_for_sim_calc: Min number of mean vectors required to calculate an average similarity.

    Returns:
        A tuple containing:
        - avg_baseline_sim_across_experts: Baseline for "sim across experts within col".
        - avg_baseline_sim_across_col: Baseline for "sim across col within experts".
    """
    baseline_sims_across_experts = []
    baseline_sims_across_col = []

    # Ensure float32 for calculations
    delta_vectors_np = delta_vectors.float().cpu().numpy()

    print(f"Calculating baselines for grouping col: '{grouping_col}'...")
    for i in tqdm(range(n_shuffles), desc="Shuffle Runs"):
        # --- Baseline 1: Shuffle Expert labels WITHIN each grouping_col group ---
        df_shuffled_experts = sample_df.copy()
        # Shuffle 'expert' conditionally within each group defined by 'grouping_col'
        df_shuffled_experts['shuffled_expert'] = df_shuffled_experts.groupby(grouping_col)['expert']\
            .transform(lambda x: x.sample(frac=1, random_state=i).values)

        # Calculate mean deltas based on (shuffled_expert, grouping_col)
        group_means_shuffled_e = df_shuffled_experts.groupby(['shuffled_expert', grouping_col])['sample_ix']\
            .apply(lambda idx: delta_vectors_np[idx.tolist()].mean(axis=0) if len(idx) >= min_samples_per_group else np.nan)\
            .dropna()

        # Calculate average similarity across SHUFFLED experts within each grouping_col
        sims_within_col = []
        for current_col_val in df_shuffled_experts[grouping_col].unique():
            # Get mean vectors for this specific col value, indexed by shuffled_expert
            means_for_col = group_means_shuffled_e.loc[:, current_col_val] # Selects multi-index level
            means_for_col = means_for_col.dropna() # Ensure we only use valid means
            if len(means_for_col) >= min_groups_for_sim_calc:
                vec_stack = np.stack(means_for_col.values)
                # Cosine distance -> Cosine similarity = 1 - distance
                cos_sim_matrix = 1 - sklearn.metrics.pairwise_distances(vec_stack, metric='cosine')
                # Get upper triangle (excluding diagonal) for pairwise similarities
                upper_triangle_indices = np.triu_indices(cos_sim_matrix.shape[0], k=1)
                if len(upper_triangle_indices[0]) > 0: # Ensure there are pairs
                    sims_within_col.append(np.mean(np.abs(cos_sim_matrix[upper_triangle_indices])))

        if sims_within_col: # Only average if we got results for at least one col group
             baseline_sims_across_experts.append(np.mean(sims_within_col))

        # --- Baseline 2: Shuffle grouping_col labels WITHIN each expert group ---
        df_shuffled_col = sample_df.copy()
        # Shuffle 'grouping_col' conditionally within each group defined by 'expert'
        df_shuffled_col['shuffled_col'] = df_shuffled_col.groupby('expert')[grouping_col]\
            .transform(lambda x: x.sample(frac=1, random_state=i).values)

        # Calculate mean deltas based on (expert, shuffled_col)
        group_means_shuffled_c = df_shuffled_col.groupby(['expert', 'shuffled_col'])['sample_ix']\
            .apply(lambda idx: delta_vectors_np[idx.tolist()].mean(axis=0) if len(idx) >= min_samples_per_group else np.nan)\
            .dropna()

        # Calculate average similarity across SHUFFLED col within each expert
        sims_within_expert = []
        for current_expert_val in df_shuffled_col['expert'].unique():
             # Get mean vectors for this specific expert value, indexed by shuffled_col
            means_for_expert = group_means_shuffled_c.loc[current_expert_val, :] # Selects multi-index level
            means_for_expert = means_for_expert.dropna() # Ensure valid means
            if len(means_for_expert) >= min_groups_for_sim_calc:
                vec_stack = np.stack(means_for_expert.values)
                cos_sim_matrix = 1 - sklearn.metrics.pairwise_distances(vec_stack, metric='cosine')
                upper_triangle_indices = np.triu_indices(cos_sim_matrix.shape[0], k=1)
                if len(upper_triangle_indices[0]) > 0:
                    sims_within_expert.append(np.mean(np.abs(cos_sim_matrix[upper_triangle_indices])))

        if sims_within_expert: # Only average if we got results for at least one expert group
            baseline_sims_across_col.append(np.mean(sims_within_expert))


    # Calculate overall averages, handling cases where shuffles might fail
    avg_baseline_sim_across_experts = np.mean(baseline_sims_across_experts) if baseline_sims_across_experts else np.nan
    avg_baseline_sim_across_col = np.mean(baseline_sims_across_col) if baseline_sims_across_col else np.nan

    print("Baseline calculation finished.")
    return avg_baseline_sim_across_experts, avg_baseline_sim_across_col


# --- Example Usage ---
# Assume test_sample_df and test_deltas are loaded

# Calculate baseline for 'prev_expert' grouping
bl_across_experts_pe, bl_across_col_pe = get_baselines(test_sample_df, test_deltas, grouping_col = 'prev_expert', n_shuffles = 5)
print(f"\nBaseline for grouping_col='prev_expert':")
print(f"  Avg Baseline Sim Across Experts (within prev_expert): {bl_across_experts_pe:.4f}")
print(f"  Avg Baseline Sim Across prev_expert (within expert): {bl_across_col_pe:.4f}")

# Calculate baseline for 'source' grouping
bl_across_experts_src, bl_across_col_src = get_baselines(test_sample_df, test_deltas, grouping_col = 'source', n_shuffles = 5)
print(f"\nBaseline for grouping_col='source':")
print(f"  Avg Baseline Sim Across Experts (within source): {bl_across_experts_src:.4f}")
print(f"  Avg Baseline Sim Across source (within expert): {bl_across_col_src:.4f}")


In [232]:
"""
Logistic Regression
"""
def run_lr(x_cp, y_cp):
    x_train, x_test, y_train, y_test = cuml.train_test_split(x_cp, y_cp, test_size = 0.1, random_state = 123)
    lr_model = cuml.linear_model.LogisticRegression(penalty = 'l2', max_iter = 1000, fit_intercept = False)
    lr_model.fit(x_train, y_train)
    accuracy = lr_model.score(x_test, y_test)
    proba = lr_model.predict_proba(x_test)
    nll = -1 * cuml.metrics.log_loss(y_test, proba) # higher is better
    print(f"–log-loss: {float(cupy.asnumpy(nll)):.4f}")

x_cp = cupy.asarray(test_deltas)

run_lr(x_cp, cupy.asarray(test_sample_df['expert'].tolist()))
run_lr(x_cp, cupy.asarray(test_sample_df['prev_expert'].tolist()))
_, source_codes = np.unique(test_sample_df['source'].tolist(), return_inverse = True)
run_lr(x_cp, source_codes)

## SVD clustering: sideways -> h_orth
1. Decompose the pre-mlp hidden states into h_sideways WITHIN each group of activations that route to a single expert, with respect to just the D-dimensional routing gate for that single expert. Then, still within that single expert, cluster those h_sideways activations.
2. Repeat across all experts.
3. After obtaining these cluters, do h_orth (the regular decomposition using all activations with respect to the entire routing gate).
4. Extract calculate the cluster centroids from h_orth, using the cluster ids/labels extracted from h_sideways earlier. This results in n_experts * n_clusters_per_expert cluster centers.
5. Calculate (using cosine similarity?) the within-expert against across-expert averages.

In [None]:
# import importlib, utils.svd as svd
# decompose_sideways = importlib.reload(svd).decompose_sideways

In [None]:
def cluster_kmeans(layer_hs: torch.Tensor, n_clusters = 512):
    kmeans_model = cuml.cluster.KMeans(n_clusters = n_clusters, max_iter = 1000, random_state = 123)
    kmeans_model.fit(cupy.asarray(layer_hs.to(torch.float32)))
    clear_all_cuda_memory(False)
    cluster_ids = kmeans_model.labels_.tolist() # n_samples
    cluster_centers = kmeans_model.cluster_centers_ # (n_clusters, D)
    return cluster_ids, cluster_centers

cluster_kmeans(relevant_pre_mlp_hs)

In [None]:
"""
Prepare sample-level df merged with top-1 expert selections for a single test layer
"""

test_layer_ix = 1

sample_df_test =\
    sample_df\
    .merge(
        topk1_df.pipe(lambda df: df[df['layer_ix'] == test_layer_ix])[['expert', 'sample_ix']],
        how = 'inner',
        on = 'sample_ix'
    )\
    .merge(
        topk1_df.pipe(lambda df: df[df['layer_ix'] == test_layer_ix - 1]).rename(columns = {'expert': 'prev_expert'})[['sample_ix', 'prev_expert']],
        how = 'inner',
        on = 'sample_ix'
    )

sample_df_test

In [None]:
"""
Extract the sideways decomposition within activations routed to a single expert; this specifically REMOVES the part of h directly 
responsible for increasing/decreasing logit specifically for that expert; then cluster them
"""
cluster_ids = torch.full([all_pre_mlp_hs[test_layer_ix].shape[0]], -1, dtype = torch.int32)

for this_expert in tqdm(sorted(sample_df_test['expert'].unique().tolist())):

    # Extract sample indices for expert
    this_sample_indices = sample_df_test[sample_df_test['expert'] == this_expert]['sample_ix'].tolist()
    
    # D-dimensional routing gate for expert route
    this_gate = model.model.layers[test_layer_ix].mlp.gate.weight[this_expert, :].to(torch.float32).detach().cpu()
    
    # Remove only this expert’s axis to expose sub‑clusters
    _, h_side = decompose_sideways(all_pre_mlp_hs[test_layer_ix][this_sample_indices], this_gate)

    # Cluster within expert
    this_cluster_ids, _ = cluster_kmeans(h_side, n_clusters = 10)
    cluster_ids[this_sample_indices] = torch.tensor(this_cluster_ids, dtype = cluster_ids.dtype)

sample_df_test_cl = sample_df_test.assign(cluster_id = cluster_ids)

In [None]:
"""
Go back to the original decomposition to get the regular h_orth
"""
_, h_orth = decompose_orthogonal(
    all_pre_mlp_hs[test_layer_ix].to(torch.float32),
    model.model.layers[test_layer_ix].mlp.gate.weight.to(torch.float32).detach().cpu(),
    method = 'svd'
)
h_orth = h_orth.to(torch.float32)

In [None]:
"""
Apply the sub-cluster labels obtained from clustering h_sideways to the corresponding h_orth vectors
This allows us to compare things in h_orth space; then compare cosine similarity. 
"""
centroids = [] # List of np centroids
tags = [] # (expert, cluster)

for this_expert in sorted(sample_df_test_cl['expert'].unique().tolist()):
    for this_cluster in [x for x in sorted(sample_df_test_cl['cluster_id'].unique().tolist()) if x != -1]: # Get clusters
        this_e_c_sample_indices =\
            sample_df_test_cl\
            .pipe(lambda df: df[(df['expert'] == this_expert) & (df['cluster_id'] == this_cluster)])\
            ['sample_ix'].tolist()
        
        if len(this_e_c_sample_indices) <= 50: # Throw out tiny clusters
            continue

        v = h_orth[this_e_c_sample_indices].mean(0)
        v = v / v.norm() # normalise for cosine
        centroids.append(v)
        tags.append((int(this_expert), int(this_cluster)))

centroids = torch.stack(centroids) # this_expert * K x D
cosine_sims = sklearn.metrics.pairwise.cosine_similarity(centroids.numpy()) # this_expert * K x this_expert * K; pairwise sim between each expert-cluster pair

within, across = [], []
for i, (e_i, _) in enumerate(tags):
    for j, (e_j, _) in enumerate(tags):
        if i >= j: continue # upper‑tri only
        (within if e_i==e_j else across).append(cosine_sims[i, j])

mean_cos_within  = float(np.mean(within))
mean_cos_across  = float(np.mean(across))
print(f"mean cosine  (within expert):  {mean_cos_within:.3f}")
print(f"mean cosine  (across experts): {mean_cos_across:.3f}")