In [10]:
import eagerpy as ep
import foolbox as fb
import matplotlib.pyplot as plt
import pandas as pd
import plotly.express as px
import seaborn as sns
import torch
from foolbox import PyTorchModel, accuracy, samples
from rich import print as rprint
from torch import nn
from torch.utils.data import DataLoader

from analysis.common import load_model
from analysis.residual_alignment_methods import alignment, plotsvals, sab, trajectories

# from koopmann import aesthetics
from koopmann.data import (
    DatasetConfig,
    get_dataset_class,
)
from koopmann.models import MLP, ResMLP
from koopmann.utils import (
    compute_model_accuracy,
    get_device,
)
from koopmann.visualization import plot_decision_boundary

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [11]:
file_dir = "/scratch/nsa325/koopmann_model_saves"
model_name = "resmlp_mnist"
device = get_device()

In [12]:
model, model_metadata = load_model(file_dir, model_name)
model.hook_model().eval()
print(model_metadata)

{'batchnorm': True, 'bias': True, 'created_at': '2025-04-09T02:41:58.432513', 'dataset': 'MNISTDataset', 'hidden_config': [784, 784, 784, 784], 'in_features': 784, 'model_class': 'ResMLP', 'nonlinearity': 'relu', 'out_features': 10, 'stochastic_depth_mode': 'batch', 'stochastic_depth_prob': 0.0}


In [13]:
# Dataset config
dataset_config = DatasetConfig(
    dataset_name=model_metadata["dataset"],
    num_samples=3_000,
    split="test",
    seed=42,
)
DatasetClass = get_dataset_class(name=dataset_config.dataset_name)
dataset = DatasetClass(config=dataset_config)
dataloader = DataLoader(dataset, batch_size=1000)

In [14]:
rprint(f"Testing Accuracy: {compute_model_accuracy(model, dataloader, device)}")

In [15]:
if dataset.in_features == 2:
    plot_decision_boundary(
        model,
        model.state_dict(),
        dataset.features,
        dataset.labels.squeeze(),
        labels=[0, 1, 2],
    )

#### Adversarial attacks

In [16]:
# epsilons = [0.0, 0.005, 0.01, 0.03, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1, 0.15]
# attacks = [fb.attacks.FGSM(), fb.attacks.LinfPGD(steps=50)]

# preprocessing = dict(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616], axis=-3)
# fmodel = fb.PyTorchModel(model, bounds=(0, 1), preprocessing=preprocessing)

# images, labels = ep.astensors(*samples(fmodel, dataset="mnist", batchsize=1024))
# raw_advs, clipped_advs, success = attacks[1](fmodel, images, labels, epsilons=epsilons)
# # calculate and report the robust accuracy (the accuracy of the model when
# # it is attacked)
# robust_accuracy = 1 - success.float32().mean(axis=-1)
# print("Robust accuracy for perturbations with")
# for eps, acc in zip(epsilons, robust_accuracy):
#     print(f"  Linf norm ≤ {eps:<6}: {acc.item() * 100:4.1f} %")