In [65]:
import math

import plotly.express as px
import torch
import torch.nn.functional as F
from plotly.subplots import make_subplots
from safetensors import safe_open
from torch.utils.data import DataLoader, Subset
from torcheval.metrics import MulticlassAccuracy

from analysis.common import load_autoencoder, load_model

# from koopmann import aesthetics
from koopmann.data import DatasetConfig, get_dataset_class
from koopmann.models import ConvResNet
from koopmann.utils import get_device, set_seed
from scripts.train_ae.shape_metrics import prepare_acts, undo_processing

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [66]:
set_seed(21)

In [67]:
file_dir = "/scratch/nsa325/koopmann_model_saves"
dim = 800
k_steps = 1
scale_idx = 1
dataset_name = "cifar10"

# rank = 100
# flavor = f"lowrank_{rank}"
flavor = "standard"
# flavor = "exponential"

model_name = f"convresnet_{dataset_name}"
ae_name = f"dim_{dim}_k_{k_steps}_loc_{scale_idx}_{flavor}_autoencoder_{dataset_name}_model"
device = get_device()
# device = "cpu"

In [68]:
model, model_metadata = load_model(file_dir, model_name)
model.eval().hook_model()
print(model_metadata)

{'batchnorm': True, 'bias': False, 'blocks_per_stage': [2, 2, 2, 2], 'created_at': '2025-04-24T01:30:56.594900', 'dataset': 'CIFAR10Dataset', 'hidden_config': [64, 128, 256, 512], 'in_channels': 3, 'initial_downsample_factor': 2, 'input_size': (32, 32), 'model_class': 'ConvResNet', 'nonlinearity': 'relu', 'out_features': 10, 'stochastic_depth_mode': 'batch', 'stochastic_depth_prob': 0.0}


In [69]:
# Dataset config
dataset_config = DatasetConfig(
    dataset_name=model_metadata["dataset"],
    num_samples=3_000,
    split="test",
    seed=42,
)
DatasetClass = get_dataset_class(name=dataset_config.dataset_name)
dataset = DatasetClass(config=dataset_config)

subset_size = 1_000
if subset_size:
    subset_indices = list(range(0, subset_size))
    dataset = Subset(dataset, subset_indices)

batch_size = 3_000
batch_size = min(subset_size, batch_size) if subset_size else batch_size
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

Files already downloaded and verified


In [70]:
autoencoder, ae_metadata = load_autoencoder(file_dir, ae_name)
autoencoder.eval().to(device)
new_dim = ae_metadata["in_features"]
preprocess = ae_metadata["preprocess"]
K_matrix = autoencoder.koopman_weights.T
print(ae_metadata)

{'batchnorm': False, 'bias': False, 'created_at': '2025-04-29T14:55:18.779415', 'hidden_config': [1600, 800, 800], 'in_features': 3000, 'k_steps': 1, 'latent_features': 800, 'model_class': 'KoopmanAutoencoder', 'nonlinearity': 'gelu', 'preprocess': True}


In [71]:
preproc_dict = {}
with safe_open(
    f"{file_dir}/{ae_name}_preprocessing.safetensors", framework="pt", device="cpu"
) as f:
    for k in f.keys():
        preproc_dict[k] = f.get_tensor(k)


orig_act_dict, proc_act_dict, _ = prepare_acts(
    data_train_loader=dataloader,
    model=model,
    device=device,
    new_dim=ae_metadata["in_features"],
    whiten_alpha=1,
    preprocess=True,
    preprocess_dict=preproc_dict,
    only_first_last=True,
)

In [72]:
with torch.no_grad():
    images, labels = next(iter(dataloader))
    labels = labels.squeeze()
    model_pred = model(images.to(device))

In [73]:
init_idx = list(orig_act_dict.keys())[0]
final_idx = list(orig_act_dict.keys())[-1]

with torch.no_grad():
    x = orig_act_dict[init_idx]
    x_proj = proc_act_dict[init_idx]

    y = orig_act_dict[final_idx]
    y_proj = proc_act_dict[final_idx]

    if preprocess:
        x_unproj = undo_processing(x_proj, preproc_dict, init_idx, device)
        y_unproj = undo_processing(y_proj, preproc_dict, final_idx, device)
    else:
        x_unproj = x_proj
        y_unproj = y_proj

    # Reconstruct first act
    x_proj_obs = autoencoder.encode(x_proj)
    x_proj_recon = autoencoder.decode(x_proj_obs)

    # Reconstruct final act
    y_proj_obs = autoencoder.encode(y_proj)
    y_proj_recon = autoencoder.decode(y_proj_obs)

    pred_proj_obs = x_proj_obs @ torch.linalg.matrix_power(K_matrix, int(k_steps))
    pred_proj = autoencoder.decode(pred_proj_obs)

    if preprocess:
        pred = undo_processing(pred_proj, preproc_dict, final_idx, device)
    else:
        pred = pred_proj

    if preprocess:
        y_recon = undo_processing(y_proj_recon, preproc_dict, final_idx, device)
    else:
        y_recon = y_proj_recon

    if type(model) is ConvResNet:
        pred = pred.reshape(-1, 512, 4, 4)
        koopman_pred = torch.argmax(model.components[-2:](pred), dim=1)

    else:
        # Feed pred to classifier
        koopman_pred = torch.argmax(model.components[-1:](pred), dim=1)


In [74]:
dataset = DatasetClass(config=dataset_config)
model_metric = MulticlassAccuracy(num_classes=dataset.out_features)
model_metric.update(model_pred, labels)

koopman_metric = MulticlassAccuracy(num_classes=dataset.out_features)
koopman_metric.update(koopman_pred[:batch_size], labels)

print("Original accuracy: ", model_metric.compute())
print("Koopman accuracy: ", koopman_metric.compute())


Files already downloaded and verified
Original accuracy:  tensor(0.9490)
Koopman accuracy:  tensor(0.9200)


In [75]:
def compare_images(
    original,
    reconstructed,
    reshape_dims=None,
    height=400,
    width=800,
    titles=["Original", "Reconstructed"],
):
    # Auto-calculate reshape dimensions if not provided
    if reshape_dims is None:
        total_elements = original.numel()
        sqrt_elements = math.sqrt(total_elements)

        if sqrt_elements.is_integer():
            side = int(sqrt_elements)
            reshape_dims = (side, side)
        else:
            side1 = int(math.sqrt(total_elements))
            while total_elements % side1 != 0 and side1 > 1:
                side1 -= 1

            if side1 > 1:
                side2 = total_elements // side1
                reshape_dims = (side1, side2)
            else:
                reshape_dims = (1, total_elements)

    fig = make_subplots(rows=1, cols=2, subplot_titles=titles)

    for i, img in enumerate([original, reconstructed], 1):
        fig.add_trace(px.imshow(img.reshape(reshape_dims)).data[0], row=1, col=i)

    fig.update_layout(height=height, width=width, xaxis_scaleanchor="y", xaxis2_scaleanchor="y2")

    error = F.mse_loss(original, reconstructed, reduction="mean")
    print(f"Error: {error:.6f}")

    return fig

In [76]:
sample_idx = torch.randint(batch_size, (1,))[0].item()
print(sample_idx)

658


In [77]:
compare_images(
    x_proj[sample_idx].cpu(),
    x_proj_recon[sample_idx].cpu(),
    titles=["LoDim Input", "Recon. LoDim Input"],
)

Error: 0.017500


In [78]:
compare_images(
    y_proj[sample_idx].cpu(),
    y_proj_recon[sample_idx].cpu(),
    titles=["LoDim Target", "Recon. LoDim Target"],
)

Error: 0.007718


In [79]:
compare_images(
    y[sample_idx].flatten().cpu(), y_recon[sample_idx].cpu(), titles=["Original", "Recon"]
)

Error: 0.281716


In [80]:
compare_images(
    x_proj_obs[sample_idx].cpu(), y_proj_obs[sample_idx].cpu(), titles=["Obs Input", "Obs Target"]
)

Error: 0.093370


In [81]:
compare_images(
    pred_proj_obs[sample_idx].cpu(),
    y_proj_obs[sample_idx].cpu(),
    titles=["Obs Predicted", "Obs Target"],
)

Error: 0.058635


In [82]:
compare_images(
    pred_proj[sample_idx].cpu(),
    y_proj[sample_idx].cpu(),
    titles=["Predicted LoDim", "LoDim Target"],
)

Error: 0.007892


In [83]:
compare_images(pred[sample_idx].cpu(), y[sample_idx].cpu(), titles=["Predicted", "Target"])

Error: 0.287262


In [84]:
px.imshow(K_matrix.detach()[:200, :200].cpu())