In [143]:
import math

import plotly.express as px
import torch
import torch.nn.functional as F
from plotly.subplots import make_subplots
from safetensors import safe_open
from torch.utils.data import DataLoader
from torcheval.metrics import MulticlassAccuracy

from analysis.common import load_autoencoder, load_model

# from koopmann import aesthetics
from koopmann.data import DatasetConfig, get_dataset_class
from koopmann.utils import get_device, set_seed
from scripts.train_ae.shape_metrics import prepare_acts, undo_processing

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [144]:
set_seed(21)

In [145]:
file_dir = "/scratch/nsa325/koopmann_model_saves"
dim = 1_024
k_steps = 1
scale_idx = 1
dataset_name = "mnist"

rank = 50
flavor = f"lowrank_{rank}"
# flavor = "standard"
# flavor = "exponential"

model_name = "resmlp_mnist"
ae_name = f"dim_{dim}_k_{k_steps}_loc_{scale_idx}_{flavor}_autoencoder_{dataset_name}_model"
device = get_device()

In [146]:
model, model_metadata = load_model(file_dir, model_name)
model.eval().hook_model()
print(model_metadata)

{'batchnorm': True, 'bias': True, 'created_at': '2025-04-09T02:41:58.432513', 'dataset': 'MNISTDataset', 'hidden_config': [784, 784, 784, 784], 'in_features': 784, 'model_class': 'ResMLP', 'nonlinearity': 'relu', 'out_features': 10, 'stochastic_depth_mode': 'batch', 'stochastic_depth_prob': 0.0}


In [147]:
# Dataset config
batch_size = 10_000
dataset_config = DatasetConfig(
    dataset_name=model_metadata["dataset"],
    num_samples=3_000,
    split="test",
    seed=42,
)
DatasetClass = get_dataset_class(name=dataset_config.dataset_name)
dataset = DatasetClass(config=dataset_config)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

In [148]:
autoencoder, ae_metadata = load_autoencoder(file_dir, ae_name)
autoencoder.eval()
new_dim = ae_metadata["in_features"]
preprocess = ae_metadata["preprocess"]
K_matrix = autoencoder.koopman_weights.T
print(ae_metadata)

{'batchnorm': False, 'bias': True, 'created_at': '2025-04-21T23:41:54.917088', 'hidden_config': [], 'in_features': 784, 'k_steps': 1, 'latent_features': 1024, 'model_class': 'LowRankKoopmanAutoencoder', 'nonlinearity': 'leaky_relu', 'preprocess': True, 'rank': 50}


In [149]:
preproc_dict = {}
with safe_open(
    f"{file_dir}/{ae_name}_preprocessing.safetensors", framework="pt", device="cpu"
) as f:
    for k in f.keys():
        preproc_dict[k] = f.get_tensor(k)


orig_act_dict, proc_act_dict, _ = prepare_acts(
    data_train_loader=dataloader,
    model=model,
    device=device,
    new_dim=ae_metadata["in_features"],
    whiten_alpha=1,
    preprocess=True,
    preprocess_dict=preproc_dict,
    only_first_last=True,
)

In [150]:
with torch.no_grad():
    images, labels = next(iter(dataloader))
    labels = labels.squeeze()
    model_pred = model(images)

In [151]:
with torch.no_grad():
    x = orig_act_dict[0]
    x_proj = proc_act_dict[0]

    y = orig_act_dict[4]
    y_proj = proc_act_dict[4]

    if preprocess:
        x_unproj = undo_processing(x_proj, preproc_dict, 0)
        y_unproj = undo_processing(y_proj, preproc_dict, 4)
    else:
        x_unproj = x_proj
        y_unproj = y_proj

    # Reconstruct first act
    x_proj_obs = autoencoder.encode(x_proj)
    x_proj_recon = autoencoder.decode(x_proj_obs)

    pred_proj_obs = x_proj_obs @ torch.linalg.matrix_power(K_matrix, int(k_steps))
    pred_proj = autoencoder.decode(pred_proj_obs)

    if preprocess:
        pred = undo_processing(pred_proj, preproc_dict, 4)
    else:
        pred = pred_proj

    # Reconstruct final act
    y_proj_obs = autoencoder.encode(y_proj)
    y_proj_recon = autoencoder.decode(y_proj_obs)

    if preprocess:
        y_recon = undo_processing(y_proj_recon, preproc_dict, 4)
    else:
        y_recon = y_proj_recon

    # Feed pred to clssifier
    koopman_pred = torch.argmax(model.components[-1:](pred), dim=1)


In [152]:
model_metric = MulticlassAccuracy(num_classes=dataset.out_features)
model_metric.update(model_pred, labels)

koopman_metric = MulticlassAccuracy(num_classes=dataset.out_features)
koopman_metric.update(koopman_pred[:batch_size], labels)

print("Original accuracy: ", model_metric.compute())
print("Koopman accuracy: ", koopman_metric.compute())


Original accuracy:  tensor(0.9904)
Koopman accuracy:  tensor(0.9846)


In [153]:
def compare_images(
    original,
    reconstructed,
    reshape_dims=None,
    height=400,
    width=800,
    titles=["Original", "Reconstructed"],
):
    # Auto-calculate reshape dimensions if not provided
    if reshape_dims is None:
        total_elements = original.numel()
        sqrt_elements = math.sqrt(total_elements)

        if sqrt_elements.is_integer():
            side = int(sqrt_elements)
            reshape_dims = (side, side)
        else:
            side1 = int(math.sqrt(total_elements))
            while total_elements % side1 != 0 and side1 > 1:
                side1 -= 1

            if side1 > 1:
                side2 = total_elements // side1
                reshape_dims = (side1, side2)
            else:
                reshape_dims = (1, total_elements)

    fig = make_subplots(rows=1, cols=2, subplot_titles=titles)

    for i, img in enumerate([original, reconstructed], 1):
        fig.add_trace(px.imshow(img.reshape(reshape_dims)).data[0], row=1, col=i)

    fig.update_layout(height=height, width=width, xaxis_scaleanchor="y", xaxis2_scaleanchor="y2")

    error = F.mse_loss(original, reconstructed, reduction="mean")
    print(f"Error: {error:.6f}")

    return fig

In [154]:
sample_idx = torch.randint(batch_size, (1,))[0].item()

In [155]:
compare_images(
    x_proj[sample_idx], x_proj_recon[sample_idx], titles=["LoDim Input", "Recon. LoDim Input"]
)

Error: 0.000000


In [156]:
compare_images(
    y_proj[sample_idx], y_proj_recon[sample_idx], titles=["LoDim Target", "Recon. LoDim Target"]
)

Error: 0.000000


In [157]:
compare_images(y[sample_idx], y_recon[sample_idx], titles=["Original", "Recon"])

Error: 0.000048


In [158]:
compare_images(x_proj_obs[sample_idx], y_proj_obs[sample_idx], titles=["Obs Input", "Obs Target"])

Error: 0.000124


In [159]:
compare_images(
    pred_proj_obs[sample_idx], y_proj_obs[sample_idx], titles=["Obs Predicted", "Obs Target"]
)

Error: 0.000004


In [160]:
compare_images(pred_proj[sample_idx], y_proj[sample_idx], titles=["Original Input", "Recon. Input"])

Error: 0.000005


In [161]:
compare_images(pred[sample_idx], y[sample_idx], titles=["Original Input", "Recon. Input"])

Error: 0.254760


In [162]:
px.imshow(K_matrix.detach()[:200, :200])