In [1]:
from rich import print as rprint
from torch.utils.data import DataLoader

from analysis.common import load_model
from koopmann import aesthetics
from koopmann.data import (
    DatasetConfig,
    get_dataset_class,
)
from koopmann.utils import (
    compute_model_accuracy,
    get_device,
)
from koopmann.visualization import plot_decision_boundary

%load_ext autoreload
%autoreload 2

In [2]:
file_dir = "/scratch/nsa325/koopmann_model_saves"
model_name = "resmlp_mnist"
device = get_device()

In [3]:
model, model_metadata = load_model(file_dir, model_name)
model.hook_model().eval().to(device)
print(model_metadata)

{'batchnorm': True, 'bias': True, 'created_at': '2025-04-09T02:41:58.432513', 'dataset': 'MNISTDataset', 'hidden_config': [784, 784, 784, 784], 'in_features': 784, 'model_class': 'ResMLP', 'nonlinearity': 'relu', 'out_features': 10, 'stochastic_depth_mode': 'batch', 'stochastic_depth_prob': 0.0}


In [4]:
# Dataset config
dataset_config = DatasetConfig(
    dataset_name=model_metadata["dataset"],
    num_samples=3_000,
    split="test",
    seed=42,
)
DatasetClass = get_dataset_class(name=dataset_config.dataset_name)
dataset = DatasetClass(config=dataset_config)
dataloader = DataLoader(dataset, batch_size=512)

In [5]:
rprint(f"Testing Accuracy: {compute_model_accuracy(model, dataloader, device)}")

In [6]:
if dataset.in_features == 2:
    plot_decision_boundary(
        model,
        model.state_dict(),
        dataset.features,
        dataset.labels.squeeze(),
        labels=[0, 1, 2],
    )