In [None]:
"""
This contains code to understand the behavior of routing weights.
"""
None

In [None]:
"""
Imports
"""
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import pandas as pd
import numpy as np
import scipy
from tqdm import tqdm
from termcolor import colored
import importlib

from utils.memory import check_memory, clear_all_cuda_memory

# https://docs.rapids.ai/install/
import cupy
import cuml

import plotly.express as px
import pickle

main_device = 'cuda:0'
seed = 1234
clear_all_cuda_memory()
check_memory()

## Load base model

In [None]:
"""
Load the base tokenizer/model
"""
model_id = 'allenai/OLMoE-1B-7B-0125-Instruct'
model_prefix = 'olmoe'
tokenizer = AutoTokenizer.from_pretrained(model_id, add_eos_token = False, add_bos_token = False, padding_side = 'left', trust_remote_code = True)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype = torch.bfloat16, trust_remote_code = True).cuda().eval()

## Load data

In [None]:
"""
Load dataset
"""
def load_data(model_prefix):
    all_pre_mlp_hs = torch.load(f'data/{model_prefix}-all-pre-mlp-hidden-states.pt')
    with open(f'data/{model_prefix}-metadata.pkl', 'rb') as f:
        metadata = pickle.load(f)
    
    return all_pre_mlp_hs, metadata['sample_df'], metadata['topk_df']

all_pre_mlp_hs, sample_df_import, topk_df_import = load_data(model_prefix)

In [None]:
"""
Let's clean up the mappings here. We'll get everything to a sample_ix level first.
"""
sample_df_raw =\
    sample_df_import\
    .assign(sample_ix = lambda df: df.groupby(['batch_ix', 'sequence_ix', 'token_ix']).ngroup())\
    .assign(seq_id = lambda df: df.groupby(['batch_ix', 'sequence_ix']).ngroup())\
    .reset_index()

topk_df =\
    topk_df_import\
    .merge(sample_df_raw[['sample_ix', 'batch_ix', 'sequence_ix', 'token_ix']], how = 'inner', on = ['sequence_ix', 'token_ix', 'batch_ix'])\
    .drop(columns = ['sequence_ix', 'token_ix', 'batch_ix'])

sample_df =\
    sample_df_raw\
    .drop(columns = ['batch_ix', 'sequence_ix'])

display(topk_df)
display(sample_df)

## Analyze routing weights

In [None]:
"""
Norms by expert and layer
"""
norms_by_expert_layer = pd.concat([
    pd.DataFrame({
        'layer_ix': layer_ix,
        'norm': torch.linalg.norm(layer.mlp.gate.weight, dim = 1, ord = 1).to(torch.float16).cpu().detach().numpy(),
        'expert': list(range(1, layer.mlp.gate.weight.shape[0] + 1))
    })
    for layer_ix, layer in enumerate(model.model.layers)
])

plot_df = norms_by_expert_layer.pivot(index = 'layer_ix', columns = 'expert', values = 'norm')
px.imshow(
    plot_df,
    x = plot_df.columns, y = plot_df.index,
    color_continuous_scale = 'ylgnbu',
    labels = {'color': 'Norm'}, title = "Norm by Expert and Layer"
).update_layout(autosize = False, width = 800).show()

scaled_df =\
    norms_by_expert_layer\
    .assign(layer_mean = lambda df: df.groupby('layer_ix')['norm'].transform('mean'))\
    .assign(norm_scaled = lambda df: df['norm'] / df['layer_mean'] - 1)

scaled_plot_df = scaled_df.pivot(index = 'layer_ix', columns = 'expert', values = 'norm_scaled')
px.imshow(
    scaled_plot_df,
    x = scaled_plot_df.columns, y = scaled_plot_df.index,
    color_continuous_scale = 'ylgnbu',
    labels = {'color': 'Norm'}, title = "Norm by Expert and Layer"
).update_layout(autosize = False, width = 800).show()

In [None]:
"""
For a single layer, what do the weights and RMSnorms look like?
"""
plot_layer_ix = 9
show_dims = list(range(0, 400))

# RMSNorm
rms_tensor = model.model.layers[plot_layer_ix].post_attention_layernorm.weight
rms_df = pd.DataFrame({
    'gamma': rms_tensor.to(torch.float16).cpu().detach().numpy(),
    'coef': 1,
    'dimension': list(range(0, rms_tensor.shape[0]))
})
plot_df = rms_df.pipe(lambda df: df[df['dimension'].isin(show_dims)]).pivot(index = 'coef', columns = 'dimension', values = 'gamma')

px.imshow(
    plot_df,
    x = plot_df.columns, y = plot_df.index,
    aspect = 'auto', color_continuous_scale = 'ylgnbu',
    labels = {'color': 'Norm'}, title = "RMSNorm Scaling Values"
).update_layout(autosize = False, width = 1400, height = 400).show()


# Weights
wt_tensor = model.model.layers[plot_layer_ix].mlp.gate.weight
wt_df = pd.DataFrame({
    'value': wt_tensor.view(-1).to(torch.float16).cpu().detach().numpy(),
    'expert': [i // wt_tensor.shape[1] for i in range(wt_tensor.view(-1).shape[0])],
    'dimension': [i % wt_tensor.shape[1] for i in range(wt_tensor.view(-1).shape[0])]
})

plot_df = wt_df.pipe(lambda df: df[df['dimension'].isin(show_dims)]).pivot(index = 'expert', columns = 'dimension', values = 'value')

px.imshow(
    plot_df,
    x = plot_df.columns, y = plot_df.index,
    aspect = 'auto', color_continuous_scale = 'ylgnbu',
    labels = {'color': 'Norm'}, title = "Routing Weights"
).update_layout(autosize = False, width = 1400).show()

# Scale weights by RMSNorm
scaled_df = wt_df.merge(rms_df, on = 'dimension', how = 'inner').assign(gamma_scaled_value = lambda df: df['gamma'] * df['value'])
plot_df = scaled_df.pipe(lambda df: df[df['dimension'].isin(show_dims)]).pivot(index = 'expert', columns = 'dimension', values = 'gamma_scaled_value')
px.imshow(
    plot_df,
    x = plot_df.columns, y = plot_df.index,
    aspect = 'auto', color_continuous_scale = 'ylgnbu',
    labels = {'color': 'Norm'}, title = "Scaled Routing Weights"
).update_layout(autosize = False, width = 1400).show()

In [None]:
"""
Mean norms across layers and dimension (averaged across experts)
"""
dfs_list = []
for layer_ix, layer in enumerate(model.model.layers):
    wt_tensor = layer.mlp.gate.weight.to(torch.float16).cpu().detach()
    rms_tensor = layer.post_attention_layernorm.weight.to(torch.float16).cpu().detach()
    scaled = (wt_tensor * rms_tensor) # Mltiply by RMS norm
    scaled = scaled.abs().mean(dim = 0) # Take mean L1 norm
    dfs_list.append(pd.DataFrame({
        'mean_norm': scaled.numpy(),
        'layer_ix': layer_ix,
        'dim': list(range(1, scaled.shape[0] + 1))
    }))

my_df = pd.concat(dfs_list)
# Additionally scale by layer average
my_df_ex_scale =\
    my_df\
    .assign(layer_mean = lambda df: df.groupby('layer_ix')['mean_norm'].transform('mean'))\
    .assign(mean_norm = lambda df: df['mean_norm'] / df['layer_mean'])

plot_df = my_df_ex_scale.pipe(lambda df: df[df['dim']  <= 200]).pivot(index = 'layer_ix', columns = 'dim', values = 'mean_norm')

px.imshow(
    plot_df,
    x = plot_df.columns, y = plot_df.index,
    zmin = 0, zmax = 8,
    aspect = 'auto', # Allow non-square boxes
    color_continuous_scale = 'ylgnbu',
    labels = {'color': 'Norm'}, title = "Mean norms by dimension and layer"
).update_layout(autosize = False, width = 1400, coloraxis = dict()).show()


In [None]:
"""
At dimension x layer-level, analyze activations (averaged across samples) versus routing weights (averaged across experts).
"""
show_dims = list(range(0, 800))

pre_mlp_for_layer_norms_all = all_pre_mlp_hs[0:500_000, :, :].abs().mean(dim = 0) # Collapse to n_layers x D dimensional

dfs_list = []
for layer_ix, pre_mlp_for_layer_norm in enumerate(pre_mlp_for_layer_norms_all.unbind(dim = 0)):
    wt_tensor = model.model.layers[layer_ix].mlp.gate.weight[:, :].to(torch.float16).cpu().detach() # (n_experts, D)
    act_tensor = pre_mlp_for_layer_norm.cpu().detach() # D-dimensional
    scaled = (wt_tensor * act_tensor) # Multiply by activation tensor
    scaled = scaled.abs().mean(dim = 0) # Take mean L1 norm
    dfs_list.append(pd.DataFrame({
        'act_norm': act_tensor.numpy(),
        'wt_norm': wt_tensor.abs().mean(dim = 0),
        'mean_scaled_norm': scaled.numpy(),
        'layer_ix': layer_ix,
        'dim': list(range(1, scaled.shape[0] + 1)) # show_dims
    }))

pre_mlp_df = pd.concat(dfs_list)
del dfs_list, pre_mlp_for_layer_norms_all

plot_df = pre_mlp_df.pipe(lambda df: df[df['dim'].isin(show_dims)]).pivot(index = 'layer_ix', columns = 'dim', values = 'mean_scaled_norm')
px.imshow(
    plot_df,
    x = plot_df.columns, y = plot_df.index,
    zmin = 0,
    zmax = .2,
    aspect = 'auto', color_continuous_scale = 'ylgnbu',
    labels = {'color': 'Norm'}, title = "Mean scaled wt * activation norms by dimension and layer"
).update_layout(autosize = False, width = 1400, coloraxis = dict()).show()

plot_df = pre_mlp_df.pipe(lambda df: df[df['dim'].isin(show_dims)]).pivot(index = 'layer_ix', columns = 'dim', values = 'act_norm')
px.imshow(
    plot_df,
    zmin = 0, zmax = 8,
    x = plot_df.columns, y = plot_df.index,
    aspect = 'auto', color_continuous_scale = 'ylgnbu',
    labels = {'color': 'Norm'}, title = "Mean activation norms by dimension and layer"
).update_layout(autosize = False, width = 1400, coloraxis = dict()).show()

plot_df = pre_mlp_df.pipe(lambda df: df[df['dim'].isin(show_dims)]).pivot(index = 'layer_ix', columns = 'dim', values = 'wt_norm')
px.imshow(
    plot_df,
    x = plot_df.columns, y = plot_df.index,
    aspect = 'auto', color_continuous_scale = 'ylgnbu',
    labels = {'color': 'Norm'}, title = "Mean weight norms by dimension and layer"
).update_layout(autosize = False, width = 1400, coloraxis = dict()).show()

# Plot activation vs routing weight norms
px.scatter(
    pre_mlp_df.pipe(lambda df: df[df['layer_ix'] == 6]),
    x = 'wt_norm', y = 'act_norm', color = 'wt_norm',
    log_y = True,
    log_x = True,
    color_continuous_scale = 'viridis', title = 'Per-Dimension Plot - Activation L1 Norm versus Routing Weight L1 Norm'
).update_layout(autosize = False, width = 1400, coloraxis = dict()).show()

scipy.stats.kurtosis(pre_mlp_df.pipe(lambda df: df[df['layer_ix'] == 6]['act_norm'].tolist() ))

In [None]:
"""
Linear regression - test ability to reconstruct topk expert id
"""
layer_to_test = 5

expert_ids =\
    topk_df\
    .pipe(lambda df: df[df['layer_ix'] == layer_to_test])\
    .pipe(lambda df: df[df['topk_ix'] == 1])\
    ['expert'].tolist()

expert_ids_cp = cupy.asarray(expert_ids)

lr_model = cuml.linear_model.LogisticRegression(
    penalty = 'l2', 
    max_iter = 10000,
    fit_intercept = False
)

dims = [
    x - 1
    for x in pre_mlp_df.pipe(lambda df: df[df['layer_ix'] == layer_to_test]).sort_values(by = 'mean_scaled_norm', ascending = False)['dim'].tolist()
]

layer_hs = cupy.asarray(all_pre_mlp_hs[:, layer_to_test, dims[0:all_pre_mlp_hs.shape[2]//50]].to(torch.float16).detach().cpu())
lr_model.fit(layer_hs, expert_ids_cp)
accuracy = lr_model.score(layer_hs, expert_ids_cp)
print(f"Accuracy: {accuracy:.2%}")

rand_dims = [int(x - 1) for x in np.random.choice(pre_mlp_df['dim'].tolist(), size = all_pre_mlp_hs.shape[2] // 50, replace = False)]
rand_hs = cupy.asarray(all_pre_mlp_hs[:, layer_to_test, rand_dims].to(torch.float16).detach().cpu())
lr_model.fit(rand_hs, expert_ids_cp)
accuracy = lr_model.score(layer_hs, expert_ids_cp)
print(f"Baseline accuracy: {accuracy:.2%}")

In [None]:
"""
Compare PCA top dimensinos versus scaled activation top dimensions
"""
layer_hs = cupy.asarray(all_pre_mlp_hs[0:200_000, layer_to_test, :].to(torch.float16).detach().cpu())
mean_vals = cupy.mean(layer_hs, axis=0)
std_vals = cupy.std(layer_hs, axis=0)
std_vals = cupy.where(std_vals == 0, cupy.asarray(1e-7), std_vals)
layer_hs_std = (layer_hs - mean_vals)/std_vals

pca = cuml.decomposition.PCA(n_components = 10, random_state = 123)
pca.fit(layer_hs_std)

pc_loadings = pca.components_
sumsq = (pc_loadings ** 2).sum(axis=0)

ranking = cupy.argsort(-sumsq)  # descending order
pca_top_dims = ranking.tolist()

plot_df =\
    pd.DataFrame({'pca_sumsq': cupy.asarray(sumsq).tolist(), 'dim': list(range(1, len(sumsq) + 1))})\
    .merge(
        pre_mlp_df.pipe(lambda df: df[df['layer_ix'] == layer_to_test])[['dim', 'mean_scaled_norm']],
        on = 'dim',
        how = 'inner'
    )

px.scatter(
    plot_df,
    x = 'mean_scaled_norm',
    y = 'pca_sumsq',
    log_y = True,
    log_x = True
).show()

In [None]:
"""
What % of hidden states is explained by PCA?
"""
# 1) Gather some data
clear_all_cuda_memory()
layer_hs = cupy.asarray(all_pre_mlp_hs[0:200_000, 5, :].to(torch.float16).detach().cpu())

# 2) Fit PCA
pca_model = cuml.PCA(iterated_power = 20, n_components = 10, verbose = True)
pca_model.fit(layer_hs)

print("Explained variance ratio:", pca_model.explained_variance_ratio_)
print("Cumulative ratio:", np.cumsum(pca_model.explained_variance_ratio_.get())[-1])

# 3) Retrieve components & variance ratio
components = pca_model.components_.get()  # shape = (10, D)
expl_ratios = pca_model.explained_variance_ratio_.get()  # shape = (10,)

# 4) Compute dimension-level importance
sq_loadings = components**2        # shape (10, D)
dim_importance = sq_loadings.T @ expl_ratios   # shape (D,)

# 5) Identify top 20 dims
top_k = 10
idx_sorted = np.argsort(dim_importance)[::-1]
top_dims = idx_sorted[:top_k]
sum_top = dim_importance[top_dims].sum()
sum_all = dim_importance.sum()
frac_top = sum_top / sum_all

print(f"Top {top_k} dims by PCA-based importance: {top_dims}")
print(f"Sum of their importances: {sum_top:.4f}")
print(f"Fraction of total importance: {frac_top:.4f}")

## Load balancing

In [None]:
unique_layers = np.array(sorted(list(set(topk_df['layer_ix']))))
unique_experts = np.array(sorted(list(set(topk_df['expert'])))) 

topk_grouped_0 =\
    topk_df.groupby(['layer_ix', 'expert'], as_index = False)\
    .agg(
        token_count = ('sample_ix', 'nunique'), # count distinct tokens
        weight_sum = ('weight', 'sum') # sum of gating weights
    )

pd.merge(
    pd.DataFrame({'layer_ix': unique_layers}),
    pd.DataFrame({'expert': unique_experts}),
    how = 'cross'
)\
.merge(topk_grouped_0, how = 'left', on = ['layer_ix', 'expert'])\
.assign(
    token_count = lambda df: df['token_count'].fillna(0),
    weight_sum = lambda df: df['weight_sum'].fillna(0)
)

In [None]:
"""
Calculate load balancing metrics
"""
topk_grouped_0 =\
    topk_df\
    .pipe(lambda df: df[df['topk_ix'] == 1])\
    .groupby(['layer_ix', 'expert'], as_index = False)\
    .agg(
        token_count = ('sample_ix', 'nunique'), # count distinct tokens
        weight_sum = ('weight', 'sum') # sum of gating weights
    )

unique_layers = np.array(sorted(list(set(topk_df['layer_ix']))))
unique_experts = np.array(sorted(list(set(topk_df['expert'])))) 

# Fill in missing expert/layers
topk_grouped =\
    pd.merge(
        pd.DataFrame({'layer_ix': unique_layers}),
        pd.DataFrame({'expert': unique_experts}),
        how = 'cross'
    )\
    .merge(topk_grouped_0, how = 'left', on = ['layer_ix', 'expert'])\
    .assign(
        token_count = lambda df: df['token_count'].fillna(0),
        weight_sum = lambda df: df['weight_sum'].fillna(0)
    )\
    .assign(
        layer_token_sums = lambda df: df.groupby('layer_ix')['token_count'].transform('sum'), # fraction of tokens that pick (layer, expert)
        layer_weight_sums = lambda df: df.groupby('layer_ix')['weight_sum'].transform('sum'),
        token_frac = lambda df: df['token_count'] / df['layer_token_sums'],
        weight_frac = lambda df: df['weight_sum'] / df['layer_weight_sums']
    )

def shannon_entropy(probs):
    # Avoid log(0)
    probs = probs[probs > 0]
    return -np.sum(probs * np.log2(probs))

entropies = []
for layer, layer_df in topk_grouped.groupby('layer_ix'):
    token_entropy = shannon_entropy(layer_df['token_frac'].values)
    weight_entropy = shannon_entropy(layer_df['weight_frac'].values)
    entropies.append({
        'layer_ix': layer,
        'token_entropy': token_entropy,
        'weight_entropy': weight_entropy
    })
entropy_df = pd.DataFrame(entropies)

def kl_divergence(p, q):
    mask = (p > 0) & (q > 0)
    return np.sum(p[mask] * np.log2(p[mask]/q[mask]))

kl_list = []
for layer, layer_df in topk_grouped.groupby('layer_ix'):
    p_token = layer_df['token_frac'].values    
    q = np.full_like(p_token, 1/len(p_token))
    
    token_kl = kl_divergence(p_token, q)
    weight_kl = kl_divergence(layer_df['weight_frac'].values, q)
    
    kl_list.append({
        'layer_ix': layer,
        'token_kl': token_kl,
        'weight_kl': weight_kl
    })
kl_df = pd.DataFrame(kl_list)

px.line(
    kl_df,
    x = 'layer_ix', y = ['weight_kl', 'token_kl'],
    title = 'KL Divergence from Uniform'
).update_layout(autosize = False, width = 800, height = 400).show()

px.line(
    entropy_df,
    x = 'layer_ix', y = ['weight_entropy', 'token_entropy'],
    title = 'Shannon Entropy'
).update_layout(autosize = False, width = 800, height = 400).show()

px.line(
    topk_grouped.pipe(lambda df: df[df['expert'].isin(list(range(0, 100)))]),
    x = 'layer_ix',
    y = 'token_count',
    color = 'expert'
).update_layout(autosize = False, width = 800, height = 400).show()