In [71]:
from copy import deepcopy


import panel as pn
import plotly.express as px
import torch
from processor import Processor
from safetensors import safe_open
from scipy.spatial import procrustes
from sklearn.decomposition import PCA
from torch import nn
from torch.utils.data import DataLoader, Subset
from torcheval.metrics import MulticlassAccuracy

from analysis.common import load_autoencoder, load_model
from koopmann import aesthetics

# from koopmann import aesthetics
from koopmann.data import (
    DatasetConfig,
    get_dataset_class,
)
from scripts.train_ae.shape_metrics import prepare_acts, undo_preprocessing_acts

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [72]:
dataset_name = "lotusroot"
model_name = f"resmlp_{dataset_name}"
file_dir = "/Users/nsa325/Documents/koopmann_model_saves"
data_root = "/Users/nsa325/datasets/"

In [73]:
if dataset_name == "lotusroot":
    dim = 20
    scale_idx = 1
    k_steps = 100
    flavor = "exponential"
elif dataset_name == "mnist":
    dim = 800
    scale_idx = 1
    k_steps = 10
    flavor = "exponential"
elif dataset_name == "cifar10":
    dim = 1_000
    scale_idx = 1
    k_steps = 100
    flavor = "exponential"
else:
    raise NotImplementedError()

ae_name = f"dim_{dim}_k_{k_steps}_loc_{scale_idx}_{flavor}_autoencoder_{dataset_name}_model"
device = "cpu"

In [74]:
model, model_metadata = load_model(file_dir, model_name)
model.eval().hook_model().to(device)
print("Model: ", model_metadata)

autoencoder, ae_metadata = load_autoencoder(file_dir, ae_name)
autoencoder.eval().to(device)
new_dim = ae_metadata["in_features"]
preprocess = ae_metadata["preprocess"]
K_matrix = autoencoder.koopman_weights.T
print("Autoencoder: ", ae_metadata)

Model:  {'batchnorm': True, 'bias': True, 'created_at': '2025-04-30T10:33:25.774294', 'dataset': 'LotusRootDataset', 'hidden_config': [10, 10, 10, 10], 'in_features': 2, 'model_class': 'ResMLP', 'nonlinearity': 'relu', 'out_features': 2, 'stochastic_depth_mode': 'batch', 'stochastic_depth_prob': 0.0}
Autoencoder:  {'batchnorm': False, 'bias': True, 'created_at': '2025-05-01T13:48:07.748422', 'hidden_config': [30], 'in_features': 10, 'k_steps': 100, 'latent_features': 20, 'model_class': 'ExponentialKoopmanAutencoder', 'nonlinearity': 'leaky_relu', 'preprocess': True}


In [75]:
dataset_config = DatasetConfig(
    dataset_name=model_metadata["dataset"],
    num_samples=3_000,
    split="train",
    seed=42,
)
DatasetClass = get_dataset_class(name=dataset_config.dataset_name)
dataset = DatasetClass(config=dataset_config, root=data_root)

subset_size = 3_000
if subset_size:
    subset_indices = list(range(0, subset_size))
    subset = Subset(dataset, subset_indices)

batch_size = 3_000
batch_size = min(subset_size, batch_size) if subset_size else batch_size
dataloader = DataLoader(subset if subset_size else dataset, batch_size=batch_size, shuffle=False)

In [76]:
def compute_reference_bases(data):
    # Compute PCA reference basis
    pca = PCA(n_components=3)
    ref = pca.fit_transform(data)
    return ref


def align_using_procrustes(reference_points, new_points):
    _, new_points_aligned, _ = procrustes(reference_points, new_points)
    return new_points_aligned


def process_pca_and_align(data, reference):
    """Applies PCA, aligns using Procrustes, and returns aligned data."""
    pca = PCA(n_components=3)
    pca_result = pca.fit_transform(data)
    pca_result = align_using_procrustes(reference, pca_result)
    return pca_result


def create_3d_scatter_plot(data, labels, axis_range):
    x, y, z = data[:, 0], data[:, 1], data[:, 2]

    str_labels = [str(label) for label in labels]
    color = str_labels

    # pca_scalar_field = np.linalg.norm(ref_a, axis=1)
    # color = pca_scalar_field
    # color_continuous_scale="Viridis")
    fig = px.scatter_3d(x=x, y=y, z=z, color=color)

    fig.update_traces(marker=dict(size=1))
    fig.update_layout(
        scene=dict(
            xaxis=dict(range=axis_range),
            yaxis=dict(range=axis_range),
            zaxis=dict(range=axis_range),
            aspectmode="cube",
            aspectratio=dict(x=1, y=1, z=1),
        ),
        showlegend=False,
    )
    return fig

In [77]:
preproc_dict = {}
with safe_open(
    f"{file_dir}/{ae_name}_preprocessing.safetensors", framework="pt", device="cpu"
) as f:
    for k in f.keys():
        preproc_dict[k] = f.get_tensor(k)

In [78]:
# Enable Panel for Jupyter
pn.extension()

# Activations from original model
with torch.no_grad():
    orig_act_dict, proc_act_dict, _ = prepare_acts(
        data_train_loader=dataloader,
        model=model,
        device=device,
        svd_dim=ae_metadata["in_features"],
        whiten_alpha=preproc_dict["wh_alpha_0"],
        preprocess=preprocess,
        preprocess_dict=preproc_dict,
        only_first_last=True,
    )
    proc_act_dict = orig_act_dict if not preprocess else proc_act_dict

align_idx = 0
ref_act = compute_reference_bases(proc_act_dict[align_idx])

In [79]:
init_idx = list(orig_act_dict.keys())[0]
final_idx = list(orig_act_dict.keys())[-1]

with torch.no_grad():
    x_proj = proc_act_dict[init_idx]
    ae_result = autoencoder(x_proj, intermediate=True)

print(ae_result.predictions.shape)

torch.Size([101, 3000, 10])


In [80]:
# # Slider
# layer_select = pn.widgets.IntSlider(
#     name="Layer Selector", start=0, end=list(proc_act_dict.keys())[-1], step=1, value=0
# )


# # Plotter
# def update_plots(data_a, data_b, ref_a, ref_b, labels):
#     """Updates PCA and RP plots with the given data and references."""
#     # Default axis range
#     pca_axis_range = [-0.05, 0.05]

#     # First plot: PCA
#     aligned_pca_result = process_pca_and_align(data_a, ref_a)
#     first_fig = create_3d_scatter_plot(aligned_pca_result, labels, pca_axis_range)

#     # Second plot: PCA
#     aligned_pca_result = process_pca_and_align(data_b, ref_b)
#     second_fig = create_3d_scatter_plot(aligned_pca_result, labels, pca_axis_range)

#     return first_fig, second_fig


# # View fn
# @pn.depends(layer_select.param.value)
# def view(layer_index):
#     figs = update_plots(
#         proc_act_dict[layer_index],
#         proc_act_dict[layer_index],
#         ref_act,
#         ref_act,
#         dataset.labels if not subset_size else dataset.labels[:subset_size],
#     )

#     panes = [pn.pane.Plotly(fig) for fig in figs]

#     return pn.Row(*panes, align="center")


# # Layout
# layout = pn.Column(
#     pn.Row(layer_select, align="center"),
#     view,
#     align="center",
#     sizing_mode="stretch_width",
# )

# # Display
# layout.show()

In [None]:
# Slider
layer_select = pn.widgets.IntSlider(
    name="Layer Selector", start=0, end=ae_result.predictions.shape[0], step=1, value=0
)


# Plotter
def update_plots(data_a, data_b, ref_a, ref_b, labels):
    """Updates PCA and RP plots with the given data and references."""
    # Default axis range
    pca_axis_range = [-0.05, 0.05]

    # First plot: PCA
    aligned_pca_result = process_pca_and_align(data_a, ref_a)
    first_fig = create_3d_scatter_plot(aligned_pca_result, labels, pca_axis_range)

    # Second plot: PCA
    aligned_pca_result = process_pca_and_align(data_b, ref_b)
    second_fig = create_3d_scatter_plot(aligned_pca_result, labels, pca_axis_range)

    return first_fig, second_fig


# View fn
@pn.depends(layer_select.param.value)
def view(layer_index):
    figs = update_plots(
        ae_result.predictions[layer_index],
        ae_result.predictions[layer_index],
        ref_act,
        ref_act,
        dataset.labels if not subset_size else dataset.labels[:subset_size],
    )

    panes = [pn.pane.Plotly(fig) for fig in figs]

    return pn.Row(*panes, align="center")


# Layout
layout = pn.Column(
    pn.Row(layer_select, align="center"),
    view,
    align="center",
    sizing_mode="stretch_width",
)

# Display
layout.show()

Launching server at http://localhost:65182


<panel.io.server.Server at 0x389df2510>

