In [37]:
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from rich import print as rprint
from torch import nn

from koopmann import aesthetics
from koopmann.data import (
    DatasetConfig,
    create_data_loader,
    get_dataset_class,
)
from koopmann.models import MLP
from koopmann.models.utils import parse_safetensors_metadata
from koopmann.utils import (
    compute_model_accuracy,
)
from koopmann.visualization import plot_decision_boundary

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [38]:
model_name = "mnist_probed"

In [39]:
file_path = f"/scratch/nsa325/koopmann_model_saves/{model_name}.safetensors"
model, _ = MLP.load_model(file_path)

# Update nonlinearities
model.modules[-2].remove_nonlinearity()
model.modules[-3].update_nonlinearity("leakyrelu")
model.eval()
model.hook_model()

In [40]:
# Dataset config
metadata = parse_safetensors_metadata(file_path=file_path)
rprint(metadata)
dataset_config = DatasetConfig(
    dataset_name=metadata["dataset"],
    num_samples=3_000,
    split="test",
    seed=42,
)
DatasetClass = get_dataset_class(name=dataset_config.dataset_name)
dataset = DatasetClass(config=dataset_config)

In [41]:
rprint(f"Testing Accuracy: {compute_model_accuracy(model, dataset)}")

In [42]:
act_dict = model.get_fwd_activations()
act_dict[4]

tensor([[ 8.3945e-01,  1.3439e+00,  1.1622e+00,  ..., -1.0257e-02,
          1.4405e+00, -2.2859e-03],
        [-8.5872e-03, -6.6275e-03, -4.5610e-03,  ...,  9.9042e-01,
         -7.5687e-03,  1.4994e-03],
        [ 1.3039e-01, -8.5643e-03,  2.3912e-01,  ...,  6.8025e-01,
         -4.0691e-03, -3.3473e-03],
        ...,
        [-1.5791e-03, -2.8675e-03, -5.0830e-03,  ...,  1.0581e+00,
         -3.8913e-03,  8.3512e-01],
        [ 1.6383e-01,  2.0019e+00,  5.6350e-01,  ..., -8.5797e-03,
          1.9227e+00, -7.0373e-03],
        [ 1.1906e+00,  6.4509e-01, -4.5690e-03,  ..., -8.8295e-03,
         -1.1835e-02,  2.4179e-01]])