In [None]:
import math
from ast import literal_eval

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.subplots as sp
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.parametrize as parametrize
from matrepr import mdisplay
from plotly.subplots import make_subplots
from rich import print as rprint
from torch import linalg
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms

from analysis.common import load_autoencoder, load_mlp
from koopmann import aesthetics
from koopmann.data import (
    DatasetConfig,
    create_data_loader,
    get_dataset_class,
)
from koopmann.models import (
    MLP,
    Autoencoder,
    ExponentialKoopmanAutencoder,
    LowRankKoopmanAutoencoder,
    ResMLP,
)
from koopmann.models.utils import get_device
from koopmann.visualization import plot_eigenvalues

%load_ext autoreload
%autoreload 2

In [2]:
from analysis.residual_alignment_methods import compare_model_autoencoder_acc

In [9]:
scale_idx = "0"
k = "1"
dim = "1024"

# flavor = "exponential"
# flavor = "standard"
flavor = "lowrank_20"

user = "nsa325"

file_dir = "/scratch/nsa325/koopmann_model_saves/"
mlp_name = "mnist_model_residual"
ae_name = f"dim_{dim}_k_{k}_loc_{scale_idx}_{flavor}_autoencoder_mnist_model.safetensors"


Load and prepare MLP

In [4]:
model, model_metadata = load_mlp(file_dir, mlp_name)

Build autoencoder

In [7]:
print(ae_name)

mnist_model


In [None]:
autoencoder, ae_metadata = load_autoencoder(file_dir + "scaling", ae_name)

AssertionError: Model file /scratch/nsa325/koopmann_model_saves//scaling/dim_1024_k_1_loc_0_lowrank_20_autoencoder_mnist_model.safetensors.safetensors does not exist.

Build data

In [None]:
dataset_config = DatasetConfig(
    dataset_name=model_metadata["dataset"], num_samples=5_000, split="test", seed=21
)
DatasetClass = get_dataset_class(name=dataset_config.dataset_name)
dataset = DatasetClass(config=dataset_config)

# Raw images and labels
raw_images, labels = dataset.data, dataset.labels

# Processed for MLP
mlp_transform = transforms.Compose(
    [
        transforms.Lambda(lambda x: x / 255),
        transforms.Normalize((0.1307,), (0.3081,)),
    ]
)
mlp_inputs = mlp_transform(raw_images)

# Processed for AE
ae_transform = transforms.Compose(
    [
        transforms.Lambda(lambda x: x / 255),
        transforms.Lambda(lambda x: x * 2 - 1),
    ]
)
ae_inputs = ae_transform(raw_images)

In [None]:
K_matrix = autoencoder.koopman_matrix.linear_layer.weight.T.detach()
k = literal_eval(ae_metadata["num_scaled"])

eigenvalues, eigenvectors = torch.linalg.eig(K_matrix)
plot_eigenvalues({(k, dim): eigenvalues}, axis=[-3, 3])

In [None]:
# acc_mlp, acc_koopman = compare_model_autoencoder_acc(
#     model, autoencoder, k, len(dataset.classes), mlp_inputs, ae_inputs, labels
# )
# mdisplay(acc_mlp, title="Original Model Testing Accuracy")
# mdisplay(acc_koopman, title="Autoencoder Prediction Testing Accuracy")