In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from rich import print as rprint
from torch import nn

from koopmann import aesthetics
from koopmann.data import (
    DatasetConfig,
    create_data_loader,
    get_dataset_class,
)
from koopmann.models import MLP
from koopmann.models.utils import parse_safetensors_metadata
from koopmann.utils import (
    compute_model_accuracy,
)
from koopmann.visualization import plot_decision_boundary

%load_ext autoreload
%autoreload 2

In [None]:
model_name = "mnist_model"

In [None]:
file_path = f"/scratch/nsa325/koopmann_model_saves/{model_name}.safetensors"
model, _ = MLP.load_model(file_path)
model.eval()
model.summary()

In [None]:
# Dataset config
metadata = parse_safetensors_metadata(file_path=file_path)
rprint(metadata)
dataset_config = DatasetConfig(
    dataset_name=metadata["dataset"],
    num_samples=3_000,
    split="test",
    seed=42,
)
DatasetClass = get_dataset_class(name=dataset_config.dataset_name)
dataset = DatasetClass(config=dataset_config)

In [None]:
rprint(f"Testing Accuracy: {compute_model_accuracy(model, dataset)}")

In [None]:
if dataset.in_features == 2:
    plot_decision_boundary(
        model,
        model.state_dict(),
        dataset.features,
        dataset.labels.squeeze(),
        labels=[0, 1, 2],
    )

In [None]:
metadata = parse_safetensors_metadata(file_path=file_path)
dataset_config = DatasetConfig(
    dataset_name=metadata["dataset"],
    num_samples=5_000,
    split="test",
    seed=21,
)
DatasetClass = get_dataset_class(name=dataset_config.dataset_name)
dataset = DatasetClass(config=dataset_config)
dataloader = create_data_loader(dataset, batch_size=1024)

In [None]:
# Attach model hooks
model.hook_model()

# Raw forward pass
images, labels = next(iter(dataloader))
images = images.flatten(start_dim=1)
with torch.no_grad():
    _ = model.forward(images)

all_acts = model.get_fwd_activations()

print(images.mean())
for key in all_acts.keys():
    print(all_acts[key].mean())