In [131]:
import math

import plotly.express as px
import torch
import torch.nn.functional as F
from plotly.subplots import make_subplots
from safetensors import safe_open
from torch.utils.data import DataLoader, Subset
from torcheval.metrics import MulticlassAccuracy

from analysis.common import load_autoencoder, load_model
from koopmann import aesthetics
from koopmann.data import DatasetConfig, get_dataset_class
from koopmann.models import ConvResNet
from koopmann.utils import set_seed
from scripts.train_ae.shape_metrics import prepare_acts, undo_preprocessing_acts

set_seed(21)

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


Control panel

In [132]:
dataset_name = "lotusroot"
model_name = f"resmlp_{dataset_name}"
file_dir = "/Users/nsa325/Documents/koopmann_model_saves"
data_root = "/Users/nsa325/datasets/"

File setup

In [133]:
if dataset_name == "lotusroot":
    dim = 20
    scale_idx = 1
    k_steps = 100
    flavor = "exponential"
elif dataset_name == "mnist":
    dim = 800
    scale_idx = 1
    k_steps = 10
    flavor = "exponential"
elif dataset_name == "cifar10":
    dim = 1_000
    scale_idx = 1
    k_steps = 100
    flavor = "exponential"
else:
    raise NotImplementedError()

ae_name = f"dim_{dim}_k_{k_steps}_loc_{scale_idx}_{flavor}_autoencoder_{dataset_name}_model"
device = "cpu"

Load models

In [134]:
model, model_metadata = load_model(file_dir, model_name)
model.eval().hook_model().to(device)
print("Model: ", model_metadata)

autoencoder, ae_metadata = load_autoencoder(file_dir, ae_name)
autoencoder.eval().to(device)
new_dim = ae_metadata["in_features"]
preprocess = ae_metadata["preprocess"]
K_matrix = autoencoder.koopman_weights.T
print("Autoencoder: ", ae_metadata)

Model:  {'batchnorm': True, 'bias': True, 'created_at': '2025-04-30T10:33:25.774294', 'dataset': 'LotusRootDataset', 'hidden_config': [10, 10, 10, 10], 'in_features': 2, 'model_class': 'ResMLP', 'nonlinearity': 'relu', 'out_features': 2, 'stochastic_depth_mode': 'batch', 'stochastic_depth_prob': 0.0}
Autoencoder:  {'batchnorm': False, 'bias': True, 'created_at': '2025-05-01T13:48:07.748422', 'hidden_config': [30], 'in_features': 10, 'k_steps': 100, 'latent_features': 20, 'model_class': 'ExponentialKoopmanAutencoder', 'nonlinearity': 'leaky_relu', 'preprocess': True}


In [135]:
dataset_config = DatasetConfig(
    dataset_name=model_metadata["dataset"],
    num_samples=3_000,
    split="test",
    seed=42,
)
DatasetClass = get_dataset_class(name=dataset_config.dataset_name)
dataset = DatasetClass(config=dataset_config, root=data_root)

subset_size = 1_000
if subset_size:
    subset_indices = list(range(0, subset_size))
    subset = Subset(dataset, subset_indices)

batch_size = 3_000
batch_size = min(subset_size, batch_size) if subset_size else batch_size
dataloader = DataLoader(subset if subset_size else dataset, batch_size=batch_size, shuffle=False)

In [136]:
preproc_dict = {}
with safe_open(
    f"{file_dir}/{ae_name}_preprocessing.safetensors", framework="pt", device="cpu"
) as f:
    for k in f.keys():
        preproc_dict[k] = f.get_tensor(k)

In [137]:
print(f"Preprocess activations?: {preprocess}")
orig_act_dict, proc_act_dict, _ = prepare_acts(
    data_train_loader=dataloader,
    model=model,
    device=device,
    svd_dim=ae_metadata["in_features"],
    whiten_alpha=preproc_dict["wh_alpha_0"],
    preprocess=preprocess,
    preprocess_dict=preproc_dict,
    only_first_last=True,
)
proc_act_dict = orig_act_dict if not preprocess else proc_act_dict
init_idx = list(orig_act_dict.keys())[0]
final_idx = list(orig_act_dict.keys())[-1]

Preprocess activations?: True


In [138]:
px.imshow(K_matrix.detach().numpy())

In [139]:
with torch.no_grad():
    eigval, eigvec = torch.linalg.eig(K_matrix)
    dec_eigvec = autoencoder.decode(eigvec[0].real)
    dec_eigvec = undo_preprocessing_acts(dec_eigvec.unsqueeze(0), preproc_dict, final_idx, device)
px.imshow(-dec_eigvec.reshape(28, 28))

RuntimeError: shape '[28, 28]' is invalid for input of size 10