In [1]:
import math

import plotly.express as px
import torch
import torch.nn.functional as F
from plotly.subplots import make_subplots
from safetensors import safe_open
from torch.utils.data import DataLoader, Subset
from torcheval.metrics import MulticlassAccuracy

from analysis.common import load_autoencoder, load_model
from koopmann import aesthetics
from koopmann.data import DatasetConfig, get_dataset_class
from koopmann.models import ConvResNet
from koopmann.utils import set_seed
from scripts.train_ae.shape_metrics import prepare_acts, undo_preprocessing_acts

set_seed(21)

%load_ext autoreload
%autoreload 2

Control panel

In [2]:
dataset_name = "mnist"
model_name = f"resmlp_{dataset_name}"

File setup

In [3]:
if dataset_name == "lotusroot":
    dim = 20
    scale_idx = 1
    k_steps = 100
    flavor = "exponential"
elif dataset_name == "mnist":
    dim = 500
    scale_idx = 1
    k_steps = 1
    flavor = "standard"
else:
    raise NotImplementedError()

file_dir = "/scratch/nsa325/koopmann_model_saves"
ae_name = f"dim_{dim}_k_{k_steps}_loc_{scale_idx}_{flavor}_autoencoder_{dataset_name}_model"
device = "cpu"

Load models

In [4]:
model, model_metadata = load_model(file_dir, model_name)
model.eval().hook_model().to(device)
print("Model: ", model_metadata)

autoencoder, ae_metadata = load_autoencoder(file_dir, ae_name)
autoencoder.eval().to(device)
new_dim = ae_metadata["in_features"]
preprocess = ae_metadata["preprocess"]
K_matrix = autoencoder.koopman_weights.T
print("Autoencoder: ", ae_metadata)

Model:  {'batchnorm': True, 'bias': True, 'created_at': '2025-04-09T02:41:58.432513', 'dataset': 'MNISTDataset', 'hidden_config': [784, 784, 784, 784], 'in_features': 784, 'model_class': 'ResMLP', 'nonlinearity': 'relu', 'out_features': 10, 'stochastic_depth_mode': 'batch', 'stochastic_depth_prob': 0.0}
Autoencoder:  {'batchnorm': False, 'bias': True, 'created_at': '2025-05-01T14:37:36.070348', 'hidden_config': [1000], 'in_features': 784, 'k_steps': 1, 'latent_features': 500, 'model_class': 'KoopmanAutoencoder', 'nonlinearity': 'gelu', 'preprocess': True}




In [None]:
dataset_config = DatasetConfig(
    dataset_name=model_metadata["dataset"],
    num_samples=3_000,
    split="test",
    seed=42,
)
DatasetClass = get_dataset_class(name=dataset_config.dataset_name)
dataset = DatasetClass(config=dataset_config)

subset_size = 1_000
if subset_size:
    subset_indices = list(range(0, subset_size))
    subset = Subset(dataset, subset_indices)

batch_size = 3_000
batch_size = min(subset_size, batch_size) if subset_size else batch_size
dataloader = DataLoader(subset if subset_size else dataset, batch_size=batch_size, shuffle=False)

In [6]:
preproc_dict = {}
with safe_open(
    f"{file_dir}/{ae_name}_preprocessing.safetensors", framework="pt", device="cpu"
) as f:
    for k in f.keys():
        preproc_dict[k] = f.get_tensor(k)

In [7]:
print(f"Preprocess activations?: {preprocess}")
orig_act_dict, proc_act_dict, _ = prepare_acts(
    data_train_loader=dataloader,
    model=model,
    device=device,
    svd_dim=ae_metadata["in_features"],
    whiten_alpha=preproc_dict["wh_alpha_0"],
    preprocess=preprocess,
    preprocess_dict=preproc_dict,
    only_first_last=True,
)
proc_act_dict = orig_act_dict if not preprocess else proc_act_dict

Preprocess activations?: True


In [8]:
with torch.no_grad():
    images, labels = next(iter(dataloader))
    labels = labels.squeeze()
    model_pred = model(images.to(device))

In [9]:
init_idx = list(orig_act_dict.keys())[0]
final_idx = list(orig_act_dict.keys())[-1]

with torch.no_grad():
    x = orig_act_dict[init_idx]
    x_proj = proc_act_dict[init_idx]

    y = orig_act_dict[final_idx]
    y_proj = proc_act_dict[final_idx]

    if preprocess:
        x_unproj = undo_preprocessing_acts(x_proj, preproc_dict, init_idx, device)
        y_unproj = undo_preprocessing_acts(y_proj, preproc_dict, final_idx, device)
    else:
        x_unproj = x_proj
        y_unproj = y_proj

    # Reconstruct first act
    x_proj_obs = autoencoder.encode(x_proj)
    x_proj_recon = autoencoder.decode(x_proj_obs)

    # Reconstruct final act
    y_proj_obs = autoencoder.encode(y_proj)
    y_proj_recon = autoencoder.decode(y_proj_obs)

    pred_proj_obs = x_proj_obs @ torch.linalg.matrix_power(K_matrix, int(k_steps))
    pred_proj = autoencoder.decode(pred_proj_obs)

    if preprocess:
        pred = undo_preprocessing_acts(pred_proj, preproc_dict, final_idx, device)
    else:
        pred = pred_proj

    if preprocess:
        y_recon = undo_preprocessing_acts(y_proj_recon, preproc_dict, final_idx, device)
    else:
        y_recon = y_proj_recon

    if type(model) is ConvResNet:
        pred = pred.reshape(-1, 512, 4, 4)
        koopman_pred = torch.argmax(model.components[-2:](pred), dim=1)

    else:
        # Feed pred to classifier
        koopman_pred = torch.argmax(model.components[-1:](pred), dim=1)


In [10]:
dataset = DatasetClass(config=dataset_config)
model_metric = MulticlassAccuracy(num_classes=dataset.out_features)
model_metric.update(model_pred, labels)

koopman_metric = MulticlassAccuracy(num_classes=dataset.out_features)
koopman_metric.update(koopman_pred[:batch_size], labels)

print("Original accuracy: ", model_metric.compute())
print("Koopman accuracy: ", koopman_metric.compute())


Original accuracy:  tensor(0.9853)
Koopman accuracy:  tensor(0.9763)


In [11]:
def compare_images(
    original,
    reconstructed,
    reshape_dims=None,
    height=400,
    width=800,
    titles=["Original", "Reconstructed"],
):
    # Auto-calculate reshape dimensions if not provided
    if reshape_dims is None:
        total_elements = original.numel()
        sqrt_elements = math.sqrt(total_elements)

        if sqrt_elements.is_integer():
            side = int(sqrt_elements)
            reshape_dims = (side, side)
        else:
            side1 = int(math.sqrt(total_elements))
            while total_elements % side1 != 0 and side1 > 1:
                side1 -= 1

            if side1 > 1:
                side2 = total_elements // side1
                reshape_dims = (side1, side2)
            else:
                reshape_dims = (1, total_elements)

    fig = make_subplots(rows=1, cols=2, subplot_titles=titles)

    for i, img in enumerate([original, reconstructed], 1):
        fig.add_trace(px.imshow(img.reshape(reshape_dims)).data[0], row=1, col=i)

    fig.update_layout(height=height, width=width, xaxis_scaleanchor="y", xaxis2_scaleanchor="y2")

    error = F.mse_loss(original, reconstructed, reduction="mean")
    print(f"Error: {error:.6f}")

    return fig

In [12]:
sample_idx = torch.randint(batch_size, (1,))[0].item()
print(sample_idx)

2975


In [13]:
compare_images(
    x_proj[sample_idx].cpu(),
    x_proj_recon[sample_idx].cpu(),
    titles=["LoDim Input", "Recon. LoDim Input"],
)

Error: 0.000087


In [14]:
compare_images(
    y_proj[sample_idx].cpu(),
    y_proj_recon[sample_idx].cpu(),
    titles=["LoDim Target", "Recon. LoDim Target"],
)

Error: 0.000016


In [15]:
compare_images(
    y[sample_idx].flatten().cpu(), y_recon[sample_idx].cpu(), titles=["Original", "Recon"]
)

Error: 0.058389


In [16]:
compare_images(
    x_proj_obs[sample_idx].cpu(), y_proj_obs[sample_idx].cpu(), titles=["Obs Input", "Obs Target"]
)

Error: 0.000382


In [17]:
compare_images(
    pred_proj_obs[sample_idx].cpu(),
    y_proj_obs[sample_idx].cpu(),
    titles=["Obs Predicted", "Obs Target"],
)

Error: 0.000023


In [18]:
compare_images(
    pred_proj[sample_idx].cpu(),
    y_proj[sample_idx].cpu(),
    titles=["Predicted LoDim", "LoDim Target"],
)

Error: 0.000037


In [19]:
compare_images(pred[sample_idx].cpu(), y[sample_idx].cpu(), titles=["Predicted", "Target"])

Error: 0.254635


In [20]:
px.imshow(K_matrix.detach()[:200, :200].cpu())