In [None]:
"""
This contains code to use SVD to decompose hidden states based on whether they're used by routing or not.
"""
None

In [None]:
"""
Imports
"""
import torch
import torch.nn.functional as F
import pandas as pd
import numpy as np
import cupy
import cuml

import importlib
import gc
import pickle
import os
import regex
import scipy

from tqdm import tqdm
import plotly.express as px
from plotly.subplots import make_subplots
from safetensors.torch import load_file

from utils.memory import check_memory, clear_all_cuda_memory
from utils.loader import load_model_and_tokenizer
from utils.quantize import compare_bf16_fp16_batched
from utils.svd import decompose_orthogonal, decompose_sideways, get_svd_proj

main_device = 'cuda:0'
seed = 1234

clear_all_cuda_memory()
check_memory()

ws = '/workspace/interpretable-moes-analysis'
svd_dir = f'{ws}/experiments/geometry/svd'

## Load model & data

In [None]:
"""
Load the base tokenizer/model
"""
model_prefix = 'gpt-oss-20b'
tokenizer, model, model_architecture, model_n_moe_layers, model_n_dense_layers = load_model_and_tokenizer(model_prefix, device = main_device)

In [None]:
"""
Get 0-indexed MoE layer indices
"""
with open(f'{ws}/experiments/geometry/activations/{model_prefix}/metadata.pkl', 'rb') as f:
    layer_indices = pickle.load(f).get('layer_mappings')

layer_indices

In [None]:
"""
Get projection matrices
"""
def get_router_objects(model, model_prefix, model_n_dense_layers, model_n_moe_layers):
    v_mats = {}
    v_mats_demeaned = {}
    router_mats = {}

    for layer_ix in tqdm(range(0, model_n_dense_layers + model_n_moe_layers)):
        if layer_ix not in layer_indices:
            continue
        if model_prefix == 'kimivl': 
            router_mat = model.language_model.model.layers[layer_ix].mlp.gate.weight
        elif model_prefix == 'granite':
            router_mat = model.model.layers[layer_ix].block_sparse_moe.router.layer.weight
        else:
            router_mat = model.model.layers[layer_ix].mlp.router.weight

        router_mat = router_mat.detach().cpu().to(torch.float32)
        router_mat = router_mat - router_mat.mean(dim = 0, keepdim = True) # Center logits
        _, _, V = get_svd_proj(router_mat)
        _, _, V_demeaned = get_svd_proj(router_mat - router_mat.mean(dim = 0, keepdim = True))
        
        v_mats[layer_ix] = V
        v_mats_demeaned[layer_ix] = V_demeaned
        router_mats[layer_ix] = router_mat

    return v_mats, v_mats_demeaned, router_mats

v_mats, v_mats_demeaned, router_mats = get_router_objects(model, model_prefix, model_n_dense_layers, model_n_moe_layers)

In [None]:
"""
Load dataset
"""
def load_data(model_prefix, max_data_files):
    """
    Load data saved by `export-activations.ipynb`
    """
    folders = [f'{ws}/experiments/geometry/activations/{model_prefix}/{i:02d}' for i in range(max_data_files)]
    folders = [f for f in folders if os.path.isdir(f)]

    sample_parts = []
    topk_parts = []
    hs_parts = []

    for f in tqdm(folders):
        sample_parts.append(pd.read_feather(f'{f}/samples.feather'))
        topk_parts.append(pd.read_feather(f'{f}/topks.feather'))
        tensors = load_file(f'{f}/activations.safetensors', device = 'cpu')
        hs_parts.append(tensors['all_pre_mlp_hs'])

    sample_df = pd.concat(sample_parts, ignore_index = True)
    topk_df = pd.concat(topk_parts, ignore_index = True)
    pre_mlp_hs = torch.concat(hs_parts, dim = 0)

    prompts_df = pd.read_feather(f'{ws}/experiments/geometry/activations/{model_prefix}/prompts.feather')
    
    gc.collect()
    return prompts_df, sample_df, topk_df, pre_mlp_hs

# 5 except glm4moe=3
prompts_df, sample_df_import, topk_df_import, all_pre_mlp_hs_import = load_data(model_prefix, 3)

In [None]:

"""
Let's clean up the mappings here. We'll get everything to a sample_ix level first.
"""
sample_df =\
    sample_df_import\
    .assign(sample_ix = lambda df: range(0, len(df)))\
    .reset_index()

topk_df =\
    topk_df_import\
    .merge(sample_df[['sample_ix', 'prompt_ix', 'token_ix']], how = 'inner', on = ['prompt_ix', 'token_ix'])\
    .drop(columns = ['token_ix'])\
    .assign(layer_ix = lambda df: df['layer_ix'] + model_n_dense_layers)

def get_sample_df_for_layer(sample_df, topk_df, layer_ix):
    """
    Helper to take the sample df and merge layer-level expert selection information
    """
    topk_layer_df = topk_df.pipe(lambda df: df[df['layer_ix'] == layer_ix])
    topk_l1_layer_df = topk_df.pipe(lambda df: df[df['layer_ix'] == layer_ix - 1])
    topk_l2_layer_df = topk_df.pipe(lambda df: df[df['layer_ix'] == layer_ix - 2])

    layer_df =\
        sample_df\
        .merge(topk_layer_df.pipe(lambda df: df[df['topk_ix'] == 1])[['sample_ix', 'expert']], how = 'inner', on = 'sample_ix')\
        .merge(topk_l1_layer_df.pipe(lambda df: df[df['topk_ix'] == 1]).rename(columns = {'expert': 'prev_expert'})[['sample_ix', 'prev_expert']], how = 'left', on = 'sample_ix')\
        .merge(topk_l2_layer_df.pipe(lambda df: df[df['topk_ix'] == 1]).rename(columns = {'expert': 'prev2_expert'})[['sample_ix', 'prev2_expert']], how = 'left', on = 'sample_ix')\
        .merge(topk_layer_df.pipe(lambda df: df[df['topk_ix'] == 2]).rename(columns = {'expert': 'expert2'})[['sample_ix', 'expert2']], how = 'left', on = 'sample_ix')\
        .assign(leading_path = lambda df: df['prev2_expert'] + '-' + df['prev_expert'])
    
    return layer_df

del sample_df_import, topk_df_import

gc.collect()
display(topk_df)
display(sample_df)

In [None]:
"""
Convert activations to fp16 (for compatibility with cupy later) + dict
"""
all_pre_mlp_hs = all_pre_mlp_hs_import.to(torch.float16)
# compare_bf16_fp16_batched(all_pre_mlp_hs_import, all_pre_mlp_hs)
del all_pre_mlp_hs_import
all_pre_mlp_hs = {(layer_ix + model_n_dense_layers): all_pre_mlp_hs[:, save_ix, :] for save_ix, layer_ix in enumerate(layer_indices)}
gc.collect()

# Trace calculations

In [None]:
"""
How much the routing matrix itself rotates
"""
def get_proj_overlap(layer_ix):
    M = (v_mats[layer_ix].T  @ v_mats[layer_ix + 1])
    _, S, _ = get_svd_proj(M, svd_tol = 0)
    S_quantiles = np.quantile(S, [0, .25, .5, .75, 1])
    S = [round(x, 2) for x in S_quantiles.tolist()]
    overlap = (M * M).sum() / min(v_mats[layer_ix].shape[1], v_mats[layer_ix + 1].shape[1])
    rd_ratio = v_mats[layer_ix].shape[1] / model.config.hidden_size
    return {'overlap': round(overlap.item(), 2), 'rank': v_mats[layer_ix].shape[1], 'rd_ratio': round(rd_ratio, 2), 'sv_quantiles': S}

for layer_ix in list(all_pre_mlp_hs.keys())[:-1]:
    proj_overlap = get_proj_overlap(layer_ix)
    if layer_ix % 2 == 0:
        print(f'Layer {layer_ix}: {proj_overlap}')

# Cosine similarity

In [None]:
"""
Blind/vis helper
"""
def get_vis_blind(hl, vmat, type = 'f32'):
    if type == 'f32':
        hlvis = (hl.float() @ vmat) @ vmat.T
    elif type == 'f16':
        hlvis = (hl @ vmat.to(torch.float16)) @ vmat.to(torch.float16).T
    else:
        raise Exception("Type must be 'bf16' or 'f32'")

    hlblind = hl - hlvis
    return hlvis, hlblind

get_vis_blind(all_pre_mlp_hs[10], v_mats[10])

In [None]:
"""
E1: Get standard cosine sims cos(h^{vis}_{l}, h^{vis}_{l+1}), cos(h^{blind}_{l}, h^{blind}_{l+1})
"""
def get_cos_sim(all_pre_mlp_hs, v_mats, layer_ix, k_ahead = 1, samples = -1):
    g = torch.Generator().manual_seed(123)    
    if samples == -1:
        sample_ixs = list(range(0, all_pre_mlp_hs[layer_ix].shape[0]))
    else:
        sample_ixs = torch.randint(low = 0, high = all_pre_mlp_hs[layer_ix].shape[0], size = (samples,), generator = g)

    h_vis_l, h_blind_l = get_vis_blind(all_pre_mlp_hs[layer_ix][sample_ixs, :], v_mats[layer_ix])
    h_vis_l1, h_blind_l1 = get_vis_blind(all_pre_mlp_hs[layer_ix + k_ahead][sample_ixs, :], v_mats[layer_ix + k_ahead])

    full_cos_sim = F.cosine_similarity(all_pre_mlp_hs[layer_ix][sample_ixs, :], all_pre_mlp_hs[layer_ix + k_ahead][sample_ixs, :], dim = -1).mean().item()
    para_cos_sim = F.cosine_similarity(h_vis_l, h_vis_l1, dim = -1).mean().item()
    orth_cos_sim = F.cosine_similarity(h_blind_l, h_blind_l1, dim = -1).mean().item()
    
    perm = torch.randperm(h_vis_l.shape[0], generator = g)
    full_shuf_cos_sim = F.cosine_similarity(all_pre_mlp_hs[layer_ix][sample_ixs, :], all_pre_mlp_hs[layer_ix + k_ahead][sample_ixs, :][perm], dim=-1).mean().item()
    para_shuf_cos_sim = F.cosine_similarity(h_vis_l, h_vis_l1[perm], dim=-1).mean().item()
    orth_shuf_cos_sim = F.cosine_similarity(h_blind_l, h_blind_l1[perm], dim=-1).mean().item()

    dm = lambda x: x - x.mean(0)
    para_dm = F.cosine_similarity(dm(h_vis_l), dm(h_vis_l1), dim = -1).mean().item()
    orth_dm = F.cosine_similarity(dm(h_blind_l), dm(h_blind_l1), dim = -1).mean().item()
    para_dm_shuf = F.cosine_similarity(dm(h_vis_l), dm(h_vis_l1)[perm], dim = -1).mean().item()
    orth_dm_shuf = F.cosine_similarity(dm(h_blind_l), dm(h_blind_l1)[perm], dim = -1).mean().item()

    return {
        'h': round(full_cos_sim, 2), 
        'hvis': round(para_cos_sim, 2),
        'hblind': round(orth_cos_sim, 2),
        'hvis_dm': round(para_dm, 2),
        'hblind_dm': round(orth_dm, 2),
        'h_shuf': round(full_shuf_cos_sim, 2),
        'hvis_shuf': round(para_shuf_cos_sim, 2),
        'hblind_shuf': round(orth_shuf_cos_sim, 2),
        'hvis_dm_shuf': round(para_dm_shuf, 2),
        'hblind_dm_shuf': round(orth_dm_shuf, 2),
    }

k_ahead = 1
for layer_ix in list(all_pre_mlp_hs.keys())[:-k_ahead]:
    if layer_ix % 2 == 0:
        cos_sims = get_cos_sim(all_pre_mlp_hs, v_mats, layer_ix, k_ahead = k_ahead, samples = 10_000)
        print(f'Layer {layer_ix}: {cos_sims}')

In [None]:
"""
E2: Get projection-controlled cosine sims cos(h_l P_l1, h_l1 P_l1), cos(h_l (I-P_l1), h_l1 (I-P_l1))
"""
def get_cos_sim(all_pre_mlp_hs, v_mats, layer_ix, k_ahead = 1, samples = -1):

    g = torch.Generator().manual_seed(seed)
    
    if samples == -1:
        sample_ixs = list(range(0, all_pre_mlp_hs[layer_ix].shape[0]))
    else:
        sample_ixs = torch.randint(low = 0, high = all_pre_mlp_hs[layer_ix].shape[0], size = (samples,), generator = g)

    ctrl_l, rest_l = get_vis_blind(all_pre_mlp_hs[layer_ix][sample_ixs, :], v_mats[layer_ix + k_ahead])
    ctrl_l1, rest_l1 = get_vis_blind(all_pre_mlp_hs[layer_ix + k_ahead][sample_ixs, :], v_mats[layer_ix + k_ahead])

    ctrl_cos_sim = F.cosine_similarity(ctrl_l, ctrl_l1, dim = -1).mean().item()
    rest_cos_sim = F.cosine_similarity(rest_l, rest_l1, dim = -1).mean().item()

    perm = torch.randperm(ctrl_l.shape[0], generator = g)
    ctrl_shuf = F.cosine_similarity(ctrl_l, ctrl_l1[perm], dim = -1).mean().item()
    rest_shuf = F.cosine_similarity(rest_l, rest_l1[perm], dim = -1).mean().item()

    dm = lambda x: x - x.mean(0)
    ctrl_dm = F.cosine_similarity(dm(ctrl_l), dm(ctrl_l1), dim = -1).mean().item()
    rest_dm = F.cosine_similarity(dm(rest_l), dm(rest_l1), dim = -1).mean().item()
    ctrl_dm_shuf = F.cosine_similarity(dm(ctrl_l), dm(ctrl_l1)[perm], dim = -1).mean().item()
    rest_dm_shuf = F.cosine_similarity(dm(rest_l), dm(rest_l1)[perm], dim = -1).mean().item()

    # No base and base_shuf since same as E1
    return {
        'ctrl': round(ctrl_cos_sim, 2),
        'rest': round(rest_cos_sim, 2),
        'ctrl_dm': round(ctrl_dm, 2),
        'rest_dm': round(rest_dm, 2),
        'ctrl_shuf': round(ctrl_shuf, 2),
        'rest_shuf': round(rest_shuf, 2),
        'ctrl_dm_shuf': round(ctrl_dm_shuf, 2),
        'rest_dm_shuf': round(rest_dm_shuf, 2),
    }

k_ahead = 1
for layer_ix in list(all_pre_mlp_hs.keys())[:-k_ahead]:
    if layer_ix % 2 == 0:
        cos_sims = get_cos_sim(all_pre_mlp_hs, v_mats, layer_ix, k_ahead = k_ahead, samples = 10_000)
        print(f'Layer {layer_ix}: {cos_sims}')

In [None]:
"""
E3: Bootstrap quantiles + matched-token cosine similarity (E1 but bootstrapped)
"""
def precompute_token_cos_sims(all_pre_mlp_hs, v_mats, k_ahead = 1):
    """
    Precompute per-token cosine similarities for all three channels across all (l, l+1) pairs.
    Returns a dict keyed by layer index, each holding (N, ) tensors for 'full', 'vis', 'blind'.
    """
    layer_indices = sorted(all_pre_mlp_hs.keys())
    results = {}
    g = torch.Generator().manual_seed(seed)
    for layer_ix in tqdm(layer_indices):
        next_layer = layer_ix + k_ahead
        if next_layer not in all_pre_mlp_hs:
            continue
        hl = all_pre_mlp_hs[layer_ix]
        hl1 = all_pre_mlp_hs[next_layer]
        h_vis_l, h_blind_l = get_vis_blind(hl, v_mats[layer_ix], type = 'f16')
        h_vis_l1, h_blind_l1 = get_vis_blind(hl1, v_mats[next_layer], type = 'f16')
        # Shuffled indices for random pairing
        n = hl.shape[0]
        shuf = torch.randperm(n, generator = g)
        results[layer_ix] = {
            # s'full': F.cosine_similarity(hl, hl1, dim = -1),
            'vis': F.cosine_similarity(h_vis_l, h_vis_l1, dim = -1),
            'blind': F.cosine_similarity(h_blind_l, h_blind_l1, dim = -1),
            'vis_rand':  F.cosine_similarity(h_vis_l, h_vis_l1[shuf], dim = -1),
            'blind_rand': F.cosine_similarity(h_blind_l, h_blind_l1[shuf], dim = -1),
        }
    return results

def get_bootstrap_cos_sims(cos_sims_by_layer, doc_ids, bs_samples = 200, samples_per_bs = 1_000):
    """
    Block bootstrap over precomputed per-token cosine similarities.
    Params:
        @cos_sims_by_layer: output of precompute_token_cos_sims
        @doc_ids: list[int] of prompt/document ID per token (same order as sample_df)
        @bs_samples: # of bootstrap iterations
        @samples_per_bs: target tokens per bootstrap draw (sample prompts until we reach this threshold)
    Returns:
        layer_indices: sorted list of layer indices (one per pair)
        output: dict with keys 'full', 'vis', 'blind', each a tuple (means, cis_lo, cis_hi) as np arrays of shape (n_pairs,)
    """
    g = torch.Generator().manual_seed(123)
    layer_indices = sorted(cos_sims_by_layer.keys())
    n_pairs = len(layer_indices)
    channels = ['vis', 'blind', 'vis_rand', 'blind_rand']
    docs = torch.as_tensor(doc_ids)
    uniq = docs.unique()
    idxs_by_doc = [(docs == d).nonzero(as_tuple = True)[0] for d in uniq]
    results = {ch: torch.empty((bs_samples, n_pairs), dtype = torch.float) for ch in channels}
    for b in tqdm(range(bs_samples)):
        take_idxs = []
        total = 0
        while total < samples_per_bs:
            j = int(torch.randint(low = 0, high = len(idxs_by_doc), size=(1,), generator = g))
            take_idxs.append(idxs_by_doc[j])
            total += idxs_by_doc[j].numel()
        sample_indices = torch.cat(take_idxs)[:samples_per_bs]
        for ch in channels:
            vals = torch.stack([cos_sims_by_layer[l][ch][sample_indices] for l in layer_indices], dim = 1)
            results[ch][b] = vals.mean(dim=0)
    output = []
    for i, layer_ix in enumerate(layer_indices):
        row = {'layer': layer_ix}
        for ch in channels:
            col = results[ch][:, i]
            row[f'{ch}_mean'] = round(col.mean().item(), 3)
            lo, hi = torch.quantile(col, torch.tensor([0.025, 0.975], dtype=col.dtype))
            row[f'{ch}_lo'] = round(lo.item(), 3)
            row[f'{ch}_hi'] = round(hi.item(), 3)
        output.append(row)
    return output

cos_sims_by_layer = precompute_token_cos_sims(all_pre_mlp_hs, v_mats, k_ahead = 1)

bs_results = get_bootstrap_cos_sims(
    cos_sims_by_layer,
    doc_ids = sample_df['prompt_ix'].tolist(),
    bs_samples = 200,
    samples_per_bs = 100,
)

for row in bs_results:
    print(
        f"Layer {row['layer']:>3d}: "
        f"vis={row['vis_mean']:.2f} [{row['vis_lo']:.2f}, {row['vis_hi']:.2f}] (rand={row['vis_rand_mean']:.2f}) | "
        f"blind={row['blind_mean']:.2f} [{row['blind_lo']:.2f}, {row['blind_hi']:.2f}] (rand={row['blind_rand_mean']:.2f})"
        # f"full={row['full_mean']:.3f} [{row['full_lo']:.3f}, {row['full_hi']:.3f}] (rand={row['full_rand_mean']:.3f})"
    )

In [None]:
"""
Export
"""
export_df = pd.DataFrame({
    'layer_ix_1': list(range(model_n_dense_layers + 1, len(all_pre_mlp_hs) + model_n_dense_layers)), # +1 to 1 index
    'para_mean_across_layers': para_means,
    'orth_mean_across_layers': orth_means,
    'para_cis_hi': para_cis_hi,
    'para_cis_lo': para_cis_lo,
    'orth_cis_hi': orth_cis_hi,
    'orth_cis_lo': orth_cis_lo
})

display(export_df)

export_df.to_csv(f'{svd_dir}/svd-transition-stability-{model_prefix}.csv', index = False)

## Reconstruction/probing tests

In [None]:
"""
LR helpers
"""
def run_probe(x_cp, y_cp):
    """
    Fit an LR probe + get labels / predictions / accuracy
    """
    x_train, x_test, y_train, y_test = cuml.train_test_split(x_cp, y_cp, test_size = 0.2, random_state = 123)
    lr_model = cuml.linear_model.LogisticRegression(penalty = 'l2', max_iter = 500, fit_intercept = True)
    lr_model.fit(x_train, y_train)
    y_hat = lr_model.predict(x_test)

    y_test_np = cupy.asnumpy(y_test)
    y_hat_np = cupy.asnumpy(y_hat)

    acc, acc_lo, acc_hi = get_acc_with_ci(y_test_np, y_hat_np)
    nmi, nmi_lo, nmi_hi = get_nmi_with_ci(y_test_np, y_hat_np)

    return acc, acc_lo, acc_hi, nmi, nmi_lo, nmi_hi

# def run_probe_with_mi(x_cp, y_cp):
#     """
#     Fit an LR probe; return normalized MI
#     """
#     x_train, x_test, y_train, y_test = cuml.train_test_split(x_cp, y_cp, test_size = 0.2, random_state = 123)
#     lr_model = cuml.linear_model.LogisticRegression(penalty = 'l2', max_iter = 100, fit_intercept = True)
#     lr_model.fit(x_train, y_train)
#     accuracy = lr_model.score(x_test, y_test)
#     train_acc = lr_model.score(x_train, y_train)
#     y_actual_np = cupy.asnumpy(y_test)
#     y_pred_np = cupy.asnumpy(lr_model.predict(x_test))
#     mi = sklearn.metrics.mutual_info_score(y_actual_np, y_pred_np) # nats
#     max_entropy = sklearn.metrics.mutual_info_score(y_actual_np, y_actual_np) # H(y)
#     return accuracy, mi, max_entropy, train_acc

def get_acc_with_ci(y_true, y_pred, alpha = 0.01):
    """
    Get accuracy with CI - standard Wilson CIs
    """
    n = y_true.size
    k = (y_true == y_pred).sum()
    z = scipy.stats.norm.ppf(1 - alpha/2)
    phat = k / n
    denom = 1 + z*z/n
    center = (phat + z*z/(2*n)) / denom
    half = z * np.sqrt((phat*(1-phat) + z*z/(4*n))/n) / denom
    return phat, center - half, center + half

def get_nmi_with_ci(y_true, y_pred, alpha = 0.01, base = 2.0):
    """
    Get normalized MI with CI - standard asymptotic CIs applying delta method to the entropy estimators
    """
    y = np.asarray(y_true); yhat = np.asarray(y_pred)
    n = y.size
    uy, yi = np.unique(y, return_inverse = True)
    uh, hi = np.unique(yhat, return_inverse = True)
    C, H = uy.size, uh.size

    N = np.zeros((C, H), float)
    for i in range(n): N[yi[i], hi[i]] += 1.0
    P = N / n
    py = P.sum(1, keepdims = True) # Cx1
    ph = P.sum(0, keepdims = True) # 1xH

    # MI and H(Y) in nats
    mask = (P>0) & (py>0) & (ph>0)
    term = np.zeros_like(P)
    term[mask] = np.log(P[mask]) - np.log(py.repeat(H,1)[mask]) - np.log(ph.repeat(C,0)[mask])

    # Gradients wrt P (nats)
    gI = np.zeros_like(P); gI[mask] = term[mask]
    gH = -(np.log(py) + 1.0) # Cx1
    gH = np.tile(gH, (1, H)) # CxH

    # Multinomial delta (nats^2)
    sI1 = (gI*gI*P).sum(); sI2 = (gI*P).sum()**2
    sH1 = (gH*gH*P).sum(); sH2 = (gH*P).sum()**2
    sC1 = (gI*gH*P).sum(); sC2 = (gI*P).sum()*(gH*P).sum()

    # Convert to base (bits by default)
    logb = np.log(base)
    I_bits = (P * term).sum() / logb
    Hy_bits =  -(py * np.log(py)).sum() / logb
    varI = (sI1 - sI2)/(n * logb**2)
    varH = (sH1 - sH2)/(n * logb**2)
    covIH = (sC1 - sC2)/(n * logb**2)

    R = I_bits/Hy_bits
    dI = 1.0/Hy_bits
    dH = -I_bits/(Hy_bits**2)
    seR = np.sqrt(dI*dI*varI + dH*dH*varH + 2*dI*dH*covIH)
    z = scipy.stats.norm.ppf(1 - alpha/2)
    return float(R), float(R - z*seR), float(R + z*seR)

In [None]:
"""
Logistic regression - predict expert ID
"""
current_layer_accuracy = []
for test_layer in tqdm(list(h_para_by_layer.keys())):
    
    expert_ids =\
        topk_df\
        .pipe(lambda df: df[df['layer_ix'] == test_layer])\
        .pipe(lambda df: df[df['topk_ix'] == 1])\
        ['expert'].tolist()

    expert_ids_cp = cupy.asarray(expert_ids)
    x_cp_para = cupy.asarray(h_para_by_layer[test_layer].to(torch.float16).detach().cpu())
    x_cp_orth = cupy.asarray(h_orth_by_layer[test_layer].to(torch.float16).detach().cpu())

    para_acc, para_acc_lo, para_acc_hi, para_nmi, para_nmi_lo, para_nmi_hi = run_probe(x_cp_para, expert_ids_cp)
    orth_acc, orth_acc_lo, orth_acc_hi, orth_nmi, orth_nmi_lo, orth_nmi_hi = run_probe(x_cp_orth, expert_ids_cp)

    current_layer_accuracy.append({
        'test_layer_1': test_layer + model_n_dense_layers + 1,
        'para_acc': para_acc, 'para_acc_lo': para_acc_lo, 'para_acc_hi': para_acc_hi,
        'para_nmi': para_nmi, 'para_nmi_lo': para_nmi_lo, 'para_nmi_hi': para_nmi_hi,
        'orth_acc': orth_acc, 'orth_acc_lo': orth_acc_lo, 'orth_acc_hi': orth_acc_hi,
        'orth_nmi': orth_nmi, 'orth_nmi_lo': orth_nmi_lo, 'orth_nmi_hi': orth_nmi_hi,
    })

pd.DataFrame(current_layer_accuracy)

In [None]:
# """
# Use h_para and h_orth to predict NEXT layer expert ids (note - this does not remove expert info, remove below)
# """
# next_layer_accuracy = []
# for test_layer in tqdm(list(h_para_by_layer.keys())[:-1]):
    
#     expert_ids =\
#         topk_df\
#         .pipe(lambda df: df[df['layer_ix'] == test_layer + 1])\
#         .pipe(lambda df: df[df['topk_ix'] == 1])\
#         ['expert'].tolist()

#     expert_ids_cp = cupy.asarray(expert_ids)
#     x_cp_para = cupy.asarray(h_para_by_layer[test_layer].to(torch.float16).detach().cpu())
#     x_cp_orth = cupy.asarray(h_orth_by_layer[test_layer].to(torch.float16).detach().cpu())

#     para_acc, para_acc_lo, para_acc_hi, para_nmi, para_nmi_lo, para_nmi_hi = run_probe(x_cp_para, expert_ids_cp)
#     orth_acc, orth_acc_lo, orth_acc_hi, orth_nmi, orth_nmi_lo, orth_nmi_hi = run_probe(x_cp_orth, expert_ids_cp)

#     next_layer_accuracy.append({
#         'test_layer_1': test_layer + model_n_dense_layers + 1,
#         'para_acc': para_acc, 'para_acc_lo': para_acc_lo, 'para_acc_hi': para_acc_hi,
#         'para_nmi': para_nmi, 'para_nmi_lo': para_nmi_lo, 'para_nmi_hi': para_nmi_hi,
#         'orth_acc': orth_acc, 'orth_acc_lo': orth_acc_lo, 'orth_acc_hi': orth_acc_hi,
#         'orth_nmi': orth_nmi, 'orth_nmi_lo': orth_nmi_lo, 'orth_nmi_hi': orth_nmi_hi,
#     })

# pd.DataFrame(next_layer_accuracy)

In [None]:
"""
Use h_para and h_orth to predict NEXT layer expert ids.
Remove expert centroids of CURRENT layer first to prevent vis to piggy-back on spurious correlations between layers.
"""
centroids_para = {}
centroids_orth = {}

# Get current-layer expert IDs for layer
for layer_ix in list(h_para_by_layer.keys()):

    cur_layer_expert_ids =\
        topk_df\
        .pipe(lambda df: df[df['layer_ix'] == layer_ix])\
        .pipe(lambda df: df[df['topk_ix'] == 1])\
        ['expert'].tolist()

    cur_layer_expert_ids_cp = cupy.asarray(cur_layer_expert_ids)

    # H_para/h_orth for layer
    h_para_cp = cupy.asarray(h_para_by_layer[layer_ix].to(torch.float16).detach().cpu())
    h_orth_cp = cupy.asarray(h_orth_by_layer[layer_ix].to(torch.float16).detach().cpu())

    # Compute centroids per expert id
    centroids_para[layer_ix] = {}
    centroids_orth[layer_ix] = {}

    for e in set(cur_layer_expert_ids):
        idx_cp = cupy.where(cur_layer_expert_ids_cp == e)[0]
        centroids_para[layer_ix][e] = h_para_cp[idx_cp].mean(axis = 0)
        centroids_orth[layer_ix][e] = h_orth_cp[idx_cp].mean(axis = 0)

next_layer_accuracy_cond = []
for test_layer in tqdm(list(h_para_by_layer.keys())[:-1]):

    # Target = next-layer slot-1 expert IDs (same as before)
    y_cp =\
        topk_df\
        .pipe(lambda df: df[df['layer_ix'] == test_layer + 1])\
        .pipe(lambda df: df[df['topk_ix'] == 1])\
        ['expert'].to_numpy()
    y_cp = cupy.asarray(y_cp)

    # Current-layer top-1 expert IDs - needed for residual lookup
    cur_exp_cp =\
        topk_df\
        .pipe(lambda df: df[df['layer_ix'] == test_layer])\
        .pipe(lambda df: df[df['topk_ix'] == 1])\
        ['expert'].to_numpy()
    cur_exp_cp = cupy.asarray(cur_exp_cp)

    # Pull h_para / h_orth tensors and convert to cupy
    h_para_cp = cupy.asarray(h_para_by_layer[test_layer].to(torch.float16).detach().cpu())
    h_orth_cp = cupy.asarray(h_orth_by_layer[test_layer].to(torch.float16).detach().cpu())

    # Subtract extract centroids
    mu_para_mat = cupy.stack([centroids_para[test_layer][int(e)] for e in cur_exp_cp])
    mu_orth_mat = cupy.stack([centroids_orth[test_layer][int(e)] for e in cur_exp_cp])
    h_para_res = h_para_cp - mu_para_mat
    h_orth_res = h_orth_cp - mu_orth_mat

    # Run the unchanged probe
    para_acc, para_acc_lo, para_acc_hi, para_nmi, para_nmi_lo, para_nmi_hi = run_probe(h_para_res, y_cp)
    orth_acc, orth_acc_lo, orth_acc_hi, orth_nmi, orth_nmi_lo, orth_nmi_hi = run_probe(h_orth_res, y_cp)

    next_layer_accuracy_cond.append({
        'test_layer_1': test_layer + model_n_dense_layers + 1,
        'para_acc': para_acc, 'para_acc_lo': para_acc_lo, 'para_acc_hi': para_acc_hi,
        'para_nmi': para_nmi, 'para_nmi_lo': para_nmi_lo, 'para_nmi_hi': para_nmi_hi,
        'orth_acc': orth_acc, 'orth_acc_lo': orth_acc_lo, 'orth_acc_hi': orth_acc_hi,
        'orth_nmi': orth_nmi, 'orth_nmi_lo': orth_nmi_lo, 'orth_nmi_hi': orth_nmi_hi,
    })

display(pd.DataFrame(next_layer_accuracy_cond))

In [None]:
"""
Use h_para and h_orth to predict PREV layer expert ids.
"""
prev_layer_accuracy_cond = []
for test_layer in tqdm(list(h_para_by_layer.keys())[1:]):

    # Target = next-layer slot-1 expert IDs (same as before)
    y_cp =\
        topk_df\
        .pipe(lambda df: df[df['layer_ix'] == test_layer - 1])\
        .pipe(lambda df: df[df['topk_ix'] == 1])\
        ['expert'].to_numpy()
    y_cp = cupy.asarray(y_cp)

    # Current-layer top-1 expert IDs - needed for residual lookup
    cur_exp_cp =\
        topk_df\
        .pipe(lambda df: df[df['layer_ix'] == test_layer])\
        .pipe(lambda df: df[df['topk_ix'] == 1])\
        ['expert'].to_numpy()
    cur_exp_cp = cupy.asarray(cur_exp_cp)

    # Pull h_para / h_orth tensors and convert to cupy
    h_para_cp = cupy.asarray(h_para_by_layer[test_layer].to(torch.float16).detach().cpu())
    h_orth_cp = cupy.asarray(h_orth_by_layer[test_layer].to(torch.float16).detach().cpu())

    # Subtract extract centroids
    mu_para_mat = cupy.stack([centroids_para[test_layer][int(e)] for e in cur_exp_cp])
    mu_orth_mat = cupy.stack([centroids_orth[test_layer][int(e)] for e in cur_exp_cp])
    h_para_res = h_para_cp - mu_para_mat
    h_orth_res = h_orth_cp - mu_orth_mat

    # Run the unchanged probe
    para_acc, para_acc_lo, para_acc_hi, para_nmi, para_nmi_lo, para_nmi_hi = run_probe(h_para_res, y_cp)
    orth_acc, orth_acc_lo, orth_acc_hi, orth_nmi, orth_nmi_lo, orth_nmi_hi = run_probe(h_orth_res, y_cp)

    prev_layer_accuracy_cond.append({
        'test_layer_1': test_layer + model_n_dense_layers + 1,
        'para_acc': para_acc, 'para_acc_lo': para_acc_lo, 'para_acc_hi': para_acc_hi,
        'para_nmi': para_nmi, 'para_nmi_lo': para_nmi_lo, 'para_nmi_hi': para_nmi_hi,
        'orth_acc': orth_acc, 'orth_acc_lo': orth_acc_lo, 'orth_acc_hi': orth_acc_hi,
        'orth_nmi': orth_nmi, 'orth_nmi_lo': orth_nmi_lo, 'orth_nmi_hi': orth_nmi_hi,
    })

display(pd.DataFrame(prev_layer_accuracy_cond))

In [None]:
"""
Use h_para and h_orth to predict path motif layer expert ids.
"""
path_motif_accuracy_cond = []
for test_layer in tqdm(list(h_para_by_layer.keys())[1:]):


    prev_layer_expert_ids =\
        topk_df\
        .pipe(lambda df: df[df['layer_ix'] == test_layer - 1])\
        .pipe(lambda df: df[df['topk_ix'] == 1])\
        ['expert'].tolist()
    
    cur_layer_expert_ids =\
        topk_df\
        .pipe(lambda df: df[df['layer_ix'] == test_layer])\
        .pipe(lambda df: df[df['topk_ix'] == 1])\
        ['expert'].tolist()

    y_df =\
        pd.DataFrame({'cur_layer_expert_ids': cur_layer_expert_ids, 'prev_layer_expert_ids': prev_layer_expert_ids})\
        .assign(path = lambda df: df['prev_layer_expert_ids'].astype(str) + '->' + df['cur_layer_expert_ids'].astype(str))
        
    y_map = {path: i for i, path in enumerate(y_df['path'].unique())}
    y_cp = cupy.asarray(y_df['path'].map(y_map))
    
    x_cp_para = cupy.asarray(h_para_by_layer[test_layer].to(torch.float16).detach().cpu())
    x_cp_orth = cupy.asarray(h_orth_by_layer[test_layer].to(torch.float16).detach().cpu())

    # Probe
    para_acc, para_acc_lo, para_acc_hi, para_nmi, para_nmi_lo, para_nmi_hi = run_probe(x_cp_para, y_cp)
    orth_acc, orth_acc_lo, orth_acc_hi, orth_nmi, orth_nmi_lo, orth_nmi_hi = run_probe(x_cp_orth, y_cp)

    path_motif_accuracy_cond.append({
        'test_layer_1': test_layer + model_n_dense_layers + 1,
        'para_acc': para_acc, 'para_acc_lo': para_acc_lo, 'para_acc_hi': para_acc_hi,
        'para_nmi': para_nmi, 'para_nmi_lo': para_nmi_lo, 'para_nmi_hi': para_nmi_hi,
        'orth_acc': orth_acc, 'orth_acc_lo': orth_acc_lo, 'orth_acc_hi': orth_acc_hi,
        'orth_nmi': orth_nmi, 'orth_nmi_lo': orth_nmi_lo, 'orth_nmi_hi': orth_nmi_hi,
    })

    print(pd.DataFrame(path_motif_accuracy_cond))

display(pd.DataFrame(path_motif_accuracy_cond))

In [None]:
"""
Use h_para and h_orth to predict path motif layer expert ids.
"""
path_motif_accuracy_cond = []

this_cur_layer_expert_id = 1
for test_layer in tqdm(list(h_para_by_layer.keys())[1:]):

    valid_samples =\
        topk_df\
        .pipe(lambda df: df[df['layer_ix'] == test_layer])\
        .pipe(lambda df: df[df['topk_ix'] == 1])\
        .pipe(lambda df: df[df['expert'] == this_cur_layer_expert_id])\
        ['sample_ix'].tolist()

    prev_layer_expert_ids =\
        topk_df\
        .pipe(lambda df: df[df['sample_ix'].isin(valid_samples)])\
        .pipe(lambda df: df[df['layer_ix'] == test_layer - 1])\
        .pipe(lambda df: df[df['topk_ix'] == 1])\
        ['expert'].to_numpy()
        
    y_cp = cupy.asarray(prev_layer_expert_ids)
    
    x_cp_para = cupy.asarray(h_para_by_layer[test_layer][valid_samples, :].to(torch.float16).detach().cpu())
    x_cp_orth = cupy.asarray(h_orth_by_layer[test_layer][valid_samples, :].to(torch.float16).detach().cpu())

    # Probe
    para_acc, para_acc_lo, para_acc_hi, para_nmi, para_nmi_lo, para_nmi_hi = run_probe(x_cp_para, y_cp)
    orth_acc, orth_acc_lo, orth_acc_hi, orth_nmi, orth_nmi_lo, orth_nmi_hi = run_probe(x_cp_orth, y_cp)

    path_motif_accuracy_cond.append({
        'test_layer_1': test_layer + model_n_dense_layers + 1,
        'para_acc': para_acc, 'para_acc_lo': para_acc_lo, 'para_acc_hi': para_acc_hi,
        'para_nmi': para_nmi, 'para_nmi_lo': para_nmi_lo, 'para_nmi_hi': para_nmi_hi,
        'orth_acc': orth_acc, 'orth_acc_lo': orth_acc_lo, 'orth_acc_hi': orth_acc_hi,
        'orth_nmi': orth_nmi, 'orth_nmi_lo': orth_nmi_lo, 'orth_nmi_hi': orth_nmi_hi,
    })

display(pd.DataFrame(path_motif_accuracy_cond))

In [None]:
prev_layer_expert_ids

In [None]:
"""
Export results
"""
layer_transitions_export_df = pd.concat([
    pd.DataFrame(current_layer_accuracy).assign(target = 'current_layer'),
    pd.DataFrame(next_layer_accuracy_cond).assign(target = 'next_layer')
]).assign(model = model_prefix)

display(layer_transitions_export_df)

layer_transitions_export_df.to_csv(f'{svd_dir}/svd-probe-expert-id-{model_prefix}.csv', index = False)

In [None]:
"""
Predict language - presplit, seperate TIDs
"""
# def run_lr_with_mi_presplit(x_train, x_test, y_train, y_test):
#     lr_model = cuml.linear_model.LogisticRegression(penalty = 'l2', max_iter = 1000, fit_intercept = True)
#     lr_model.fit(x_train, y_train)
#     accuracy = lr_model.score(x_test, y_test)
#     train_acc = lr_model.score(x_train, y_train)
#     y_actual_np = cupy.asnumpy(y_test)
#     y_pred_np = cupy.asnumpy(lr_model.predict(x_test))
#     mi = sklearn.metrics.mutual_info_score(y_actual_np, y_pred_np) # nats
#     max_entropy = sklearn.metrics.mutual_info_score(y_actual_np, y_actual_np) # H(y)
#     return accuracy, mi.item(), max_entropy.item(), train_acc

# lang_probe_accs = []
# # Split train/test, different TIDs in each
# gss = sklearn.model_selection.GroupShuffleSplit(n_splits = 1, test_size = 0.2, random_state = 123)
# train_ix, test_ix = next(gss.split(sample_df, groups = sample_df['token_id']))

# train_sample_df = sample_df.take(train_ix)
# test_sample_df = sample_df.take(test_ix)

# # Prep y values
# source_mapping = {source: i for i, source in enumerate(sample_df['source'].unique())}

# y_train = cupy.asarray(train_sample_df.assign(source = lambda df: df['source'].map(source_mapping))['source'].tolist())
# y_test = cupy.asarray(test_sample_df.assign(source = lambda df: df['source'].map(source_mapping))['source'].tolist())

# for test_layer in tqdm(list(h_para_by_layer.keys())[::2]):

#     x_para = cupy.asarray(h_para_by_layer[test_layer].to(torch.float16).detach().cpu())
#     x_orth = cupy.asarray(h_orth_by_layer[test_layer].to(torch.float16).detach().cpu())
    
#     x_train_para = x_para[train_sample_df['sample_ix'].tolist(), :]
#     x_test_para = x_para[test_sample_df['sample_ix'].tolist(), :]
#     x_train_orth = x_orth[train_sample_df['sample_ix'].tolist(), :]
#     x_test_orth = x_orth[test_sample_df['sample_ix'].tolist(), :]

#     para_res = run_lr_with_mi_presplit(x_train_para, x_test_para, y_train, y_test)
#     orth_res = run_lr_with_mi_presplit(x_train_orth, x_test_orth, y_train, y_test)

#     lang_probe_accs.append({
#         'test_layer_1': test_layer + model_n_dense_layers + 1,
#         'para_acc': para_res[0],
#         'para_train_acc': para_res[3],
#         'para_mi_bits': para_res[1]/np.log(2.0),
#         'para_entropy_bits': para_res[2]/np.log(2.0),
#         'para_mi_pct': para_res[1]/para_res[2],
#         'orth_acc': orth_res[0],
#         'orth_train_acc': orth_res[3],
#         'orth_mi_bits': orth_res[1]/np.log(2.0),
#         'orth_entropy_bits': orth_res[2]/np.log(2.0),
#         'orth_mi_pct': orth_res[1]/orth_res[2]
#     })

#     display(pd.DataFrame(lang_probe_accs))

In [None]:
"""
Predict Language
"""
lang_probe_accs = []

for test_layer in tqdm(list(h_para_by_layer.keys())):

    source_mapping = {source: i for i, source in enumerate(sample_df['source'].unique())}

    # Probe en/es token predictiveness
    y_df =\
        sample_df\
        .assign(source = lambda df: df['source'].map(source_mapping))
        # .pipe(lambda df: df[df['source'].isin(['en', 'es'])])\ # Move up about assign(source=...)
        # .pipe(lambda df: df[df['token'].apply(lambda x: bool(regex.search(r'\p{L}', x)))])

    selected_indices = y_df['sample_ix'].tolist()

    y_df = y_df['source'].tolist()

    y_cp = cupy.asarray(y_df)
    x_cp_para = cupy.asarray(h_para_by_layer[test_layer][selected_indices, :].to(torch.float16).detach().cpu())
    x_cp_orth = cupy.asarray(h_orth_by_layer[test_layer][selected_indices, :].to(torch.float16).detach().cpu())
    
    para_acc, para_acc_lo, para_acc_hi, para_nmi, para_nmi_lo, para_nmi_hi = run_probe(x_cp_para, y_cp)
    orth_acc, orth_acc_lo, orth_acc_hi, orth_nmi, orth_nmi_lo, orth_nmi_hi = run_probe(x_cp_orth, y_cp)

    lang_probe_accs.append({
        'test_layer_1': test_layer + model_n_dense_layers + 1,
        'para_acc': para_acc, 'para_acc_lo': para_acc_lo, 'para_acc_hi': para_acc_hi,
        'para_nmi': para_nmi, 'para_nmi_lo': para_nmi_lo, 'para_nmi_hi': para_nmi_hi,
        'orth_acc': orth_acc, 'orth_acc_lo': orth_acc_lo, 'orth_acc_hi': orth_acc_hi,
        'orth_nmi': orth_nmi, 'orth_nmi_lo': orth_nmi_lo, 'orth_nmi_hi': orth_nmi_hi,
    })

display(pd.DataFrame(lang_probe_accs))

In [None]:
# """
# Predict Language + DEMEAN
# """
# lang_probe_accs = []

# for test_layer in tqdm(list(h_para_by_layer.keys())[::2]):

#     source_mapping = {source: i for i, source in enumerate(sample_df['source'].unique())}

#     # Probe en/es token predictiveness
#     y_df =\
#         sample_df\
#         .pipe(lambda df: df[df['source'].isin(['en', 'es'])])\
#         .assign(source = lambda df: df['source'].map(source_mapping))
#         #\ .pipe(lambda df: df[df['token'].apply(lambda x: bool(regex.search(r'\p{L}', x)))])

#     selected_indices = y_df['sample_ix'].tolist()
#     y_df = y_df['source'].tolist()
#     y_cp = cupy.asarray(y_df)

#     ### Demean by current-layer top-1 expert
#     cur_exp_cp = topk_df.pipe(lambda df: df[df['layer_ix'] == test_layer]).pipe(lambda df: df[df['topk_ix'] == 1])['expert'].to_numpy()
#     cur_exp_cp = cupy.asarray(cur_exp_cp)[selected_indices]  # aligned!

#     # Subtract centroids
#     x_cp_para = cupy.asarray(h_para_by_layer[test_layer][selected_indices, :].to(torch.float16).detach().cpu()) # Pull base data
#     x_cp_orth = cupy.asarray(h_orth_by_layer[test_layer][selected_indices, :].to(torch.float16).detach().cpu())
#     mu_para_sel = cupy.stack([centroids_para[test_layer][int(e)] for e in cur_exp_cp.get()]) # Get centroids by cur expert
#     mu_orth_sel = cupy.stack([centroids_orth[test_layer][int(e)] for e in cur_exp_cp.get()])
#     x_cp_para = x_cp_para - mu_para_sel # Get demenead residuals
#     x_cp_orth = x_cp_orth - mu_orth_sel

#     para_acc, para_acc_lo, para_acc_hi, para_nmi, para_nmi_lo, para_nmi_hi = run_probe(x_cp_para, y_cp)
#     orth_acc, orth_acc_lo, orth_acc_hi, orth_nmi, orth_nmi_lo, orth_nmi_hi = run_probe(x_cp_orth, y_cp) 

#     lang_probe_accs.append({
#         'test_layer_1': test_layer + model_n_dense_layers + 1,
#         'para_acc': para_acc, 'para_acc_lo': para_acc_lo, 'para_acc_hi': para_acc_hi,
#         'para_nmi': para_nmi, 'para_nmi_lo': para_nmi_lo, 'para_nmi_hi': para_nmi_hi,
#         'orth_acc': orth_acc, 'orth_acc_lo': orth_acc_lo, 'orth_acc_hi': orth_acc_hi,
#         'orth_nmi': orth_nmi, 'orth_nmi_lo': orth_nmi_lo, 'orth_nmi_hi': orth_nmi_hi,
#     })

# display(pd.DataFrame(lang_probe_accs))

In [None]:
"""
Export
"""
display(sample_df.groupby('source', as_index = False).agg(z = ('sample_ix', 'count')))
lang_export_df = pd.DataFrame(lang_probe_accs)
display(lang_export_df)

lang_export_df.to_csv(f'{svd_dir}/svd-probe-lang-{model_prefix}.csv', index = False)

In [None]:
"""
Predict TID
"""
tid_probe_accs = []

for test_layer in tqdm(list(h_para_by_layer.keys())):

    clear_all_cuda_memory(False)

    top_tids =\
        sample_df\
        .pipe(lambda df: df[df['source'] == 'en'])\
        .groupby(['token_id', 'token'], as_index = False)\
        .agg(n = ('token', 'count')).sort_values(by = 'n', ascending = False)\
        .head(500)

    valid_samples =\
        sample_df\
        .assign(token_id = lambda df: np.where(df['token_id'].isin(top_tids['token_id']), df['token_id'], 999999))
        # .pipe(lambda df: df[df['token_id'].isin(top_tids['token_id'].tolist())])

    y_df =\
        valid_samples\
        ['token_id']\
        .tolist()

    y_cp = cupy.asarray(y_df)
    x_cp_para = cupy.asarray(h_para_by_layer[test_layer][valid_samples['sample_ix'].tolist(), :].to(torch.float16).detach().cpu())
    x_cp_orth = cupy.asarray(h_orth_by_layer[test_layer][valid_samples['sample_ix'].tolist(), :].to(torch.float16).detach().cpu())
    
    para_acc, para_acc_lo, para_acc_hi, para_nmi, para_nmi_lo, para_nmi_hi = run_probe(x_cp_para, y_cp)
    orth_acc, orth_acc_lo, orth_acc_hi, orth_nmi, orth_nmi_lo, orth_nmi_hi = run_probe(x_cp_orth, y_cp)

    tid_probe_accs.append({
        'test_layer_1': test_layer + model_n_dense_layers + 1,
        'para_acc': para_acc, 'para_acc_lo': para_acc_lo, 'para_acc_hi': para_acc_hi,
        'para_nmi': para_nmi, 'para_nmi_lo': para_nmi_lo, 'para_nmi_hi': para_nmi_hi,
        'orth_acc': orth_acc, 'orth_acc_lo': orth_acc_lo, 'orth_acc_hi': orth_acc_hi,
        'orth_nmi': orth_nmi, 'orth_nmi_lo': orth_nmi_lo, 'orth_nmi_hi': orth_nmi_hi,
    })

pd.DataFrame(tid_probe_accs)

In [None]:
"""
Export
"""
tid_export_df = pd.DataFrame(tid_probe_accs)
display(tid_export_df)

tid_export_df.to_csv(f'{svd_dir}/svd-probe-tid-{model_prefix}.csv', index = False)

In [None]:
"""
Predict Position
"""
position_probe_accs = []

for test_layer in tqdm(list(h_para_by_layer.keys())):

    clear_all_cuda_memory(False)

    valid_samples =\
        sample_df\
        .pipe(lambda df: df[df['token_ix'].isin(range(500))])

    y_df = (valid_samples['token_ix'] // 100).tolist()

    y_cp = cupy.asarray(y_df)
    x_cp_para = cupy.asarray(h_para_by_layer[test_layer][valid_samples['sample_ix'].tolist(), :].to(torch.float16).detach().cpu())
    x_cp_orth = cupy.asarray(h_orth_by_layer[test_layer][valid_samples['sample_ix'].tolist(), :].to(torch.float16).detach().cpu())
    
    para_acc, para_acc_lo, para_acc_hi, para_nmi, para_nmi_lo, para_nmi_hi = run_probe(x_cp_para, y_cp)
    orth_acc, orth_acc_lo, orth_acc_hi, orth_nmi, orth_nmi_lo, orth_nmi_hi = run_probe(x_cp_orth, y_cp)
    
    position_probe_accs.append({
        'test_layer_1': test_layer + model_n_dense_layers + 1,
        'para_acc': para_acc, 'para_acc_lo': para_acc_lo, 'para_acc_hi': para_acc_hi,
        'para_nmi': para_nmi, 'para_nmi_lo': para_nmi_lo, 'para_nmi_hi': para_nmi_hi,
        'orth_acc': orth_acc, 'orth_acc_lo': orth_acc_lo, 'orth_acc_hi': orth_acc_hi,
        'orth_nmi': orth_nmi, 'orth_nmi_lo': orth_nmi_lo, 'orth_nmi_hi': orth_nmi_hi,
    })

pd.DataFrame(position_probe_accs)

In [None]:
"""
Export
"""
position_export_df = pd.DataFrame(position_probe_accs)
display(position_export_df)

position_export_df.to_csv(f'{svd_dir}/svd-probe-pos-{model_prefix}.csv', index = False)