In [None]:
"""
This contains code to test MoE hidden states & understand routing weights.
"""
None

In [None]:
"""
Imports
"""
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import pandas as pd
import numpy as np
from tqdm import tqdm
from termcolor import colored
import importlib

from utils.memory import check_memory, clear_all_cuda_memory

# https://docs.rapids.ai/install/
import cupy
import cuml

import plotly.express as px
import pickle

main_device = 'cuda:0'
seed = 1234
clear_all_cuda_memory()
check_memory()

## Load base model

In [None]:
"""
Load the base tokenizer/model
"""
model_id = 'allenai/OLMoE-1B-7B-0125-Instruct'
model_prefix = 'olmoe'
tokenizer = AutoTokenizer.from_pretrained(model_id, add_eos_token = False, add_bos_token = False, padding_side = 'left', trust_remote_code = True)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype = torch.bfloat16, trust_remote_code = True).cuda().eval()

## Load data

In [None]:
"""
Load dataset
"""
def load_data(model_prefix):
    all_pre_mlp_hs = torch.load(f'data/{model_prefix}-all-pre-mlp-hidden-states.pt')
    with open(f'data/{model_prefix}-metadata.pkl', 'rb') as f:
        metadata = pickle.load(f)
    
    return all_pre_mlp_hs, metadata['sample_df'], metadata['topk_df']

all_pre_mlp_hs, sample_df_import, topk_df_import = load_data(model_prefix)

In [None]:
"""
Let's clean up the mappings here. We'll get everything to a sample_ix level first.
"""
sample_df_raw =\
    sample_df_import\
    .assign(sample_ix = lambda df: df.groupby(['batch_ix', 'sequence_ix', 'token_ix']).ngroup())\
    .assign(seq_id = lambda df: df.groupby(['batch_ix', 'sequence_ix']).ngroup())\
    .reset_index()

topk_df =\
    topk_df_import\
    .merge(sample_df_raw[['sample_ix', 'batch_ix', 'sequence_ix', 'token_ix']], how = 'inner', on = ['sequence_ix', 'token_ix', 'batch_ix'])\
    .drop(columns = ['sequence_ix', 'token_ix', 'batch_ix'])

sample_df =\
    sample_df_raw\
    .drop(columns = ['batch_ix', 'sequence_ix'])

display(topk_df)
display(sample_df)

## Analyze routing weights

In [None]:
"""
Norms by expert and layer
"""
norms_by_expert_layer = pd.concat([
    pd.DataFrame({
        'layer_ix': layer_ix,
        'norm': torch.linalg.norm(layer.mlp.gate.weight, dim = 1, ord = 1).to(torch.float16).cpu().detach().numpy(),
        'expert': list(range(1, layer.mlp.gate.weight.shape[0] + 1))
    })
    for layer_ix, layer in enumerate(model.model.layers)
])

plot_df = norms_by_expert_layer.pivot(index = 'layer_ix', columns = 'expert', values = 'norm')
px.imshow(
    plot_df,
    x = plot_df.columns, y = plot_df.index,
    color_continuous_scale = 'ylgnbu',
    labels = {'color': 'Norm'}, title = "Norm by Expert and Layer"
).update_layout(autosize = False, width = 800).show()

scaled_df =\
    norms_by_expert_layer\
    .assign(layer_mean = lambda df: df.groupby('layer_ix')['norm'].transform('mean'))\
    .assign(norm_scaled = lambda df: df['norm'] / df['layer_mean'] - 1)

scaled_plot_df = scaled_df.pivot(index = 'layer_ix', columns = 'expert', values = 'norm_scaled')
px.imshow(
    scaled_plot_df,
    x = scaled_plot_df.columns, y = scaled_plot_df.index,
    color_continuous_scale = 'ylgnbu',
    labels = {'color': 'Norm'}, title = "Norm by Expert and Layer"
).update_layout(autosize = False, width = 800).show()

In [None]:
"""
For a single layer, what do the weights and RMSnorms look like?
"""
plot_layer_ix = 9
show_dims = list(range(0, 400))

# RMSNorm
rms_tensor = model.model.layers[plot_layer_ix].post_attention_layernorm.weight
rms_df = pd.DataFrame({
    'gamma': rms_tensor.to(torch.float16).cpu().detach().numpy(),
    'coef': 1,
    'dimension': list(range(0, rms_tensor.shape[0]))
})
plot_df = rms_df.pipe(lambda df: df[df['dimension'].isin(show_dims)]).pivot(index = 'coef', columns = 'dimension', values = 'gamma')

px.imshow(
    plot_df,
    x = plot_df.columns, y = plot_df.index,
    aspect = 'auto', color_continuous_scale = 'ylgnbu',
    labels = {'color': 'Norm'}, title = "RMSNorm Scaling Values"
).update_layout(autosize = False, width = 1400, height = 400).show()


# Weights
wt_tensor = model.model.layers[plot_layer_ix].mlp.gate.weight
wt_df = pd.DataFrame({
    'value': wt_tensor.view(-1).to(torch.float16).cpu().detach().numpy(),
    'expert': [i // wt_tensor.shape[1] for i in range(wt_tensor.view(-1).shape[0])],
    'dimension': [i % wt_tensor.shape[1] for i in range(wt_tensor.view(-1).shape[0])]
})

plot_df = wt_df.pipe(lambda df: df[df['dimension'].isin(show_dims)]).pivot(index = 'expert', columns = 'dimension', values = 'value')

px.imshow(
    plot_df,
    x = plot_df.columns, y = plot_df.index,
    aspect = 'auto', color_continuous_scale = 'ylgnbu',
    labels = {'color': 'Norm'}, title = "Routing Weights"
).update_layout(autosize = False, width = 1400).show()

# Scale weights by RMSNorm
scaled_df = wt_df.merge(rms_df, on = 'dimension', how = 'inner').assign(gamma_scaled_value = lambda df: df['gamma'] * df['value'])
plot_df = scaled_df.pipe(lambda df: df[df['dimension'].isin(show_dims)]).pivot(index = 'expert', columns = 'dimension', values = 'gamma_scaled_value')
px.imshow(
    plot_df,
    x = plot_df.columns, y = plot_df.index,
    aspect = 'auto', color_continuous_scale = 'ylgnbu',
    labels = {'color': 'Norm'}, title = "Scaled Routing Weights"
).update_layout(autosize = False, width = 1400).show()

In [None]:
"""
Mean norms across layers and dimension (averaged across experts)
"""
dfs_list = []
for layer_ix, layer in enumerate(model.model.layers):
    wt_tensor = layer.mlp.gate.weight.to(torch.float16).cpu().detach()
    rms_tensor = layer.post_attention_layernorm.weight.to(torch.float16).cpu().detach()
    scaled = (wt_tensor * rms_tensor) # Mltiply by RMS norm
    scaled = scaled.abs().mean(dim = 0) # Take mean L1 norm
    dfs_list.append(pd.DataFrame({
        'mean_norm': scaled.numpy(),
        'layer_ix': layer_ix,
        'dim': list(range(1, scaled.shape[0] + 1))
    }))

my_df = pd.concat(dfs_list)
# Additionally scale by layer average
my_df_ex_scale =\
    my_df\
    .assign(layer_mean = lambda df: df.groupby('layer_ix')['mean_norm'].transform('mean'))\
    .assign(mean_norm = lambda df: df['mean_norm'] / df['layer_mean'])

plot_df = my_df_ex_scale.pipe(lambda df: df[df['dim']  <= 200]).pivot(index = 'layer_ix', columns = 'dim', values = 'mean_norm')

px.imshow(
    plot_df,
    x = plot_df.columns, y = plot_df.index,
    zmin = 0, zmax = 8,
    aspect = 'auto', # Allow non-square boxes
    color_continuous_scale = 'ylgnbu',
    labels = {'color': 'Norm'}, title = "Mean norms by dimension and layer"
).update_layout(autosize = False, width = 1400, coloraxis = dict()).show()


In [None]:
"""
At dimension x layer-level, analyze activations (averaged across samples) versus routing weights (averaged across experts).
"""
show_dims = list(range(0, 800))

pre_mlp_for_layer_norms_all = all_pre_mlp_hs[0:500_000, :, :].abs().mean(dim = 0) # Collapse to n_layers x D dimensional

dfs_list = []
for layer_ix, pre_mlp_for_layer_norm in enumerate(pre_mlp_for_layer_norms_all.unbind(dim = 0)):
    wt_tensor = model.model.layers[layer_ix].mlp.gate.weight[:, :].to(torch.float16).cpu().detach() # (n_experts, D)
    act_tensor = pre_mlp_for_layer_norm.cpu().detach() # D-dimensional
    scaled = (wt_tensor * act_tensor) # Multiply by activation tensor
    scaled = scaled.abs().mean(dim = 0) # Take mean L1 norm
    dfs_list.append(pd.DataFrame({
        'act_norm': act_tensor.numpy(),
        'wt_norm': wt_tensor.abs().mean(dim = 0),
        'mean_scaled_norm': scaled.numpy(),
        'layer_ix': layer_ix,
        'dim': list(range(1, scaled.shape[0] + 1)) # show_dims
    }))

pre_mlp_df = pd.concat(dfs_list)
del dfs_list, pre_mlp_for_layer_norms_all

plot_df = pre_mlp_df.pipe(lambda df: df[df['dim'].isin(show_dims)]).pivot(index = 'layer_ix', columns = 'dim', values = 'mean_scaled_norm')
px.imshow(
    plot_df,
    x = plot_df.columns, y = plot_df.index,
    zmin = 0,
    zmax = .2,
    aspect = 'auto', color_continuous_scale = 'ylgnbu',
    labels = {'color': 'Norm'}, title = "Mean scaled wt * activation norms by dimension and layer"
).update_layout(autosize = False, width = 1400, coloraxis = dict()).show()

plot_df = pre_mlp_df.pipe(lambda df: df[df['dim'].isin(show_dims)]).pivot(index = 'layer_ix', columns = 'dim', values = 'act_norm')
px.imshow(
    plot_df,
    zmin = 0, zmax = 8,
    x = plot_df.columns, y = plot_df.index,
    aspect = 'auto', color_continuous_scale = 'ylgnbu',
    labels = {'color': 'Norm'}, title = "Mean activation norms by dimension and layer"
).update_layout(autosize = False, width = 1400, coloraxis = dict()).show()

plot_df = pre_mlp_df.pipe(lambda df: df[df['dim'].isin(show_dims)]).pivot(index = 'layer_ix', columns = 'dim', values = 'wt_norm')
px.imshow(
    plot_df,
    x = plot_df.columns, y = plot_df.index,
    aspect = 'auto', color_continuous_scale = 'ylgnbu',
    labels = {'color': 'Norm'}, title = "Mean weight norms by dimension and layer"
).update_layout(autosize = False, width = 1400, coloraxis = dict()).show()

# Plot activation vs routing weight norms
px.scatter(
    pre_mlp_df.pipe(lambda df: df[df['layer_ix'] == 6]),
    x = 'wt_norm', y = 'act_norm', color = 'wt_norm',
    log_y = True,
    log_x = True,
    color_continuous_scale = 'viridis', title = 'Per-Dimension Plot - Activation L1 Norm versus Routing Weight L1 Norm'
).update_layout(autosize = False, width = 1400, coloraxis = dict()).show()

import scipy
scipy.stats.kurtosis(pre_mlp_df.pipe(lambda df: df[df['layer_ix'] == 6]['act_norm'].tolist() ))

In [None]:
"""
Linear regression - test ability to reconstruct topk expert id
"""
layer_to_test = 5

expert_ids =\
    topk_df\
    .pipe(lambda df: df[df['layer_ix'] == layer_to_test])\
    .pipe(lambda df: df[df['topk_ix'] == 1])\
    ['expert'].tolist()

expert_ids_cp = cupy.asarray(expert_ids)

lr_model = cuml.linear_model.LogisticRegression(
    penalty = 'l2', 
    max_iter = 10000,
    fit_intercept = False
)

dims = [
    x - 1
    for x in pre_mlp_df.pipe(lambda df: df[df['layer_ix'] == layer_to_test]).sort_values(by = 'mean_scaled_norm', ascending = False)['dim'].tolist()
]

layer_hs = cupy.asarray(all_pre_mlp_hs[:, layer_to_test, dims[0:all_pre_mlp_hs.shape[2]//50]].to(torch.float16).detach().cpu())
lr_model.fit(layer_hs, expert_ids_cp)
accuracy = lr_model.score(layer_hs, expert_ids_cp)
print(f"Accuracy: {accuracy:.2%}")

rand_dims = [int(x - 1) for x in np.random.choice(pre_mlp_df['dim'].tolist(), size = all_pre_mlp_hs.shape[2] // 50, replace = False)]
rand_hs = cupy.asarray(all_pre_mlp_hs[:, layer_to_test, rand_dims].to(torch.float16).detach().cpu())
lr_model.fit(rand_hs, expert_ids_cp)
accuracy = lr_model.score(layer_hs, expert_ids_cp)
print(f"Baseline accuracy: {accuracy:.2%}")

In [None]:
"""
Compare PCA top dimensinos versus scaled activation top dimensions
"""
layer_hs = cupy.asarray(all_pre_mlp_hs[0:200_000, layer_to_test, :].to(torch.float16).detach().cpu())
mean_vals = cupy.mean(layer_hs, axis=0)
std_vals = cupy.std(layer_hs, axis=0)
std_vals = cupy.where(std_vals == 0, cupy.asarray(1e-7), std_vals)
layer_hs_std = (layer_hs - mean_vals)/std_vals

pca = cuml.decomposition.PCA(n_components = 10, random_state = 123)
pca.fit(layer_hs_std)

pc_loadings = pca.components_
sumsq = (pc_loadings ** 2).sum(axis=0)

ranking = cupy.argsort(-sumsq)  # descending order
pca_top_dims = ranking.tolist()

plot_df =\
    pd.DataFrame({'pca_sumsq': cupy.asarray(sumsq).tolist(), 'dim': list(range(1, len(sumsq) + 1))})\
    .merge(
        pre_mlp_df.pipe(lambda df: df[df['layer_ix'] == layer_to_test])[['dim', 'mean_scaled_norm']],
        on = 'dim',
        how = 'inner'
    )

px.scatter(
    plot_df,
    x = 'mean_scaled_norm',
    y = 'pca_sumsq',
    log_y = True,
    log_x = True
).show()

In [None]:
"""
What % of hidden states is explained by PCA?
"""
# 1) Gather some data
clear_all_cuda_memory()
layer_hs = cupy.asarray(all_pre_mlp_hs[0:200_000, 5, :].to(torch.float16).detach().cpu())

# 2) Fit PCA
pca_model = cuml.PCA(iterated_power = 20, n_components = 10, verbose = True)
pca_model.fit(layer_hs)

print("Explained variance ratio:", pca_model.explained_variance_ratio_)
print("Cumulative ratio:", np.cumsum(pca_model.explained_variance_ratio_.get())[-1])

# 3) Retrieve components & variance ratio
components = pca_model.components_.get()  # shape = (10, D)
expl_ratios = pca_model.explained_variance_ratio_.get()  # shape = (10,)

# 4) Compute dimension-level importance
sq_loadings = components**2        # shape (10, D)
dim_importance = sq_loadings.T @ expl_ratios   # shape (D,)

# 5) Identify top 20 dims
top_k = 10
idx_sorted = np.argsort(dim_importance)[::-1]
top_dims = idx_sorted[:top_k]
sum_top = dim_importance[top_dims].sum()
sum_all = dim_importance.sum()
frac_top = sum_top / sum_all

print(f"Top {top_k} dims by PCA-based importance: {top_dims}")
print(f"Sum of their importances: {sum_top:.4f}")
print(f"Fraction of total importance: {frac_top:.4f}")

## Load balancing

In [None]:
unique_layers = np.array(sorted(list(set(topk_df['layer_ix']))))
unique_experts = np.array(sorted(list(set(topk_df['expert'])))) 

topk_grouped_0 =\
    topk_df.groupby(['layer_ix', 'expert'], as_index = False)\
    .agg(
        token_count = ('sample_ix', 'nunique'), # count distinct tokens
        weight_sum = ('weight', 'sum') # sum of gating weights
    )

pd.merge(
    pd.DataFrame({'layer_ix': unique_layers}),
    pd.DataFrame({'expert': unique_experts}),
    how = 'cross'
)\
.merge(topk_grouped_0, how = 'left', on = ['layer_ix', 'expert'])\
.assign(
    token_count = lambda df: df['token_count'].fillna(0),
    weight_sum = lambda df: df['weight_sum'].fillna(0)
)

In [None]:
"""
Calculate load balancing metrics
"""
topk_grouped_0 =\
    topk_df\
    .pipe(lambda df: df[df['topk_ix'] == 1])\
    .groupby(['layer_ix', 'expert'], as_index = False)\
    .agg(
        token_count = ('sample_ix', 'nunique'), # count distinct tokens
        weight_sum = ('weight', 'sum') # sum of gating weights
    )

unique_layers = np.array(sorted(list(set(topk_df['layer_ix']))))
unique_experts = np.array(sorted(list(set(topk_df['expert'])))) 

# Fill in missing expert/layers
topk_grouped =\
    pd.merge(
        pd.DataFrame({'layer_ix': unique_layers}),
        pd.DataFrame({'expert': unique_experts}),
        how = 'cross'
    )\
    .merge(topk_grouped_0, how = 'left', on = ['layer_ix', 'expert'])\
    .assign(
        token_count = lambda df: df['token_count'].fillna(0),
        weight_sum = lambda df: df['weight_sum'].fillna(0)
    )\
    .assign(
        layer_token_sums = lambda df: df.groupby('layer_ix')['token_count'].transform('sum'), # fraction of tokens that pick (layer, expert)
        layer_weight_sums = lambda df: df.groupby('layer_ix')['weight_sum'].transform('sum'),
        token_frac = lambda df: df['token_count'] / df['layer_token_sums'],
        weight_frac = lambda df: df['weight_sum'] / df['layer_weight_sums']
    )

def shannon_entropy(probs):
    # Avoid log(0)
    probs = probs[probs > 0]
    return -np.sum(probs * np.log2(probs))

entropies = []
for layer, layer_df in topk_grouped.groupby('layer_ix'):
    token_entropy = shannon_entropy(layer_df['token_frac'].values)
    weight_entropy = shannon_entropy(layer_df['weight_frac'].values)
    entropies.append({
        'layer_ix': layer,
        'token_entropy': token_entropy,
        'weight_entropy': weight_entropy
    })
entropy_df = pd.DataFrame(entropies)

def kl_divergence(p, q):
    mask = (p > 0) & (q > 0)
    return np.sum(p[mask] * np.log2(p[mask]/q[mask]))

kl_list = []
for layer, layer_df in topk_grouped.groupby('layer_ix'):
    p_token = layer_df['token_frac'].values    
    q = np.full_like(p_token, 1/len(p_token))
    
    token_kl = kl_divergence(p_token, q)
    weight_kl = kl_divergence(layer_df['weight_frac'].values, q)
    
    kl_list.append({
        'layer_ix': layer,
        'token_kl': token_kl,
        'weight_kl': weight_kl
    })
kl_df = pd.DataFrame(kl_list)

px.line(
    kl_df,
    x = 'layer_ix', y = ['weight_kl', 'token_kl'],
    title = 'KL Divergence from Uniform'
).update_layout(autosize = False, width = 800, height = 400).show()

px.line(
    entropy_df,
    x = 'layer_ix', y = ['weight_entropy', 'token_entropy'],
    title = 'Shannon Entropy'
).update_layout(autosize = False, width = 800, height = 400).show()

px.line(
    topk_grouped.pipe(lambda df: df[df['expert'].isin(list(range(0, 100)))]),
    x = 'layer_ix',
    y = 'token_count',
    color = 'expert'
).update_layout(autosize = False, width = 800, height = 400).show()

## Basic clustering

In [None]:
""" 
Cross-layer Topk = 1 Clusters
"""
def print_samples(df, grouping_cols):
    """
    Takes a wide dataframe and groups it, then prints random groups
    """
    res =\
        df\
        .groupby(grouping_cols, as_index = False)\
        .agg(
            n_samples = ('token', 'size'),
            samples = ('token', lambda s: s.sample(n = min(len(s), 10)).tolist())
        )\
        .pipe(lambda df: df[df['n_samples'] >= 5])\
        .sample(35)
    display(res)

topk_wide =\
    topk_df\
    .pipe(lambda df: df[df['topk_ix'] == 1])\
    .merge(sample_df[['sample_ix', 'token']], on = 'sample_ix', how = 'inner')\
    .pivot(index = ['sample_ix', 'token'], columns = 'layer_ix', values = 'expert')\
    .rename(columns = lambda c: f'layer_{c}_id')\
    .reset_index()

display(topk_wide.groupby('layer_1_id', as_index = False).agg(n_samples = ('token', 'size')).sort_values(by = 'n_samples', ascending = False))
print_samples(topk_wide, ['layer_4_id', 'layer_5_id', 'layer_6_id', 'layer_7_id'])

In [None]:
"""
Within layer clusters
"""
# Pivot by layer and topk to get expert_l4_k1, etc.
layer_topk_df =\
    topk_df\
    .pipe(lambda df: df[df['layer_ix'].isin([5, 7])])\
    .merge(sample_df[['sample_ix', 'token']], on = 'sample_ix', how = 'inner')\
    .assign(layer_topk_ix = lambda df: 'l' + df['layer_ix'].astype(str) + '_k' + df['topk_ix'].astype(str))\
    .pivot(index = ['sample_ix', 'token'], columns = ['layer_topk_ix'], values = 'expert')\
    .rename(columns = lambda c: f'expert_{c}')\
    .merge(sample_df[['sample_ix', 'token']], on = 'sample_ix', how = 'inner')\
    .reset_index(drop = True)

print_samples(layer_topk_df, ['expert_l5_k1', 'expert_l5_k2', 'expert_l5_k3', 'expert_l5_k4'])
print_samples(layer_topk_df, ['expert_l7_k1', 'expert_l7_k2', 'expert_l7_k3', 'expert_l7_k4'])
print_samples(layer_topk_df, ['expert_l5_k1', 'expert_l5_k2', 'expert_l7_k1', 'expert_l7_k2'])

In [None]:
"""
Base K-Means (note - returns imbalanced clusters)
""" 
def cluster_kmeans(layer_hs: torch.Tensor, n_clusters = 64):
    """
    Params:
        @layer_hs: A n_token_samples x D tensor for a single layer
        @n_clusters: The number of clusters to return

    Returns:
        A list of length n_token_samples of cluster ids
    """
    hs_cupy = cupy.asarray(layer_hs.to(torch.float16))
    kmeans_model = cuml.cluster.KMeans(
        n_clusters = n_clusters,
        max_iter = 1000,
        random_state = 123,
        verbose = True
    )
    kmeans_model.fit(hs_cupy)
    cluster_labels = kmeans_model.labels_ # shape = (n_samples,)
    # cluster_centers = kmeans_model.cluster_centers_ # shape = (num_clusters, D)
    clear_all_cuda_memory()
    return cluster_labels.tolist()

kmeans_res = [
    {'layer_ix': layer_ix, 'cluster_ids': cluster_kmeans(layer_hs, 64)}
    for layer_ix, layer_hs in tqdm(enumerate(all_pre_mlp_hs.unbind(dim = 1)))
]

kmeans_df =\
    pd.concat([pd.DataFrame({'layer_' + str(x['layer_ix']) + '_id': x['cluster_ids']}) for x in kmeans_res], axis = 1)\
    .pipe(lambda df: pd.concat([df, sample_df], axis = 1))

display(kmeans_df.groupby('layer_1_id', as_index = False).agg(n_samples = ('token', 'size')).sort_values(by = 'n_samples', ascending = False))
display(kmeans_df.groupby('layer_3_id', as_index = False).agg(n_samples = ('token', 'size')).sort_values(by = 'n_samples', ascending = False))
clear_all_cuda_memory()

print_samples(kmeans_df, ['layer_2_id', 'layer_3_id', 'layer_4_id', 'layer_5_id', 'layer_6_id'])

## Dimension reduction clustering

In [None]:
""" 
Test decomp methods
"""
def reduce_pca(layer_hs: torch.Tensor, n_components = 2):
    # https://docs.rapids.ai/api/cuml/stable/api/#principal-component-analysis
    hs_cupy = cupy.asarray(layer_hs.to(torch.float16))
    model = cuml.PCA(
        iterated_power = 20,
        n_components = n_components,
        verbose = True
    )
    model.fit(hs_cupy)
    # print(f'Explained variance ratio: {model.explained_variance_ratio_}')
    # print(f'Cumulative variance ratio: {np.cumsum(model.explained_variance_ratio_)[-1]}')
    # print(f'Means by feature: {model.mean_}')
    # print(f'Max feature mean: {np.max(model.mean_)} | Min feature mean: {np.min(model.mean_)}')
    pred = cupy.asnumpy(model.transform(hs_cupy))
    clear_all_cuda_memory()
    return pred

pca_test = reduce_pca(all_pre_mlp_hs.unbind(dim = 1)[0], 100)
px.scatter(
    pd.concat([pd.DataFrame({'d1': pca_test[:, 0], 'd2': pca_test[:, 1]}), sample_df.head(pca_test.shape[0])], axis = 1)\
        .sample(5000)
        .assign(is_of = lambda df: np.where(df['token'] == ' of', 1, 0)),
    x = 'd1', y = 'd2', color = 'is_of', hover_data = ['token']
).show()

pca_10 = [reduce_pca(layer_hs, 10) for layer_hs in tqdm(all_pre_mlp_hs.unbind(dim = 1))]
pca_100 = [reduce_pca(layer_hs, 100) for layer_hs in tqdm(all_pre_mlp_hs.unbind(dim = 1))]

In [None]:
def reduce_umap(layer_hs: torch.Tensor, n_components = 2, metric = 'cosine'):
    # https://docs.rapids.ai/api/cuml/stable/api/#umap
    hs_cupy = cupy.asarray(layer_hs.to(torch.float16))

    model = cuml.UMAP(
        n_components = n_components, 
        n_neighbors = 15, # 15 for default, smaller = more local data preserved [2 - 100]
        metric = metric, # euclidean, cosine, manhattan, l2, hamming
        min_dist = 0.1, # 0.1 by default, effective distance between embedded points
        n_epochs = 200, # 200 by default for large datasets
        random_state = None, # Allow parallelism
        verbose = False
    )
    pred = cupy.asnumpy(model.fit_transform(hs_cupy))
    clear_all_cuda_memory()
    return pred

umap_test = reduce_umap(all_pre_mlp_hs.unbind(dim = 1)[0], 2, 'cosine') # 300k = 2min
px.scatter(
    pd.concat([pd.DataFrame({'d1': umap_test[:, 0], 'd2': umap_test[:, 1]}), sample_df.head(umap_test.shape[0])], axis = 1)\
        .sample(5000)
        .assign(is_of = lambda df: np.where(df['token'] == ' of', 1, 0)),
    x = 'd1', y = 'd2', color = 'is_of', hover_data = ['token']
).show()

# umap_euc_10 = [reduce_umap(layer_hs, 10, 'euclidean') for layer_hs in tqdm(all_pre_mlp_hs.unbind(dim = 1))]
# umap_euc_100 = [reduce_umap(layer_hs, 100, 'euclidean') for layer_hs in tqdm(all_pre_mlp_hs.unbind(dim = 1))]
umap_cos_10 = [reduce_umap(layer_hs, 10, 'cosine') for layer_hs in tqdm(all_pre_mlp_hs.unbind(dim = 1))] # Cosine most closely maps to router (dot product)
umap_cos_100 = [reduce_umap(layer_hs, 100, 'cosine') for layer_hs in tqdm(all_pre_mlp_hs.unbind(dim = 1))]

In [None]:
"""
Kmeans
"""
def cluster_kmeans(layer_hs_np: np.ndarray):
    """
    Cluster a layer using Kmeans

    Params:
        @layer_hs_np: An np array of size n_samples x Dhat, where Dhat is some possibly compressed hidden state dimension.
    """
    # https://docs.rapids.ai/api/cuml/stable/api/#kmeans
    hs_cupy = cupy.asarray(layer_hs_np)

    model = cuml.KMeans(
        n_clusters = 100,
        max_iter = 500
    )
    cluster_labels = model.fit_predict(hs_cupy).tolist()
    return cluster_labels


def test_kmeans(layer_hs_list, layers_to_group):
    """
    Cluster multiple layers and print diagnostics, then print cross-layer groups.
    
    Params:
        @layer_hs_list: A list of np arrays, each of size n_samples x Dhat, where Dhat is some possibly compressed hidden state dimension
        @layers_to_group: The indices of layer_hs_list (0-indexed) to be used for grouping clusters across layers.
    """
    cl_res = [{'layer_ix': l, 'cluster_ids': cluster_kmeans(layer_hs_list[l])} for l in tqdm(layers_to_group)]
    
    cl_df =\
        pd.concat([pd.DataFrame({'layer_' + str(x['layer_ix']) + '_id': x['cluster_ids']}) for x in cl_res], axis = 1)\
        .pipe(lambda df: pd.concat([df, sample_df.head(layer_hs_list[0].shape[0])], axis = 1))

    display(cl_df.groupby(f"layer_{str(layers_to_group[0])}_id", as_index = False).agg(n_samples = ('token', 'size')).sort_values(by = 'n_samples', ascending = False))
    print('Cross-layer clusters:')
    print_samples(cl_df, [f"layer_{str(l)}_id"  for l in layers_to_group])

    return cl_df

kmeans_path_1 = test_kmeans(umap_cos_100, [2, 3, 4, 5, 6])

In [None]:
"""
Agglomerative
"""
def cluster_aggc(layer_hs_np: np.ndarray):
    """
    Cluster a layer using Kmeans

    Params:
        @layer_hs_np: An np array of size n_samples x Dhat, where Dhat is some possibly compressed hidden state dimension.
    """
    # https://docs.rapids.ai/api/cuml/stable/api/#dbscan
    hs_cupy = cupy.asarray(layer_hs_np)

    model = cuml.AgglomerativeClustering(
        n_clusters = 100,
        metric = 'cosine'
    )
    cluster_labels = model.fit_predict(hs_cupy).tolist()
    return cluster_labels


def test_aggc(layer_hs_list, layers_to_group):
    """
    Cluster multiple layers and print diagnostics, then print cross-layer groups.
    
    Params:
        @layer_hs_list: A list of np arrays, each of size n_samples x Dhat, where Dhat is some possibly compressed hidden state dimension
        @layers_to_group: The indices of layer_hs_list (0-indexed) to be used for grouping clusters across layers.
    """
    cl_res = [{'layer_ix': l, 'cluster_ids': cluster_aggc(layer_hs_list[l])} for l in tqdm(layers_to_group)]
    
    cl_df =\
        pd.concat([pd.DataFrame({'layer_' + str(x['layer_ix']) + '_id': x['cluster_ids']}) for x in cl_res], axis = 1)\
        .pipe(lambda df: pd.concat([df, sample_df.head(layer_hs_list[0].shape[0])], axis = 1))

    display(cl_df.groupby(f"layer_{str(layers_to_group[0])}_id", as_index = False).agg(n_samples = ('token', 'size')).sort_values(by = 'n_samples', ascending = False))
    print('Cross-layer clusters:')
    print_samples(cl_df, [f"layer_{str(l)}_id"  for l in layers_to_group])

    return cl_df

aggc_path_1 = test_aggc(umap_cos_100, [2, 3, 4, 5, 6])

In [None]:
"""
DBScan
"""
def cluster_dbscan(layer_hs_np: np.ndarray, metric = 'euclidean'):
    """
    Cluster a layer using DBScan

    Params:
        @layer_hs_np: An np array of size n_samples x Dhat, where Dhat is some possibly compressed hidden state dimension.
        @metric: The distance metric to use. Either "euclidean" or "cosine" are reasonable.
    """
    # https://docs.rapids.ai/api/cuml/stable/api/#dbscan
    hs_cupy = cupy.asarray(layer_hs_np)

    model = cuml.DBSCAN(
        metric = metric, # Or cosine
        min_samples = 10, # Number of samples st the group can be considered a core point
        verbose = False
    )
    cluster_labels = model.fit_predict(hs_cupy).tolist()
    return cluster_labels


def test_dbscan(layer_hs_list, metric, layers_to_group):
    """
    Cluster multiple layers and print diagnostics, then print cross-layer groups.
    
    Params:
        @layer_hs_list: A list of np arrays, each of size n_samples x Dhat, where Dhat is some possibly compressed hidden state dimension
        @metric: The distance metric to use. Either "euclidean" or "cosine" are reasonable.
        @layers_to_group: The indices of layer_hs_list (0-indexed) to be used for grouping clusters across layers.
    """
    cl_res = [{'layer_ix': l, 'cluster_ids': cluster_dbscan(layer_hs_list[l] , metric)} for l in tqdm(layers_to_group)]
    
    for r in cl_res:
        print(f"Clusters {len(set(r['cluster_ids'])):,} | Unassigned to clusters: {len([x for x in r['cluster_ids'] if x == -1]):,}/{len(r['cluster_ids']):,}")

    cl_df =\
        pd.concat([pd.DataFrame({'layer_' + str(x['layer_ix']) + '_id': x['cluster_ids']}) for x in cl_res], axis = 1)\
        .pipe(lambda df: pd.concat([df, sample_df.head(layer_hs_list[0].shape[0])], axis = 1))

    display(cl_df.groupby(f"layer_{str(layers_to_group[0])}_id", as_index = False).agg(n_samples = ('token', 'size')).sort_values(by = 'n_samples', ascending = False))
    print('Cross-layer clusters:')
    print_samples(cl_df, [f"layer_{str(l)}_id"  for l in layers_to_group])

    return cl_df

dbscan_paths_1 = test_dbscan(umap_cos_10, 'euclidean', [2, 3, 4, 5, 6])

In [None]:
"""
HDBSCAN
"""

def cluster_hdbscan(layer_hs_np: np.ndarray, metric = 'euclidean'):
    """
    Cluster a layer using HDBScan

    Params:
        @layer_hs_np: An np array of size n_samples x Dhat, where Dhat is some possibly compressed hidden state dimension.
        @metric: The distance metric to use. Either "euclidean" or "cosine" are reasonable.
    """
    # https://docs.rapids.ai/api/cuml/stable/api/#dbscan
    hs_cupy = cupy.asarray(layer_hs_np)

    model = cuml.HDBSCAN(
        min_cluster_size = len(hs_cupy) // (64 * 50), # Min 1/20 of the uniform dist value
        max_cluster_size = len(hs_cupy) // (64 * 1/50), # Max 20x the uniform dist values 
        metric = metric,
        min_samples = 1,
    )
    cluster_labels = model.fit_predict(hs_cupy).tolist()
    return cluster_labels

def test_hdbscan(layer_hs_list, metric, layers_to_group):
    """
    Cluster multiple layers and print diagnostics, then print cross-layer groups.
    
    Params:
        @layer_hs_list: A list of np arrays, each of size n_samples x Dhat, where Dhat is some possibly compressed hidden state dimension
        @metric: The distance metric to use. Either "euclidean" or "cosine" are reasonable.
        @layers_to_group: The indices of layer_hs_list (0-indexed) to be used for grouping clusters across layers.
    """
    cl_res = [{'layer_ix': l, 'cluster_ids': cluster_hdbscan(layer_hs_list[l] , metric)} for l in tqdm(layers_to_group)]
    
    for r in cl_res:
        print(f"Clusters {len(set(r['cluster_ids'])):,} | Unassigned to clusters: {len([x for x in r['cluster_ids'] if x == -1]):,}/{len(r['cluster_ids']):,}")

    cl_df =\
        pd.concat([pd.DataFrame({'layer_' + str(x['layer_ix']) + '_id': x['cluster_ids']}) for x in cl_res], axis = 1)\
        .pipe(lambda df: pd.concat([df, sample_df.head(layer_hs_list[0].shape[0])], axis = 1))

    display(cl_df.groupby(f"layer_{str(layers_to_group[0])}_id", as_index = False).agg(n_samples = ('token', 'size')).sort_values(by = 'n_samples', ascending = False))
    print('Cross-layer clusters:')
    print_samples(cl_df, [f"layer_{str(l)}_id"  for l in layers_to_group])

    return cl_df

hdbscan_paths_1 = test_hdbscan(umap_cos_10, 'euclidean', [2, 3, 4, 5, 6])

## SVD Decomposition

In [None]:
"""
The goal now is to split things by SVD
"""

def decompose_orthogonal(hidden_states: torch.Tensor, router_weights: torch.Tensor, method: str = 'svd', svd_tol: float = 1e-6) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Decomposes hidden states into components parallel and orthogonal to the row space of the router weights.

    The component parallel to the row space ('h_row') contains the information seen by the linear router mechanism (logits = W_g @ h).

    The component orthogonal to the row space ('h_orth') contains information ignored by the linear router mechanism, but potentially used by the non-linear expert MLP or downstream layers.

    Params:
        @hidden_states: Tensor of shape (n_samples, D) representing the pre-routing hidden states.
        @router_weights: Tensor of shape (n_experts, D) representing the linear router gate weights for the layer.
        @method: Decomposition method, 'svd' (default) or 'qr'.
        @svd_tol: Tolerance for determining non-zero singular values in SVD to establish the matrix rank.

    Returns:
        A tuple containing:
        - h_row (torch.Tensor): Projection onto the row space ("used" by router). Shape (n_samples, D).
        - h_orth (torch.Tensor): Projection onto the orthogonal complement ("unused" by router). Shape (n_samples, D).

    Example:
        h_row, h_orth = decompose_orthogonal(all_pre_mlp_hs[0:10_000, 1, :].to(torch.float32), model.model.layers[1].mlp.gate.weight.to(torch.float32).detach().cpu(), 'svd')
        dot_products_svd = torch.sum(h_row * h_orth, dim=1)
        print(f"Mean dot product (SVD): {torch.mean(dot_products_svd).item():.4e}")
        print(f"Max absolute dot product (SVD): {torch.max(torch.abs(dot_products_svd)).item():.4e}")

        reconstruction_diff_svd = torch.norm(all_pre_mlp_hs[0:10_000, 1, :].to(torch.float32) - (h_row + h_orth), dim=1)
        print(f"Mean reconstruction norm diff (SVD): {torch.mean(reconstruction_diff_svd).item():.4e}")

        # Can also verify that QR orthogonality/reconstruction is close to 0, and also that SVD and QR results shoudl be close torch.norm(h_svd = h_qr)
    """
    _, D = hidden_states.shape

    assert D == router_weights.shape[1], 'Hidden state dim != router gate dim'

    if method == 'svd':
        # Compute SVD: W_g = U S V^T
        # V^T (Vt) has shape (k, D), where k = min(n_experts, D)
        # The rows of V^T are the right singular vectors (orthonormal)
        # The first 'rank' rows of V^T span the row space of W_g
        U, S, Vt = torch.linalg.svd(router_weights, full_matrices = False) # Use full_matrices = False for efficiency if D > n_experts

        # Determine rank based on tolerance
        rank = torch.sum(S > svd_tol)
        if rank == 0:
             raise Exception('Router weights matrix has rank 0 according to tolerance.')

        # Basis for the row space (columns of Vr)
        # Vt[:rank] selects the first 'rank' rows (shape rank x D)
        # .T makes it (D x rank) - columns are the orthonormal basis vectors
        Vr = Vt[:rank, :].T

        # Project hidden_states onto the row space (Vr)
        # Formula: h_row = Vr @ Vr^T @ h
        # Batched: H_row = (H @ Vr) @ Vr^T
        # (n_samples, D) @ (D, rank) -> (n_samples, rank)
        h_projected_coeffs = hidden_states @ Vr
        # (n_samples, rank) @ (rank, D) -> (n_samples, D)
        h_row = h_projected_coeffs @ Vr.T

    elif method == 'qr':
        # Compute QR decomposition of W_g^T: W_g^T = Q R
        # Q will have shape (D, k), where k = min(D, n_experts)
        # Columns of Q form an orthonormal basis for column space of W_g^T, which is the row space of W_g.
        Q, R = torch.linalg.qr(router_weights.T, mode = 'reduced') # Use 'reduced' mode for efficiency

        # Q's columns are the orthonormal basis (shape D x k)
        # Need to consider rank deficiency if applicable, but QR handles it implicitly by the shape of Q returned by 'reduced' mode.

        # Project hidden_states onto the column space of Q
        # Formula: h_row = Q @ Q^T @ h
        # Batched: H_row = (H @ Q) @ Q^T
        # (n_samples, D) @ (D, k) -> (n_samples, k)
        h_projected_coeffs = hidden_states @ Q
        # (n_samples, k) @ (k, D) -> (n_samples, D)
        h_row = h_projected_coeffs @ Q.T

    else:
        raise ValueError('Method must be svd or qr')

    # The orthogonal component is the residual
    h_orth = hidden_states - h_row

    return h_row, h_orth

test_layers = list(range(0, 6))

res = [
    decompose_orthogonal(all_pre_mlp_hs[:, layer_ix, :].to(torch.float32), model.model.layers[layer_ix].mlp.gate.weight.to(torch.float32).detach().cpu(), 'svd')
    for layer_ix in tqdm(test_layers)
]

h_row_by_layer = [x[0] for x in res]
h_orth_by_layer = [x[1] for x in res]

In [None]:
def print_samples(df, grouping_cols):
    """
    Takes a wide dataframe and groups it, then prints random groups
    """
    res =\
        df\
        .groupby(grouping_cols, as_index = False)\
        .agg(
            n_samples = ('token', 'size'),
            samples = ('token', lambda s: s.sample(n = min(len(s), 10)).tolist())
        )\
        .pipe(lambda df: df[df['n_samples'] >= 5])\
        .sample(35)
    display(res)

def cluster_kmeans(layer_hs: torch.Tensor, n_clusters = 64):
    kmeans_model = cuml.cluster.KMeans(n_clusters = n_clusters, max_iter = 1000, random_state = 123)
    kmeans_model.fit(cupy.asarray(layer_hs.to(torch.float32)))
    clear_all_cuda_memory()
    return kmeans_model.labels_.tolist()

par_kmeans_res = [{'layer_ix': layer_ix, 'cluster_ids': cluster_kmeans(layer_hs, 64)} for layer_ix, layer_hs in tqdm(enumerate(h_row_by_layer))]

par_kmeans_df =\
    pd.concat([pd.DataFrame({'layer_' + str(x['layer_ix']) + '_id': x['cluster_ids']}) for x in par_kmeans_res], axis = 1)\
    .pipe(lambda df: pd.concat([df, sample_df], axis = 1))

display(par_kmeans_df.groupby('layer_1_id', as_index = False).agg(n_samples = ('token', 'size')).sort_values(by = 'n_samples', ascending = False))

print_samples(par_kmeans_df, ['layer_1_id', 'layer_2_id'])

In [None]:
orth_kmeans_res = [{'layer_ix': layer_ix, 'cluster_ids': cluster_kmeans(layer_hs, 64)} for layer_ix, layer_hs in tqdm(enumerate(h_orth_by_layer))]

orth_kmeans_df =\
    pd.concat([pd.DataFrame({'layer_' + str(x['layer_ix']) + '_id': x['cluster_ids']}) for x in orth_kmeans_res], axis = 1)\
    .pipe(lambda df: pd.concat([df, sample_df], axis = 1))

display(orth_kmeans_df.groupby('layer_1_id', as_index = False).agg(n_samples = ('token', 'size')).sort_values(by = 'n_samples', ascending = False))

print_samples(orth_kmeans_df, ['layer_1_id', 'layer_2_id'])

In [None]:
orth_kmeans_df\
    .groupby(['layer_1_id'], as_index = False)\
    .agg(
        n_samples = ('token', 'size'),
        samples = ('token', lambda s: s.sample(n = min(len(s), 20)).tolist())
    )\
    .pipe(lambda df: df[df['n_samples'] >= 5])\
    .assign(is_eq = lambda df: df.samples.apply(lambda s: 1 if len(set(s)) == 1 else 0))\
    .groupby('is_eq', as_index = False)\
    .agg(count = ('is_eq', 'count'))

In [None]:
par_kmeans_df\
    .groupby(['layer_1_id'], as_index = False)\
    .agg(
        n_samples = ('token', 'size'),
        samples = ('token', lambda s: s.sample(n = min(len(s), 20)).tolist())
    )\
    .pipe(lambda df: df[df['n_samples'] >= 5])\
    .assign(is_eq = lambda df: df.samples.apply(lambda s: 1 if len(set(s)) == 1 else 0))\
    .groupby('is_eq', as_index = False)\
    .agg(count = ('is_eq', 'count'))

In [None]:
print_samples(par_kmeans_df, ['layer_3_id'])

In [None]:
par_kmeans_df

In [None]:
print_samples(orth_kmeans_df, ['layer_3_id'])

In [None]:
par_kmeans_df

In [None]:
"""
Logistic regression - predict topk using h_orth?
"""
# Test layer 
test_layer = 5

expert_ids =\
    topk_df\
    .pipe(lambda df: df[df['layer_ix'] == test_layer])\
    .pipe(lambda df: df[df['topk_ix'] == 1])\
    ['expert'].tolist()

expert_ids_cp = cupy.asarray(expert_ids)

lr_model = cuml.linear_model.LogisticRegression(
    penalty = 'l2', 
    max_iter = 10000,
    fit_intercept = False
)
layer_hs = cupy.asarray(h_row_by_layer[test_layer].to(torch.float16).detach().cpu())
lr_model.fit(layer_hs, expert_ids_cp)
accuracy = lr_model.score(layer_hs, expert_ids_cp)
print(f"Accuracy: {accuracy:.2%}")

layer_hs = cupy.asarray(h_orth_by_layer[test_layer].to(torch.float16).detach().cpu())
lr_model.fit(layer_hs, expert_ids_cp)
accuracy = lr_model.score(layer_hs, expert_ids_cp)
print(f"Accuracy: {accuracy:.2%}")

In [None]:
sample_df.groupby('token').agg(n = ('token', 'count')).sort_values(by = 'n', ascending = False).head(30)

In [None]:
test_layer = 5

def run_lr(x_cp, y_cp):
    x_train, x_test, y_train, y_test = cuml.train_test_split(x_cp, y_cp, test_size = 0.1, random_state = 123)
    lr_model = cuml.linear_model.LogisticRegression(penalty = 'l2', max_iter = 10000, fit_intercept = False)
    lr_model.fit(x_train, y_train)
    accuracy = lr_model.score(x_test, y_test)
    print(f"Accuracy: {accuracy:.2%}")

y_df =\
    sample_df\
    .assign(is_sample = lambda df: np.where(df['token'].isin(['.', '_', ',', ':']), 1, 0))\
    ['is_sample'].tolist()

y_cp = cupy.asarray(y_df)
x_cp_para = cupy.asarray(h_row_by_layer[test_layer].to(torch.float16).detach().cpu())
x_cp_orth = cupy.asarray(h_orth_by_layer[test_layer].to(torch.float16).detach().cpu())

run_lr(x_cp_para, y_cp)
run_lr(x_cp_orth, y_cp)