In [11]:
import os

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

In [12]:
import math
from pathlib import Path

import numpy as np
import plotly.express as px
import torch
from rich import print as rprint
from rich.console import Console
from rich.table import Table
from safetensors import safe_open
from sklearn.decomposition import PCA
from torch.utils.data import DataLoader
from torcheval.metrics import MulticlassAccuracy

from analysis.utils import load_autoencoder, load_model
from koopmann import aesthetics
from koopmann.data import DatasetConfig, get_dataset_class
from koopmann.models import ConvResNet
from koopmann.shape_metrics import prepare_acts, undo_preprocessing_acts
from koopmann.utils import set_seed

set_seed(36)


%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


File setup

In [13]:
mlp_file_dir = "/scratch/nsa325/koopmann_model_saves"
data_root = "/scratch/nsa325/datasets/"

# dataset_name = "mnist"
# model_name = f"resmlp_{dataset_name}"

dataset_name = "mnist"
model_name = f"resmlp_{dataset_name}"

ae_file_dir = f"/scratch/nsa325/koopmann_model_saves/{dataset_name}"

device = "cpu"

Load model

In [14]:
model, model_metadata = load_model(mlp_file_dir, model_name)
model.hook_model().eval().to(device)
rprint("MLP Metadata: ", model_metadata)

Load test dataset

In [15]:
# Build dataset
test_dataset_config = DatasetConfig(
    dataset_name=model_metadata["dataset"], num_samples=5_000, split="test", seed=42
)
DatasetClass = get_dataset_class(name=test_dataset_config.dataset_name)
test_dataset = DatasetClass(config=test_dataset_config, root=data_root)
test_labels = test_dataset.labels.squeeze()

# Make dataloader
batch_size = 5_000
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

Load autoencoders

In [16]:
ae_files = sorted(os.listdir(ae_file_dir))
ae_files = [
    Path(file) for file in ae_files if ("autoencoder" in file and "preprocessing" not in file)
]

In [17]:
def koopman_intermediates(
    init_idx,
    final_idx,
    orig_act_dict,
    proc_act_dict,
    device,
    preproc_dict,
    autoencoder,
    model,
    preprocess,
    k_steps,
):
    with torch.no_grad():
        # x = orig_act_dict[init_idx]
        # y = orig_act_dict[final_idx]

        x_proj = proc_act_dict[init_idx]
        y_proj = proc_act_dict[final_idx]

        # if preprocess:
        #     x_unproj = undo_preprocessing_acts(x_proj, preproc_dict, init_idx, device)
        #     y_unproj = undo_preprocessing_acts(y_proj, preproc_dict, final_idx, device)
        # else:
        #     x_unproj = x_proj
        #     y_unproj = y_proj

        # Reconstruct first act
        x_proj_obs = autoencoder.encode(x_proj)
        x_proj_recon = autoencoder.decode(x_proj_obs)

        # Reconstruct final act
        y_proj_obs = autoencoder.encode(y_proj)
        y_proj_recon = autoencoder.decode(y_proj_obs)

        pred_proj_obs = autoencoder.koopman_forward(x_proj_obs, k_steps)
        pred_proj = autoencoder.decode(pred_proj_obs)

        if preprocess:
            pred = undo_preprocessing_acts(pred_proj, preproc_dict, final_idx, device)
            y_recon = undo_preprocessing_acts(y_proj_recon, preproc_dict, final_idx, device)
        else:
            pred = pred_proj
            y_recon = y_proj_recon

        pred = model.components[-1:](pred)

        # Return all requested variables in a dictionary
        results = {
            "x_proj": x_proj,
            "y_proj": y_proj,
            "x_proj_obs": x_proj_obs,
            "y_proj_obs": y_proj_obs,
            "pred_proj_obs": pred_proj_obs,
            "x_proj_recon": x_proj_recon,
            "y_proj_recon": y_proj_recon,
            "y_recon": y_recon,
            "pred_proj": pred_proj,
            "pred": pred,
        }

        return results


Evaluate

In [18]:
mlp_per_class_metric = MulticlassAccuracy(num_classes=test_dataset.out_features, average=None)
for inputs, labels in test_dataloader:
    mlp_pred = model(inputs)
    mlp_per_class_metric.update(mlp_pred, labels.squeeze().long())
mlp_acc = mlp_per_class_metric.compute()

In [19]:
accuracies = {}

for ae_file in ae_files:
    # Load preprocessing dict and autoencoder
    preproc_dict = {}
    with safe_open(
        f"{ae_file_dir}/{ae_file.stem}_preprocessing.safetensors", framework="pt", device=device
    ) as f:
        for k in f.keys():
            preproc_dict[k] = f.get_tensor(k)
    autoencoder, ae_metadata = load_autoencoder(ae_file_dir, ae_file.stem)
    seed_loc_in_str = ae_file.stem.find("seed_") + 5
    seed = int(ae_file.stem[seed_loc_in_str:])

    # Prepared activations
    orig_act_dict, processed_act_dict, _ = prepare_acts(
        data_train_loader=test_dataloader,
        model=model,
        device=device,
        svd_dim=ae_metadata["in_features"],
        whiten_alpha=preproc_dict["wh_alpha_0"],
        preprocess=True,
        preprocess_dict=preproc_dict,
        only_first_last=True,
    )
    init_idx = list(orig_act_dict.keys())[0]
    final_idx = list(orig_act_dict.keys())[-1]

    # Koopman intermediates
    test_intermediates = koopman_intermediates(
        init_idx=init_idx,
        final_idx=final_idx,
        orig_act_dict=orig_act_dict,
        proc_act_dict=processed_act_dict,
        device=device,
        preproc_dict=preproc_dict,
        autoencoder=autoencoder,
        model=model,
        preprocess=True,
        k_steps=ae_metadata["k_steps"],
    )
    pred = test_intermediates["pred"]
    x_proj_obs = test_intermediates["x_proj_obs"]

    # Per-class accuracy
    per_class_metric = MulticlassAccuracy(num_classes=test_dataset.out_features, average=None)
    per_class_metric.update(pred, test_labels.to(torch.long))
    # print("Koopman accuracy per class (original):", per_class_metric.compute())
    accuracies[seed] = per_class_metric.compute()


Processing activations: 100%|██████████| 2/2 [00:00<00:00, 19.60it/s]
Processing activations: 100%|██████████| 2/2 [00:00<00:00, 19.08it/s]
Processing activations: 100%|██████████| 2/2 [00:00<00:00, 12.15it/s]
Processing activations: 100%|██████████| 2/2 [00:00<00:00, 14.83it/s]
Processing activations: 100%|██████████| 2/2 [00:00<00:00, 13.12it/s]


Visualization

In [20]:
def format_array(arr, multiplier=100, precision=2):
    """Format array values as percentages with specified precision"""
    values = [f"{x*multiplier:.{precision}f}%" for x in arr]
    return "[" + ", ".join(values) + "]"


table = Table(title=f"{dataset_name} Accuracies")
table.add_column("Seed")
table.add_column("Per-Class Accuracy")
table.add_column("Overall Accuracy")

# Convert dict to numpy arrays for vectorized operations
seeds = list(accuracies.keys())
values = np.array([v.numpy() for v in accuracies.values()])
overall = np.array([v.mean().item() for v in accuracies.values()])

# Add individual rows
for i, seed in enumerate(seeds):
    # Add end_section=True to the last seed row to create a separator line
    if i == len(seeds) - 1:
        table.add_row(
            str(seed), format_array(values[i]), f"{overall[i]*100:.2f}%", end_section=True
        )
    else:
        table.add_row(str(seed), format_array(values[i]), f"{overall[i]*100:.2f}%")

# Add statistics rows
table.add_row("Mean", format_array(values.mean(axis=0)), f"{overall.mean()*100:.2f}%", style="bold")
table.add_row(
    "Std",
    format_array(values.std(axis=0)),
    f"{overall.std()*100:.2f}%",
    style="bold",
    end_section=True,
)
table.add_row(
    "MLP",
    format_array(mlp_per_class_metric.compute()),
    f"{mlp_per_class_metric.compute().mean()*100:.2f}%",
    style="bold",
)


console = Console()
console.print(table)