In [None]:
"""
This contains code to test orthogonality of expert specialization.
"""
None

In [None]:
"""
Imports
"""
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import pandas as pd
import numpy as np
import scipy
import cupy
import cuml
import sklearn

import importlib
import gc
import pickle

from tqdm import tqdm
from termcolor import colored
import plotly.express as px
from plotly.subplots import make_subplots

from utils.memory import check_memory, clear_all_cuda_memory
from utils.quantize import compare_bf16_fp16_batched
from utils.svd import decompose_orthogonal, decompose_sideways
from utils.vis import combine_plots

main_device = 'cuda:0'
seed = 1234
clear_all_cuda_memory()
check_memory()

## Load base model

In [None]:
"""
Load the base tokenizer/model
"""
model_id = 'allenai/OLMoE-1B-7B-0125-Instruct'
model_prefix = 'olmoe'
tokenizer = AutoTokenizer.from_pretrained(model_id, add_eos_token = False, add_bos_token = False, padding_side = 'left', trust_remote_code = True)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype = torch.bfloat16, trust_remote_code = True).cuda().eval()

In [None]:
"""
Load dataset
"""
def load_data(model_prefix):
    all_pre_mlp_hs = torch.load(f'data/{model_prefix}/all-pre-mlp-hidden-states.pt')
    all_expert_outputs = torch.load(f'data/{model_prefix}/all-expert-outputs.pt')
    with open(f'data/{model_prefix}/metadata.pkl', 'rb') as f:
        metadata = pickle.load(f)
    
    return all_pre_mlp_hs, all_expert_outputs, metadata['sample_df'], metadata['topk_df'], metadata['all_pre_mlp_hidden_states_layers'], metadata['all_expert_outputs_layers']

all_pre_mlp_hs_import, all_expert_outputs_import, sample_df_import, topk_df_import, act_map, expert_map = load_data(model_prefix)

In [None]:
"""
Let's clean up the mappings here. We'll get everything to a sample_ix level first.
"""
sample_df_raw =\
    sample_df_import\
    .assign(sample_ix = lambda df: df.groupby(['batch_ix', 'sequence_ix', 'token_ix']).ngroup())\
    .assign(seq_id = lambda df: df.groupby(['batch_ix', 'sequence_ix']).ngroup())\
    .reset_index()

topk_df =\
    topk_df_import\
    .merge(sample_df_raw[['sample_ix', 'batch_ix', 'sequence_ix', 'token_ix']], how = 'inner', on = ['sequence_ix', 'token_ix', 'batch_ix'])\
    .drop(columns = ['sequence_ix', 'token_ix', 'batch_ix'])

topk1_df =\
    topk_df\
    .pipe(lambda df: df[df['topk_ix'] == 1])

sample_df =\
    sample_df_raw\
    .drop(columns = ['batch_ix', 'sequence_ix'])

del sample_df_import, topk_df_import
display(topk_df)
display(sample_df)

In [None]:
"""
Convert activations to fp16 (for compatibility with cupy later) + dict
"""
all_pre_mlp_hs = all_pre_mlp_hs_import.to(torch.float16)
compare_bf16_fp16_batched(all_pre_mlp_hs_import, all_pre_mlp_hs)
del all_pre_mlp_hs_import
all_pre_mlp_hs = {layer_ix: all_pre_mlp_hs[:, save_ix, :] for save_ix, layer_ix in enumerate(act_map)}

all_expert_outputs = all_expert_outputs_import.to(torch.float16)
compare_bf16_fp16_batched(all_expert_outputs_import, all_expert_outputs)
del all_expert_outputs_import
all_expert_outputs = {layer_ix: all_expert_outputs[:, save_ix, :, :] for save_ix, layer_ix in enumerate(expert_map)}

gc.collect()

## Visualize activation clusters of single (layer, expert)

In [None]:
"""
Visualize clusters
"""
layer_ix = 9
expert_id = 1

relevant_sample_ids =\
    topk1_df\
    .pipe(lambda df: df[df['layer_ix'] == layer_ix])\
    .pipe(lambda df: df[df['expert'] == expert_id])\
    .sort_values(by = 'sample_ix', ascending = True)\
    ['sample_ix']\
    .tolist()

# Get expert IDs of previous layer
prev_experts_df =\
    topk1_df\
    .pipe(lambda df: df[df['layer_ix'] == layer_ix - 1])\
    .pipe(lambda df: df[df['sample_ix'].isin(relevant_sample_ids)])\
    .rename(columns = {'expert': 'prev_expert'})\
    [['sample_ix', 'prev_expert']]

# Get sample dfs of relevant sample IDs, include expert IDs of previous layer
relevant_samples_df =\
    sample_df[sample_df['sample_ix'].isin(relevant_sample_ids)]\
    .merge(prev_experts_df, on = 'sample_ix', how = 'inner')

display(relevant_samples_df)

relevant_pre_mlp_hs = all_pre_mlp_hs[layer_ix][relevant_sample_ids]

In [None]:
"""
Helpers
"""
def reduce_pca(input_tensor: torch.Tensor, n_components = 2):
    hs_cupy = cupy.asarray(input_tensor.to(torch.float32))
    model = cuml.PCA(
        iterated_power = 100,
        n_components = n_components,
        verbose = True
    )
    model.fit(hs_cupy)
    # print(f'Explained variance ratio: {model.explained_variance_ratio_}')
    print(f'Cumulative variance ratio: {np.cumsum(model.explained_variance_ratio_)[-1]}')
    pred = cupy.asnumpy(model.fit_transform(hs_cupy))
    clear_all_cuda_memory(False)
    return pred

def reduce_umap(input_tensor: torch.Tensor, n_components = 2, metric = 'cosine', n_epochs = 200):
    hs_cupy = cupy.asarray(input_tensor.to(torch.float32))
    model = cuml.UMAP(
        n_components = n_components, 
        n_neighbors = 20, # 15 for default, smaller = more local data preserved [2 - 100]
        metric = metric, # euclidean, cosine, manhattan, l2, hamming
        min_dist = 0.5, # 0.1 by default, effective distance between embedded points
        n_epochs = n_epochs, # 200 by default for large datasets
        random_state = 123, # Allow parallelism
        verbose = False
    )
    pred = cupy.asnumpy(model.fit_transform(hs_cupy))
    clear_all_cuda_memory(False)
    return pred

def plot_manifold(plot_df, color_col, hover_col, title = None):
    plot = px.scatter(
        plot_df,
        x = 'd1', y = 'd2', color = color_col, hover_data = [hover_col],
        title = title
    ).update_layout(autosize = False, height = 400)
    return plot

In [None]:
"""
Plot PCA + UMAP, color by previous expert ID
"""
pca_res = reduce_pca(relevant_pre_mlp_hs, 2)
pca_plot_df =\
    pd.concat([pd.DataFrame({'d1': pca_res[:, 0], 'd2': pca_res[:, 1]}), relevant_samples_df], axis = 1)\
    .sample(5000)\
    .assign(prev_expert = lambda df: df['prev_expert'].astype(str))
plot_manifold(pca_plot_df, 'prev_expert', 'token')
plot_manifold(pca_plot_df, 'source', 'token')

ump_res = reduce_umap(relevant_pre_mlp_hs, 2, 'cosine')
ump_plot_df =\
    pd.concat([pd.DataFrame({'d1': pca_res[:, 0], 'd2': pca_res[:, 1]}), relevant_samples_df], axis = 1)\
    .sample(5000)\
    .assign(prev_expert = lambda df: df['prev_expert'].astype(str))
plot_manifold(ump_plot_df, 'prev_expert', 'token')
plot_manifold(ump_plot_df, 'source', 'token')

In [None]:
"""
Plot PCA + UMAP, but this time include token samples from outside this expert
"""
nonrelevant_sample_ids =\
    topk1_df\
    .pipe(lambda df: df[df['layer_ix'] == layer_ix])\
    .pipe(lambda df: df[df['expert'] != expert_id])\
    .sort_values(by = 'sample_ix', ascending = True)\
    ['sample_ix']\
    .tolist()

# Get sample dfs of relevant sample IDs, include expert IDs of previous layer
nonrelevant_samples_df = sample_df[sample_df['sample_ix'].isin(nonrelevant_sample_ids)]
nonrelevant_pre_mlp_hs = all_pre_mlp_hs[layer_ix][nonrelevant_sample_ids]

# Prep all_samples_df samples df
all_samples_df = pd.concat([
    relevant_samples_df.assign(prev_expert = lambda df: df['prev_expert'].astype(str)), 
    nonrelevant_samples_df.assign(prev_expert = 'NA')
]).reset_index(drop = True)

pca_res = reduce_pca(torch.concat([relevant_pre_mlp_hs, nonrelevant_pre_mlp_hs], dim = 0), 2)
pca_plot_df =\
    pd.concat([pd.DataFrame({'d1': pca_res[:, 0], 'd2': pca_res[:, 1]}), all_samples_df], axis = 1)\
    .pipe(lambda df: pd.concat([
        df[df['prev_expert'] == 'NA'].sample(10_000),
        df[df['prev_expert'] != 'NA'].sample(10_000)
    ]))

plot_manifold(pca_plot_df, 'prev_expert', 'token')
plot_manifold(pca_plot_df, 'source', 'token')

## Cross-layer source mappings

In [None]:
"""
Plot cross-layer source mappings
"""
layers_to_test = [0, 2, 7, 13, 15]

def get_sample_df_for_layer(sample_df, topk1_df, layer_ix):
    layer_df =\
        sample_df\
        .merge(
            topk1_df.pipe(lambda df: df[df['layer_ix'] == layer_ix])[['expert', 'sample_ix']],
            how = 'inner',
            on = 'sample_ix'
        )\
        .merge(
            topk1_df.pipe(lambda df: df[df['layer_ix'] == layer_ix - 1]).rename(columns = {'expert': 'prev_expert'})[['sample_ix', 'prev_expert']],
            how = 'left',
            on = 'sample_ix'
        )
    return layer_df


test_layers = [
    {'layer_ix': layer_ix, 'sample_df': get_sample_df_for_layer(sample_df, topk1_df, layer_ix)}
    for layer_ix in layers_to_test
]

test_layers[0]

In [None]:
"""
H by language
"""
for test_layer in test_layers:
    pca_res = reduce_pca(all_pre_mlp_hs[test_layer['layer_ix']], 2)
    pca_plot_df =\
        pd.concat([pd.DataFrame({'d1': pca_res[:, 0], 'd2': pca_res[:, 1]}), test_layer['sample_df']], axis = 1)\
        .sample(2500)
    plot_manifold(pca_plot_df, 'source', 'token', f"<em>H<sub>{str(test_layer['layer_ix'])}</sub></em>").show()

In [None]:
"""
H by expert
"""
for test_layer in test_layers:
    pca_res = reduce_pca(all_pre_mlp_hs[test_layer['layer_ix']], 2)
    pca_plot_df =\
        pd.concat([pd.DataFrame({'d1': pca_res[:, 0], 'd2': pca_res[:, 1]}), test_layer['sample_df']], axis = 1)\
        .pipe(lambda df: df[df['expert'].isin(list(range(0, 5)))])\
        .assign(expert = lambda df: df['expert'].astype(str))\
        .sample(2500)\
        .sort_values(by = 'expert')
    plot_manifold(pca_plot_df, 'expert', 'token', f"<em>H<sub>{str(test_layer['layer_ix'])}</sub></em>").show()

In [None]:
"""
H_orth vs H_para by expert
"""
for test_layer in test_layers:

    h_para, h_orth = decompose_orthogonal(
        all_pre_mlp_hs[test_layer['layer_ix']].to(torch.float32),
        model.model.layers[test_layer['layer_ix']].mlp.gate.weight.to(torch.float32).detach().cpu(),
        method = 'svd'
    )
    h_para = h_para.to(torch.float32)
    h_orth = h_orth.to(torch.float32)
    
    h_para_pca_res = reduce_pca(h_para, 2)
    h_orth_pca_res = reduce_pca(h_orth, 2)
    
    h_para_plot_df =\
        pd.concat([pd.DataFrame({'d1': h_para_pca_res[:, 0], 'd2': h_para_pca_res[:, 1]}), test_layer['sample_df']], axis = 1)\
        .sample(2500)\
        .sort_values(by = 'expert')

    h_orth_plot_df =\
        pd.concat([pd.DataFrame({'d1': h_orth_pca_res[:, 0], 'd2': h_orth_pca_res[:, 1]}), test_layer['sample_df']], axis = 1)\
        .sample(2500)\
        .sort_values(by = 'expert')
    
    combine_plots([
        plot_manifold(h_para_plot_df, 'source', 'token', f"<em>H<sub>para</sub>({str(test_layer['layer_ix'])})</em>"),
        plot_manifold(h_orth_plot_df, 'source', 'token', f"<em>H<sub>orth</sub>({str(test_layer['layer_ix'])})</em>")
    ]).show() 

## Functional Specialization

In [None]:
"""
Print available layers
"""
list(all_expert_outputs.keys())

In [None]:
del test_t1_exp_inputs

In [None]:
"""
Get test samples (pre and post-MLP) for experts where expert was top-1
"""
test_layer_ix = 11

test_sample_df = get_sample_df_for_layer(sample_df, topk1_df, test_layer_ix)

test_exp_inputs = all_pre_mlp_hs[test_layer_ix][:, :]
test_t1_exp_outputs = all_expert_outputs[test_layer_ix][:, 0, :]

test_deltas = test_t1_exp_outputs - test_exp_inputs
print(f"{test_exp_inputs.shape} | {test_t1_exp_outputs.shape}")

In [None]:
"""
Show PCA of functional transformation
"""
layer_pca = reduce_pca(test_deltas, 2)
test_sample_df_with_pca = test_sample_df.assign(d1 = layer_pca[:, 0], d2 = layer_pca[:, 1])

# Experts to test
test_experts =\
    test_sample_df\
    .groupby('expert', as_index = False).agg(n = ('expert', 'count'))\
    .sort_values('n', ascending = False)\
    .head(5)['expert'].tolist()

# Plot across experts
pca_plot_df =\
    test_sample_df_with_pca.pipe(lambda df: df[df['expert'].isin(test_experts)])\
    .sample(5_000)\
    .assign(expert = lambda df: df['expert'].astype(str))

combine_plots([
    plot_manifold(pca_plot_df.sort_values('expert'), 'expert', 'token', title = 'Deltas by expert'),
    plot_manifold(pca_plot_df.sort_values('source'), 'source', 'token', title = 'Deltas by source')
], title = f'Layer {str(test_layer_ix)}').show()

# Plot within experts
for test_expert in sorted(test_experts):
    
    source_plot_df =\
         test_sample_df_with_pca.pipe(lambda df: df[df['expert'] == test_expert])\
        .sample(5_000)\
        .assign(prev_expert = lambda df: df['prev_expert'].astype(str))

    t5_prev_experts =\
        test_sample_df_with_pca.pipe(lambda df: df[df['expert'] == test_expert])\
        .groupby('prev_expert', as_index = False).agg(n = ('prev_expert', 'count'))\
        .sort_values('n', ascending = False)\
        .head(5)['prev_expert'].tolist()

    t5_plot_df =\
        test_sample_df_with_pca.pipe(lambda df: df[df['expert'] == test_expert])\
        .pipe(lambda df: df[df['prev_expert'].isin(t5_prev_experts)])\
        .assign(prev_expert = lambda df: df['prev_expert'].astype(str))
    
    combine_plots([
        plot_manifold(t5_plot_df.sort_values('prev_expert'), 'prev_expert', 'token', title = 'Deltas by prev expert'),
        plot_manifold(source_plot_df.sort_values('source'), 'source', 'token', title = 'Deltas by source')
    ], title = f'Layer {str(test_layer_ix)}, Expert {str(test_expert)}').show()

In [None]:
del test_expert_pca_res
del test_expert_indices

In [None]:
test_expert_indices

In [None]:
grped_delta = []

for grp_val in sorted(test_sample_df['source'].unique().tolist()):
    this_grp_sample_indices = test_sample_df[test_sample_df['source'] == grp_val]['sample_ix'].tolist()
    if len(this_grp_sample_indices) <= 10:
        continue
    this_grp_deltas = test_deltas[test_expert_indices, :]
    grped_deltas.append({
        'grp': grp_val,
        'grped_vals': this_grp_deltas.mean(dim = 0)
    })

In [None]:
test_deltas[0:10,:].mean(dim = 0).shape

In [None]:
"""
test_layer_ix = 11

test_sample_df =\
    get_sample_df_for_layer(sample_df, topk1_df, test_layer_ix)\
    .pipe(lambda df: df[df['expert'] == test_expert])

test_sample_indices = test_sample_df['sample_ix'].tolist()
display(test_sample_df)

test_t1_exp_outputs = all_expert_outputs[test_layer_ix][test_sample_indices, 0, :]
test_t1_exp_inputs = all_pre_mlp_hs[test_layer_ix][test_sample_indices, :]

print(f"{test_t1_exp_inputs.shape} | {test_t1_exp_outputs.shape}")

In [None]:
"""
Top-1 vs top-2 behavior
"""

## SVD clustering: sideways -> h_orth
1. Decompose the pre-mlp hidden states into h_sideways WITHIN each group of activations that route to a single expert, with respect to just the D-dimensional routing gate for that single expert. Then, still within that single expert, cluster those h_sideways activations.
2. Repeat across all experts.
3. After obtaining these cluters, do h_orth (the regular decomposition using all activations with respect to the entire routing gate).
4. Extract calculate the cluster centroids from h_orth, using the cluster ids/labels extracted from h_sideways earlier. This results in n_experts * n_clusters_per_expert cluster centers.
5. Calculate (using cosine similarity?) the within-expert against across-expert averages.

In [None]:
# import importlib, utils.svd as svd
# decompose_sideways = importlib.reload(svd).decompose_sideways

In [None]:
def cluster_kmeans(layer_hs: torch.Tensor, n_clusters = 512):
    kmeans_model = cuml.cluster.KMeans(n_clusters = n_clusters, max_iter = 1000, random_state = 123)
    kmeans_model.fit(cupy.asarray(layer_hs.to(torch.float32)))
    clear_all_cuda_memory(False)
    cluster_ids = kmeans_model.labels_.tolist() # n_samples
    cluster_centers = kmeans_model.cluster_centers_ # (n_clusters, D)
    return cluster_ids, cluster_centers

cluster_kmeans(relevant_pre_mlp_hs)

In [None]:
"""
Prepare sample-level df merged with top-1 expert selections for a single test layer
"""

test_layer_ix = 1

sample_df_test =\
    sample_df\
    .merge(
        topk1_df.pipe(lambda df: df[df['layer_ix'] == test_layer_ix])[['expert', 'sample_ix']],
        how = 'inner',
        on = 'sample_ix'
    )\
    .merge(
        topk1_df.pipe(lambda df: df[df['layer_ix'] == test_layer_ix - 1]).rename(columns = {'expert': 'prev_expert'})[['sample_ix', 'prev_expert']],
        how = 'inner',
        on = 'sample_ix'
    )

sample_df_test

In [None]:
"""
Extract the sideways decomposition within activations routed to a single expert; this specifically REMOVES the part of h directly 
responsible for increasing/decreasing logit specifically for that expert; then cluster them
"""
cluster_ids = torch.full([all_pre_mlp_hs[test_layer_ix].shape[0]], -1, dtype = torch.int32)

for this_expert in tqdm(sorted(sample_df_test['expert'].unique().tolist())):

    # Extract sample indices for expert
    this_sample_indices = sample_df_test[sample_df_test['expert'] == this_expert]['sample_ix'].tolist()
    
    # D-dimensional routing gate for expert route
    this_gate = model.model.layers[test_layer_ix].mlp.gate.weight[this_expert, :].to(torch.float32).detach().cpu()
    
    # Remove only this expert’s axis to expose sub‑clusters
    _, h_side = decompose_sideways(all_pre_mlp_hs[test_layer_ix][this_sample_indices], this_gate)

    # Cluster within expert
    this_cluster_ids, _ = cluster_kmeans(h_side, n_clusters = 10)
    cluster_ids[this_sample_indices] = torch.tensor(this_cluster_ids, dtype = cluster_ids.dtype)

sample_df_test_cl = sample_df_test.assign(cluster_id = cluster_ids)

In [None]:
"""
Go back to the original decomposition to get the regular h_orth
"""
_, h_orth = decompose_orthogonal(
    all_pre_mlp_hs[test_layer_ix].to(torch.float32),
    model.model.layers[test_layer_ix].mlp.gate.weight.to(torch.float32).detach().cpu(),
    method = 'svd'
)
h_orth = h_orth.to(torch.float32)

In [None]:
"""
Apply the sub-cluster labels obtained from clustering h_sideways to the corresponding h_orth vectors
This allows us to compare things in h_orth space; then compare cosine similarity. 
"""
centroids = [] # List of np centroids
tags = [] # (expert, cluster)

for this_expert in sorted(sample_df_test_cl['expert'].unique().tolist()):
    for this_cluster in [x for x in sorted(sample_df_test_cl['cluster_id'].unique().tolist()) if x != -1]: # Get clusters
        this_e_c_sample_indices =\
            sample_df_test_cl\
            .pipe(lambda df: df[(df['expert'] == this_expert) & (df['cluster_id'] == this_cluster)])\
            ['sample_ix'].tolist()
        
        if len(this_e_c_sample_indices) <= 50: # Throw out tiny clusters
            continue

        v = h_orth[this_e_c_sample_indices].mean(0)
        v = v / v.norm() # normalise for cosine
        centroids.append(v)
        tags.append((int(this_expert), int(this_cluster)))

centroids = torch.stack(centroids) # this_expert * K x D
cosine_sims = sklearn.metrics.pairwise.cosine_similarity(centroids.numpy()) # this_expert * K x this_expert * K; pairwise sim between each expert-cluster pair

within, across = [], []
for i, (e_i, _) in enumerate(tags):
    for j, (e_j, _) in enumerate(tags):
        if i >= j: continue # upper‑tri only
        (within if e_i==e_j else across).append(cosine_sims[i, j])

mean_cos_within  = float(np.mean(within))
mean_cos_across  = float(np.mean(across))
print(f"mean cosine  (within expert):  {mean_cos_within:.3f}")
print(f"mean cosine  (across experts): {mean_cos_across:.3f}")

## SVD clustering: h_orth vs h_para

In [None]:
"""
Prepare sample-level df merged with top-1 expert selections for a single test layer
"""

test_layer_ix = 10

topk1_df =\
    topk_df\
    .pipe(lambda df: df[df['topk_ix'] == 1])

sample_df_test =\
    sample_df\
    .merge(
        topk1_df.pipe(lambda df: df[df['layer_ix'] == test_layer_ix])[['expert', 'sample_ix']],
        how = 'inner',
        on = 'sample_ix'
    )\
    .merge(
        topk1_df.pipe(lambda df: df[df['layer_ix'] == test_layer_ix - 1]).rename(columns = {'expert': 'prev_expert'})[['sample_ix', 'prev_expert']],
        how = 'inner',
        on = 'sample_ix'
    )

sample_df_test

In [None]:
"""
Go back to the original decomposition to get the regular h_orth
"""
h_para, h_orth = decompose_orthogonal(
    all_pre_mlp_hs[test_layer_ix].to(torch.float32),
    model.model.layers[test_layer_ix].mlp.gate.weight.to(torch.float32).detach().cpu(),
    method = 'svd'
)
h_para = h_para.to(torch.float32)
h_orth = h_orth.to(torch.float32)

In [None]:
"""
Plot PCA + UMAP, color by expert id
"""
def reduce_pca(input_tensor: torch.Tensor, n_components = 2):
    hs_cupy = cupy.asarray(input_tensor.to(torch.float32))
    model = cuml.PCA(
        iterated_power = 100,
        n_components = n_components,
        verbose = True
    )
    model.fit(hs_cupy)
    # print(f'Explained variance ratio: {model.explained_variance_ratio_}')
    print(f'Cumulative variance ratio: {np.cumsum(model.explained_variance_ratio_)[-1]}')
    pred = cupy.asnumpy(model.fit_transform(hs_cupy))
    clear_all_cuda_memory(False)
    return pred

def reduce_umap(input_tensor: torch.Tensor, n_components = 2, metric = 'cosine', n_epochs = 200):
    hs_cupy = cupy.asarray(input_tensor.to(torch.float32))
    model = cuml.UMAP(
        n_components = n_components, 
        n_neighbors = 20, # 15 for default, smaller = more local data preserved [2 - 100]
        metric = metric, # euclidean, cosine, manhattan, l2, hamming
        min_dist = 0.5, # 0.1 by default, effective distance between embedded points
        n_epochs = n_epochs, # 200 by default for large datasets
        random_state = 123, # Allow parallelism
        verbose = False
    )
    pred = cupy.asnumpy(model.fit_transform(hs_cupy))
    clear_all_cuda_memory(False)
    return pred

def plot_manifold(reduced_np, sample_df):
    px.scatter(
        pd.concat([pd.DataFrame({'d1': reduced_np[:, 0], 'd2': reduced_np[:, 1]}), sample_df], axis = 1)\
            .sample(5000)\
            .assign(prev_expert = lambda df: df['expert'].astype(str)),
        x = 'd1', y = 'd2', color = 'expert', hover_data = ['token']
    ).show()

#pca_res = reduce_pca(h_para, 2)
plot_manifold(pca_res, sample_df_test)

# plot_reduction(relevant_pre_mlp_hs, relevant_samples_df, reduce_umap, 2, 'cosine', 500)

In [None]:
plot_manifold(pca_res, sample_df_test)