In [None]:
"""
This contains code to understand how routing correlates to the hidden state. Requires that `export-data/export-activations.ipynb` has been run.
"""
None

In [None]:
"""
Imports
"""
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import pandas as pd
import numpy as np
import scipy
import cupy
import cuml
import sklearn

import importlib
import gc
import pickle
import os

from tqdm import tqdm
from termcolor import colored
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.io as pio

from utils.memory import check_memory, clear_all_cuda_memory
from utils.quantize import compare_bf16_fp16_batched
from utils.vis import combine_plots

main_device = 'cuda:0'
seed = 1234
clear_all_cuda_memory()
check_memory()

## Load model & data

In [None]:
"""
Load the base tokenizer/model
"""
model_ix = 2
models_list = [
    ('allenai/OLMoE-1B-7B-0125-Instruct', 'olmoe', 0),
    ('Qwen/Qwen1.5-MoE-A2.7B-Chat', 'qwen1.5moe', 0),
    ('deepseek-ai/DeepSeek-V2-Lite', 'dsv2', 1),
    ('Qwen/Qwen3-30B-A3B', 'qwen3moe', 0)
]

def load_model_and_tokenizer(model_id, model_prefix):
    local_path = os.path.join('/workspace/models', model_prefix)
    model_path = local_path if os.path.exists(local_path) else model_id    
    tokenizer = AutoTokenizer.from_pretrained(model_path, add_eos_token = False, add_bos_token = False, padding_side = 'left', trust_remote_code = True)
    model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype = torch.bfloat16, trust_remote_code = True).cuda().eval()    
    if model_path == model_id:
        os.makedirs(local_path, exist_ok = True)
        tokenizer.save_pretrained(local_path)
        model.save_pretrained(local_path)
    return tokenizer, model

model_id, model_prefix, model_pre_mlp_layers = models_list[model_ix]
tokenizer, model = load_model_and_tokenizer(model_id, model_prefix)

In [None]:
"""
Load dataset
"""
def load_data(model_prefix, max_data_files):
    """
    Load data saved by `export-activations-sm.ipynb`
    """
    folders = [f'./../export-data/activations-sm/{model_prefix}/{i:02d}' for i in range(max_data_files)]
    folders = [f for f in folders if os.path.isdir(f)]

    all_pre_mlp_hs = []
    sample_df = []
    topk_df = []

    for f in tqdm(folders):
        sample_df.append(pd.read_pickle(f'{f}/samples.pkl'))
        topk_df.append(pd.read_pickle(f'{f}/topks.pkl'))
        all_pre_mlp_hs.append(torch.load(f'{f}/all-pre-mlp-hidden-states.pt'))

    sample_df = pd.concat(sample_df)
    topk_df = pd.concat(topk_df)
    all_pre_mlp_hs = torch.concat(all_pre_mlp_hs)    

    with open(f'./../export-data/activations-sm/{model_prefix}/metadata.pkl', 'rb') as f:
        metadata = pickle.load(f)
    
    gc.collect()
    return sample_df, topk_df, all_pre_mlp_hs, metadata['all_pre_mlp_hidden_states_layers']

sample_df_import, topk_df_import, all_pre_mlp_hs_import, act_map = load_data(model_prefix, 2)

In [None]:
"""
Let's clean up the mappings here. We'll get everything to a sample_ix level first.
"""
sample_df_raw =\
    sample_df_import\
    .assign(sample_ix = lambda df: df.groupby(['batch_ix', 'sequence_ix', 'token_ix']).ngroup())\
    .assign(seq_id = lambda df: df.groupby(['batch_ix', 'sequence_ix']).ngroup())\
    .reset_index()

topk_df =\
    topk_df_import\
    .merge(sample_df_raw[['sample_ix', 'batch_ix', 'sequence_ix', 'token_ix']], how = 'inner', on = ['sequence_ix', 'token_ix', 'batch_ix'])\
    .drop(columns = ['sequence_ix', 'token_ix', 'batch_ix'])\
    .assign(layer_ix = lambda df: df['layer_ix'] + model_pre_mlp_layers)

topk1_df =\
    topk_df\
    .pipe(lambda df: df[df['topk_ix'] == 1])

sample_df =\
    sample_df_raw\
    .drop(columns = ['batch_ix', 'sequence_ix'])

def get_sample_df_for_layer(sample_df, topk_df, layer_ix):
    """
    Helper to take the sample df and merge layer-level expert selection information
    """
    topk_layer_df = topk_df.pipe(lambda df: df[df['layer_ix'] == layer_ix])
    topk_l1_layer_df = topk_df.pipe(lambda df: df[df['layer_ix'] == layer_ix - 1])
    topk_l2_layer_df = topk_df.pipe(lambda df: df[df['layer_ix'] == layer_ix - 2])

    layer_df =\
        sample_df\
        .merge(topk_layer_df.pipe(lambda df: df[df['topk_ix'] == 1])[['sample_ix', 'expert']], how = 'inner', on = 'sample_ix')\
        .merge(topk_l1_layer_df.pipe(lambda df: df[df['topk_ix'] == 1]).rename(columns = {'expert': 'prev_expert'})[['sample_ix', 'prev_expert']], how = 'left', on = 'sample_ix')\
        .merge(topk_l2_layer_df.pipe(lambda df: df[df['topk_ix'] == 1]).rename(columns = {'expert': 'prev2_expert'})[['sample_ix', 'prev2_expert']], how = 'left', on = 'sample_ix')\
        .merge(topk_layer_df.pipe(lambda df: df[df['topk_ix'] == 2]).rename(columns = {'expert': 'expert2'})[['sample_ix', 'expert2']], how = 'left', on = 'sample_ix')\
        .assign(leading_path = lambda df: df['prev2_expert'] + '-' + df['prev_expert'])
    
    return layer_df

del sample_df_import, sample_df_raw, topk_df_import

gc.collect()
display(topk_df)
display(sample_df)

In [None]:
"""
Convert activations to fp16 (for compatibility with cupy later) + dict
"""
all_pre_mlp_hs = all_pre_mlp_hs_import.to(torch.float16)
# compare_bf16_fp16_batched(all_pre_mlp_hs_import, all_pre_mlp_hs)
del all_pre_mlp_hs_import
all_pre_mlp_hs = {(layer_ix + model_pre_mlp_layers): all_pre_mlp_hs[:, save_ix, :] for save_ix, layer_ix in enumerate(act_map)}

gc.collect()

## Analyze routing weights

In [None]:
"""
Norms by expert and layer
"""
norms_by_expert_layer = pd.concat([
    pd.DataFrame({
        'layer_ix': layer_ix,
        'norm': torch.linalg.norm(model.model.layers[layer_ix].mlp.gate.weight, dim = 1, ord = 1).to(torch.float16).cpu().detach().numpy(),
        'expert': list(range(1, model.model.layers[layer_ix].mlp.gate.weight.shape[0] + 1))
    })
    for layer_ix in list(all_pre_mlp_hs.keys())
])

plot_df = norms_by_expert_layer.pivot(index = 'layer_ix', columns = 'expert', values = 'norm')
px.imshow(
    plot_df,
    x = plot_df.columns, y = plot_df.index,
    color_continuous_scale = 'ylgnbu',
    labels = {'color': 'Norm'}, title = "Norm by Expert and Layer"
).update_layout(autosize = False, width = 800).show()

scaled_df =\
    norms_by_expert_layer\
    .assign(layer_mean = lambda df: df.groupby('layer_ix')['norm'].transform('mean'))\
    .assign(norm_scaled = lambda df: df['norm'] / df['layer_mean'] - 1)

scaled_plot_df = scaled_df.pivot(index = 'layer_ix', columns = 'expert', values = 'norm_scaled')
px.imshow(
    scaled_plot_df,
    x = scaled_plot_df.columns, y = scaled_plot_df.index,
    color_continuous_scale = 'ylgnbu',
    labels = {'color': 'Norm'}, title = "Norm by Expert and Layer"
).update_layout(autosize = False, width = 800).show()

In [None]:
"""
For a single layer, what do the weights and RMSnorms look like?
"""
plot_layer_ix = 9
show_dims = list(range(0, 400))

# RMSNorm
rms_tensor = model.model.layers[plot_layer_ix].post_attention_layernorm.weight
rms_df = pd.DataFrame({
    'gamma': rms_tensor.to(torch.float16).cpu().detach().numpy(),
    'coef': 1,
    'dimension': list(range(0, rms_tensor.shape[0]))
})
plot_df = rms_df.pipe(lambda df: df[df['dimension'].isin(show_dims)]).pivot(index = 'coef', columns = 'dimension', values = 'gamma')

px.imshow(
    plot_df,
    x = plot_df.columns, y = plot_df.index,
    aspect = 'auto', color_continuous_scale = 'ylgnbu',
    labels = {'color': 'Norm'}, title = "RMSNorm Scaling Values"
).update_layout(autosize = False, width = 1400, height = 400).show()

# Weights
wt_tensor = model.model.layers[plot_layer_ix].mlp.gate.weight
wt_df = pd.DataFrame({
    'value': wt_tensor.view(-1).to(torch.float16).cpu().detach().numpy(),
    'expert': [i // wt_tensor.shape[1] for i in range(wt_tensor.view(-1).shape[0])],
    'dimension': [i % wt_tensor.shape[1] for i in range(wt_tensor.view(-1).shape[0])]
})

plot_df = wt_df.pipe(lambda df: df[df['dimension'].isin(show_dims)]).pivot(index = 'expert', columns = 'dimension', values = 'value')

px.imshow(
    plot_df,
    x = plot_df.columns, y = plot_df.index,
    aspect = 'auto', color_continuous_scale = 'ylgnbu',
    labels = {'color': 'Norm'}, title = "Routing Weights"
).update_layout(autosize = False, width = 1400).show()

# Scale weights by RMSNorm
scaled_df = wt_df.merge(rms_df, on = 'dimension', how = 'inner').assign(gamma_scaled_value = lambda df: df['gamma'] * df['value'])
plot_df = scaled_df.pipe(lambda df: df[df['dimension'].isin(show_dims)]).pivot(index = 'expert', columns = 'dimension', values = 'gamma_scaled_value')
px.imshow(
    plot_df,
    x = plot_df.columns, y = plot_df.index,
    aspect = 'auto', color_continuous_scale = 'ylgnbu',
    labels = {'color': 'Norm'}, title = "Scaled Routing Weights"
).update_layout(autosize = False, width = 1400).show()

In [None]:
"""
Mean norms across layers and dimension (averaged across experts)
"""
dfs_list = []
for layer_ix in list(all_pre_mlp_hs.keys()):
    wt_tensor = model.model.layers[layer_ix].mlp.gate.weight.to(torch.float16).cpu().detach()
    rms_tensor = model.model.layers[layer_ix].post_attention_layernorm.weight.to(torch.float16).cpu().detach()
    scaled = (wt_tensor * rms_tensor) # Multiply by RMS norm
    scaled = scaled.abs().mean(dim = 0) # Take mean L1 norm
    dfs_list.append(pd.DataFrame({
        'mean_norm': scaled.numpy(),
        'layer_ix': layer_ix,
        'dim': list(range(1, scaled.shape[0] + 1))
    }))

my_df = pd.concat(dfs_list)
# Additionally scale by layer average
my_df_ex_scale =\
    my_df\
    .assign(layer_mean = lambda df: df.groupby('layer_ix')['mean_norm'].transform('mean'))\
    .assign(mean_norm = lambda df: df['mean_norm'] / df['layer_mean'])

plot_df = my_df_ex_scale.pipe(lambda df: df[df['dim']  <= 200]).pivot(index = 'layer_ix', columns = 'dim', values = 'mean_norm')

px.imshow(
    plot_df,
    x = plot_df.columns, y = plot_df.index,
    zmin = 0, zmax = 8,
    aspect = 'auto', # Allow non-square boxes
    color_continuous_scale = 'ylgnbu',
    labels = {'color': 'Norm'}, title = "Mean norms by dimension and layer"
).update_layout(autosize = False, width = 1400, coloraxis = dict()).show()

In [None]:
"""
At dimension x layer-level, analyze activations (averaged across samples) versus routing weights (averaged across experts).
"""
show_dims = list(range(0, 800))

dfs_list = []
for layer_ix, pre_mlp_for_layer in tqdm(all_pre_mlp_hs.items()):
    wt_tensor = model.model.layers[layer_ix].mlp.gate.weight[:, :].to(torch.float16).cuda().detach() # (n_experts, D)
    act_tensor = all_pre_mlp_hs[layer_ix].abs().mean(dim = 0).cuda().detach() # n_samples x D => D via L1 norm
    scaled = (wt_tensor * act_tensor) # Multiply by activation tensor
    scaled = scaled.abs().mean(dim = 0) # Take mean L1 norm
    dfs_list.append(pd.DataFrame({
        'layer_ix': layer_ix,
        'act_norm': act_tensor.cpu().numpy(), # D,
        'wt_norm': wt_tensor.abs().mean(dim = 0).cpu().numpy(), # n_experts x D => D,
        'mean_scaled_norm': scaled.cpu().numpy(),
        'dim': list(range(1, scaled.shape[0] + 1)) # show_dims
    }))

pre_mlp_df = pd.concat(dfs_list)
del dfs_list

plot_df = pre_mlp_df.pipe(lambda df: df[df['dim'].isin(show_dims)]).pivot(index = 'layer_ix', columns = 'dim', values = 'mean_scaled_norm')
px.imshow(
    plot_df,
    x = plot_df.columns, y = plot_df.index,
    zmin = 0,
    zmax = .2,
    aspect = 'auto', color_continuous_scale = 'ylgnbu',
    labels = {'color': 'Norm'}, title = "Mean scaled wt * activation norms by dimension and layer"
).update_layout(autosize = False, width = 1400, coloraxis = dict()).show()

plot_df = pre_mlp_df.pipe(lambda df: df[df['dim'].isin(show_dims)]).pivot(index = 'layer_ix', columns = 'dim', values = 'act_norm')
px.imshow(
    plot_df,
    zmin = 0, zmax = 8,
    x = plot_df.columns, y = plot_df.index,
    aspect = 'auto', color_continuous_scale = 'ylgnbu',
    labels = {'color': 'Norm'}, title = "Mean activation norms by dimension and layer"
).update_layout(autosize = False, width = 1400, coloraxis = dict()).show()

plot_df = pre_mlp_df.pipe(lambda df: df[df['dim'].isin(show_dims)]).pivot(index = 'layer_ix', columns = 'dim', values = 'wt_norm')
px.imshow(
    plot_df,
    x = plot_df.columns, y = plot_df.index,
    aspect = 'auto', color_continuous_scale = 'ylgnbu',
    labels = {'color': 'Norm'}, title = "Mean weight norms by dimension and layer"
).update_layout(autosize = False, width = 1400, coloraxis = dict()).show()
# scipy.stats.kurtosis(pre_mlp_df.pipe(lambda df: df[df['layer_ix'] == 6]['act_norm'].tolist() ))

In [None]:
"""
Build and export the correlation plot
"""
font_size = 22

def get_layer_acts_and_wts(pre_mlp_df, layer_ix):
    layer_df =\
        pre_mlp_df.pipe(lambda df: df[df['layer_ix'] == layer_ix])\
        .pipe(lambda df: df[(df['wt_norm'] > 0) & (df['act_norm'] > 0)])\
        .assign(log_wt = lambda df: np.log10(df['wt_norm']), log_act = lambda df: np.log10(df['act_norm']))

    return layer_df

def get_ols(layer_df):
    calc_df = layer_df\
        .pipe(lambda df: df[
            (df['act_norm'] > np.quantile(df['act_norm'], .01)) & 
            (df['wt_norm'] > np.quantile(df['wt_norm'], .01)) & 
            (df['act_norm'] < np.quantile(df['act_norm'], .99)) & 
            (df['wt_norm'] < np.quantile(df['wt_norm'], .99))
        ])
    b1, b0, r, _, _ = scipy.stats.linregress(calc_df['log_act'], calc_df['log_wt'])
    return {'b1': b1.item(), 'b0': b0.item(), 'r': r.item()}

layer_acts_by_layer = {layer_ix: get_layer_acts_and_wts(pre_mlp_df, layer_ix) for layer_ix in all_pre_mlp_hs.keys()}
ols_acts_by_layer = {layer_ix: get_ols(layer_act) for layer_ix, layer_act in layer_acts_by_layer.items()}

layers_to_plot = [
    int(np.floor(np.quantile(list(all_pre_mlp_hs.keys()), 0.25)).item()),
    int(np.floor(np.quantile(list(all_pre_mlp_hs.keys()), 0.50)).item()),
    int(np.floor(np.quantile(list(all_pre_mlp_hs.keys()), 0.75)).item())
]

color_map = {
    str(layer_ix): px.colors.qualitative.Plotly[i % len(px.colors.qualitative.Plotly)]
    for i, layer_ix in enumerate(sorted(layers_to_plot))
}

plot_df = pd.concat([layer_acts_by_layer[layer_ix] for layer_ix in layers_to_plot]).assign(layer_ix = lambda df: df['layer_ix'].astype(str))

fig = px.scatter(
    plot_df\
        .pipe(lambda df: df[(df['act_norm'] > np.quantile(df['act_norm'], .0005))  & (df['wt_norm'] > np.quantile(df['wt_norm'], .0005))])\
        .pipe(lambda df: df[(df['act_norm'] < np.quantile(df['act_norm'], .9995))  & (df['wt_norm'] < np.quantile(df['wt_norm'], .9995))]),
    x = 'act_norm', y = 'wt_norm',
    color = 'layer_ix',
    log_x = True, log_y = True,
    color_discrete_map = color_map,
    template = 'plotly_white',
    title = f'Per-dimension activation vs. router weight (log - log)',
    )\
    .update_traces(marker = dict(size = 5, opacity = 0.4, line = dict(width = 0)))

# Add regression line
for i, layer_ix in enumerate(layers_to_plot):
    layer_plot_df =\
        plot_df[plot_df['layer_ix'] == str(layer_ix)]\
        .pipe(lambda df: df[
            (df['act_norm'] > np.quantile(df['act_norm'], .01)) & 
            (df['wt_norm'] > np.quantile(df['wt_norm'], .01)) & 
            (df['act_norm'] < np.quantile(df['act_norm'], .99)) & 
            (df['wt_norm'] < np.quantile(df['wt_norm'], .99))
        ])

    x_fit_log_act = np.array([layer_plot_df['log_act'].min() - 0.1, layer_plot_df['log_act'].max() + 0.1])
    y_fit_log_wt = ols_acts_by_layer[layer_ix]['b0'] + ols_acts_by_layer[layer_ix]['b1']  * x_fit_log_act

    fig.add_scatter(
        x = 10 ** x_fit_log_act,
        y = 10 ** y_fit_log_wt,
        mode = 'lines',
        line = dict(width = 2, color = color_map[str(layer_ix)]),
        name = f'OLS Layer {layer_ix}',
        showlegend = False
    )

    x_anchor = np.quantile(layer_plot_df['act_norm'], .99) + 0.05
    y_anchor = 10 ** (ols_acts_by_layer[layer_ix]['b0'] + ols_acts_by_layer[layer_ix]['b1'] * np.log10(x_anchor)) + ((0.0 if i == 0 else -0.03) if i != 1 else 0.13) # Or .13 => 0.03
    fig.add_annotation(
        x = np.log10(x_anchor),
        y = np.log10(y_anchor),
        text = f"<b>layer {str(layer_ix + 1)} (<i>&#961;={ols_acts_by_layer[layer_ix]['r']:.2f}</i>)</b>",
        xanchor = 'center', yanchor = 'bottom',
        showarrow = False,
        font = dict(family = 'CMU Serif', size = font_size - 1, color = color_map[str(layer_ix)]),
        bgcolor='rgba(255, 255, 255, 0.9)', # White with 75% opacity
        borderpad=2 # Small padding around text
    )


# Clean layout
fig.update_layout(
    width = 800, height = 500,
    margin = dict(l = 10, r = 10, t = 40, b = 10),
    coloraxis_colorbar = dict(title = 'Residual'),
    font_family = 'CMU Serif',
    showlegend = False,
    font_size = font_size,
    coloraxis_showscale = False,
    xaxis_title = 'Router weight magnitude',
    yaxis_title = 'Hidden state magnitude',
    # xaxis = dict(tickmode = 'auto', nticks = 5, showexponent = 'all', exponentformat = 'power'),
    # yaxis = dict(tickmode = 'auto', nticks = 5, showexponent = 'all', exponentformat = 'power'),
    title = None,
    xaxis=dict(type = "log", dtick=1),     # keep decades only
    yaxis=dict(type = "log", dtick=1)
)

fig.show()

# pio.write_image(fig, f"exports/router-corr-{model_prefix}-md.pdf", width = 800, height = 500)
# pio.write_image(fig, f"exports/router-corr-{model_prefix}-md.svg", width = 800, height = 500)
# pio.write_image(
#     fig\
#     .update_layout(
#         xaxis_title = 'Router weight magnitude' if model_ix in [2, 3] else '',
#         yaxis_title = 'Hidden state magnitude' if model_ix in [0, 2] else ''
#     ),
#     f"exports/router-corr-{model_prefix}.pdf", width = 800, height = 450
# )

In [None]:
"""
Get all layer corrs + export
"""
# layer_corrs =\
#     pd.DataFrame([{'layer_ix': layer_ix + 1, 'rho': ols_res['r']} for layer_ix, ols_res in ols_acts_by_layer.items()])\
#     .assign(model = model_prefix)\
#     .assign(phase = lambda df: pd.cut(
#             df['layer_ix'] - 1,
#             bins = [0, df['layer_ix'].max() / 3, 2 * df['layer_ix'].max() / 3, df['layer_ix'].max()],
#             labels = ['early', 'mid', 'late'],
#             include_lowest = True,
#             right = False
#         )
#     )\
#     .groupby(['model', 'phase'], sort = False, as_index = False)\
#     .agg(rho_mean = ('rho', 'mean'))

layer_corrs =\
    pd.DataFrame([{'layer_ix': layer_ix + 1, 'rho': ols_res['r']} for layer_ix, ols_res in ols_acts_by_layer.items()])\
    .pipe(lambda df: df[df['layer_ix'] % 4 == 1])\
    .assign(model = model_prefix)

display(layer_corrs)
# layer_corrs.to_csv(f'exports/router-corrs-{model_prefix}.csv', mode = 'w', header = True, index = False)

In [None]:
"""
Linear regression - test ability to reconstruct topk expert id + export
"""
def lr_for_layer(layer_to_test):
    expert_ids =\
        topk_df\
        .pipe(lambda df: df[df['layer_ix'] == layer_to_test])\
        .pipe(lambda df: df[df['topk_ix'] == 1])\
        ['expert'].tolist()

    expert_ids_cp = cupy.asarray(expert_ids)

    lr_model = cuml.linear_model.LogisticRegression(
        penalty = 'l2', 
        max_iter = 10000,
        fit_intercept = False
    )

    dims = [
        x - 1
        for x in pre_mlp_df.pipe(lambda df: df[df['layer_ix'] == layer_to_test]).sort_values(by = 'mean_scaled_norm', ascending = False)['dim'].tolist()
    ]

    layer_hs = cupy.asarray(all_pre_mlp_hs[layer_to_test][:, dims[0:all_pre_mlp_hs[layer_to_test].shape[1]//50]].to(torch.float16).detach().cpu())
    lr_model.fit(layer_hs, expert_ids_cp)
    accuracy = lr_model.score(layer_hs, expert_ids_cp)
    print(f"Accuracy: {accuracy:.2%}")

    baseline_accs = []
    for _ in range(10):
        rand_dims = [int(x - 1) for x in np.random.choice(pre_mlp_df['dim'].tolist(), size = all_pre_mlp_hs[layer_to_test].shape[1] // 50, replace = False)]
        rand_hs = cupy.asarray(all_pre_mlp_hs[layer_to_test][:, rand_dims].to(torch.float16).detach().cpu())
        lr_model.fit(rand_hs, expert_ids_cp)
        test_baseline_acc = lr_model.score(layer_hs, expert_ids_cp)
        baseline_accs.append(test_baseline_acc)
    baseline_acc = np.mean(baseline_accs)

    print(f"Baseline accuracy: {baseline_acc:.2%}")

    return {
        'layer_ix_1': layer_to_test + 1,
        'accuracy': np.round(accuracy, 4).item(),
        'baseline_accuracy': baseline_acc.round(4).item()
    }

layer_res =\
    pd.DataFrame([lr_for_layer(layer_ix) for layer_ix in all_pre_mlp_hs if layer_ix % 2 == 0])\
    .assign(model = model_prefix)\
    .pipe(lambda df: df[df['layer_ix_1'] % 4 == 1])

display(layer_res)
# layer_res.to_csv(f'exports/router-probe-{model_prefix}.csv', mode = 'w', header = True, index = False)

In [None]:
"""
Saturation curve - mid layer - export data
"""
layer_to_test = int(np.floor(np.quantile(list(all_pre_mlp_hs.keys()), 0.50)).item())

def lr_for_layer_at_percent(layer_to_test, percent):
    expert_ids =\
        topk_df\
        .pipe(lambda df: df[df['layer_ix'] == layer_to_test])\
        .pipe(lambda df: df[df['topk_ix'] == 1])\
        ['expert'].tolist()

    expert_ids_cp = cupy.asarray(expert_ids)

    lr_model = cuml.linear_model.LogisticRegression(
        penalty = 'l2', 
        max_iter = 10000,
        fit_intercept = False
    )

    dims = [
        x - 1
        for x in pre_mlp_df.pipe(lambda df: df[df['layer_ix'] == layer_to_test]).sort_values(by = 'mean_scaled_norm', ascending = False)['dim'].tolist()
    ]

    # Calculate how many dimensions to use
    n_dims = int(all_pre_mlp_hs[layer_to_test].shape[1] * percent / 100)
    n_dims = max(1, n_dims)  # Ensure at least 1 dimension
    
    layer_hs = cupy.asarray(all_pre_mlp_hs[layer_to_test][:, dims[0:n_dims]].to(torch.float16).detach().cpu())
    lr_model.fit(layer_hs, expert_ids_cp)
    accuracy = lr_model.score(layer_hs, expert_ids_cp)
    print(f"Accuracy: {accuracy:.2%}")

    return {
        'layer_ix_1': layer_to_test + 1,
        'pct_dims': percent,
        'accuracy': np.round(accuracy, 4).item(),
    }

# Test different percentages
percentages = list(range(0, 102, 2))

saturation_data = [lr_for_layer_at_percent(layer_to_test, pct) for pct in tqdm(percentages)]
saturation_df = pd.DataFrame(saturation_data).assign(model = model_prefix, layer_ix_1 = layer_to_test + 1)

display(saturation_df)
saturation_df.to_csv(f'exports/router-saturation-{model_prefix}.csv', mode = 'w', header = True, index = False)

In [None]:
"""
Compare PCA top dimensions versus scaled activation top dimensions
"""
layer_to_test = list(all_pre_mlp_hs.keys())[5]

layer_hs = cupy.asarray(all_pre_mlp_hs[layer_to_test][0:200_000, :].to(torch.float16).detach().cpu())
mean_vals = cupy.mean(layer_hs, axis=0)
std_vals = cupy.std(layer_hs, axis=0)
std_vals = cupy.where(std_vals == 0, cupy.asarray(1e-7), std_vals)
layer_hs_std = (layer_hs - mean_vals)/std_vals

pca = cuml.decomposition.PCA(n_components = 10, random_state = 123)
pca.fit(layer_hs_std)

pc_loadings = pca.components_
sumsq = (pc_loadings ** 2).sum(axis=0)

ranking = cupy.argsort(-sumsq)  # descending order
pca_top_dims = ranking.tolist()

plot_df =\
    pd.DataFrame({'pca_sumsq': cupy.asarray(sumsq).tolist(), 'dim': list(range(1, len(sumsq) + 1))})\
    .merge(
        pre_mlp_df.pipe(lambda df: df[df['layer_ix'] == layer_to_test])[['dim', 'mean_scaled_norm']],
        on = 'dim',
        how = 'inner'
    )

px.scatter(
    plot_df,
    x = 'mean_scaled_norm',
    y = 'pca_sumsq'
).show()

In [None]:
"""
What % of hidden states is explained by PCA?
"""
# 1) Gather some data
clear_all_cuda_memory()
layer_hs = cupy.asarray(all_pre_mlp_hs[5][0:200_000, :].to(torch.float16).detach().cpu())

# 2) Fit PCA
pca_model = cuml.PCA(iterated_power = 20, n_components = 10, verbose = True)
pca_model.fit(layer_hs)

print("Explained variance ratio:", pca_model.explained_variance_ratio_)
print("Cumulative ratio:", np.cumsum(pca_model.explained_variance_ratio_.get())[-1])

# 3) Retrieve components & variance ratio
components = pca_model.components_.get()  # shape = (10, D)
expl_ratios = pca_model.explained_variance_ratio_.get()  # shape = (10,)

# 4) Compute dimension-level importance
sq_loadings = components**2        # shape (10, D)
dim_importance = sq_loadings.T @ expl_ratios   # shape (D,)

# 5) Identify top 20 dims
top_k = 10
idx_sorted = np.argsort(dim_importance)[::-1]
top_dims = idx_sorted[:top_k]
sum_top = dim_importance[top_dims].sum()
sum_all = dim_importance.sum()
frac_top = sum_top / sum_all

print(f"Top {top_k} dims by PCA-based importance: {top_dims}")
print(f"Sum of their importances: {sum_top:.4f}")
print(f"Fraction of total importance: {frac_top:.4f}")

## Load balancing

In [None]:
unique_layers = np.array(sorted(list(set(topk_df['layer_ix']))))
unique_experts = np.array(sorted(list(set(topk_df['expert'])))) 

topk_grouped_0 =\
    topk_df.groupby(['layer_ix', 'expert'], as_index = False)\
    .agg(
        token_count = ('sample_ix', 'nunique'), # count distinct tokens
        weight_sum = ('weight', 'sum') # sum of gating weights
    )

pd.merge(
    pd.DataFrame({'layer_ix': unique_layers}),
    pd.DataFrame({'expert': unique_experts}),
    how = 'cross'
)\
.merge(topk_grouped_0, how = 'left', on = ['layer_ix', 'expert'])\
.assign(
    token_count = lambda df: df['token_count'].fillna(0),
    weight_sum = lambda df: df['weight_sum'].fillna(0)
)

In [None]:
"""
Calculate load balancing metrics
"""
topk_grouped_0 =\
    topk_df\
    .pipe(lambda df: df[df['topk_ix'] == 1])\
    .groupby(['layer_ix', 'expert'], as_index = False)\
    .agg(
        token_count = ('sample_ix', 'nunique'), # count distinct tokens
        weight_sum = ('weight', 'sum') # sum of gating weights
    )

unique_layers = np.array(sorted(list(set(topk_df['layer_ix']))))
unique_experts = np.array(sorted(list(set(topk_df['expert'])))) 

# Fill in missing expert/layers
topk_grouped =\
    pd.merge(
        pd.DataFrame({'layer_ix': unique_layers}),
        pd.DataFrame({'expert': unique_experts}),
        how = 'cross'
    )\
    .merge(topk_grouped_0, how = 'left', on = ['layer_ix', 'expert'])\
    .assign(
        token_count = lambda df: df['token_count'].fillna(0),
        weight_sum = lambda df: df['weight_sum'].fillna(0)
    )\
    .assign(
        layer_token_sums = lambda df: df.groupby('layer_ix')['token_count'].transform('sum'), # fraction of tokens that pick (layer, expert)
        layer_weight_sums = lambda df: df.groupby('layer_ix')['weight_sum'].transform('sum'),
        token_frac = lambda df: df['token_count'] / df['layer_token_sums'],
        weight_frac = lambda df: df['weight_sum'] / df['layer_weight_sums']
    )

def shannon_entropy(probs):
    # Avoid log(0)
    probs = probs[probs > 0]
    return -np.sum(probs * np.log2(probs))

entropies = []
for layer, layer_df in topk_grouped.groupby('layer_ix'):
    token_entropy = shannon_entropy(layer_df['token_frac'].values)
    weight_entropy = shannon_entropy(layer_df['weight_frac'].values)
    entropies.append({
        'layer_ix': layer,
        'token_entropy': token_entropy,
        'weight_entropy': weight_entropy
    })
entropy_df = pd.DataFrame(entropies)

def kl_divergence(p, q):
    mask = (p > 0) & (q > 0)
    return np.sum(p[mask] * np.log2(p[mask]/q[mask]))

kl_list = []
for layer, layer_df in topk_grouped.groupby('layer_ix'):
    p_token = layer_df['token_frac'].values    
    q = np.full_like(p_token, 1/len(p_token))
    
    token_kl = kl_divergence(p_token, q)
    weight_kl = kl_divergence(layer_df['weight_frac'].values, q)
    
    kl_list.append({
        'layer_ix': layer,
        'token_kl': token_kl,
        'weight_kl': weight_kl
    })
kl_df = pd.DataFrame(kl_list)

px.line(
    kl_df,
    x = 'layer_ix', y = ['weight_kl', 'token_kl'],
    title = 'KL Divergence from Uniform'
).update_layout(autosize = False, width = 800, height = 400).show()

px.line(
    entropy_df,
    x = 'layer_ix', y = ['weight_entropy', 'token_entropy'],
    title = 'Shannon Entropy'
).update_layout(autosize = False, width = 800, height = 400).show()

px.line(
    topk_grouped.pipe(lambda df: df[df['expert'].isin(list(range(0, 100)))]),
    x = 'layer_ix',
    y = 'token_count',
    color = 'expert'
).update_layout(autosize = False, width = 800, height = 400).show()