In [1]:
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from rich import print as rprint
from torch import nn

from koopmann import aesthetics
from koopmann.data import (
    DatasetConfig,
    create_data_loader,
    get_dataset_class,
)
from koopmann.models import MLP
from koopmann.models.utils import parse_safetensors_metadata
from koopmann.utils import (
    compute_model_accuracy,
)
from koopmann.visualization import plot_decision_boundary

%load_ext autoreload
%autoreload 2

In [2]:
model_name = "fashion_probed"

In [3]:
file_path = f"/scratch/nsa325/koopmann_model_saves/{model_name}.safetensors"
model, _ = MLP.load_model(file_path)

# Update nonlinearities
model.modules[-2].remove_nonlinearity()
model.modules[-3].update_nonlinearity("leakyrelu")
model.eval()
model.hook_model()

In [4]:
# Dataset config
metadata = parse_safetensors_metadata(file_path=file_path)
rprint(metadata)
dataset_config = DatasetConfig(
    dataset_name=metadata["dataset"],
    num_samples=3_000,
    split="test",
    seed=42,
)
DatasetClass = get_dataset_class(name=dataset_config.dataset_name)
dataset = DatasetClass(config=dataset_config)

In [5]:
rprint(f"Testing Accuracy: {compute_model_accuracy(model, dataset)}")

In [6]:
act_dict = model.get_fwd_activations()
act_dict[4]

tensor([[-5.4831e-03, -5.7824e-04, -2.0669e-02,  ..., -1.1237e-02,
          6.2775e-01,  1.2427e+00],
        [-1.3846e-03, -3.1844e-03,  4.0118e-01,  ...,  1.6159e-02,
         -9.3340e-03,  1.0704e-01],
        [ 2.9792e-02, -1.6480e-02,  1.4419e+00,  ..., -8.8502e-03,
         -4.4253e-03,  1.3427e+00],
        ...,
        [-3.2746e-04, -3.3126e-03,  5.9986e-01,  ..., -4.2345e-03,
         -3.1315e-03, -1.0727e-03],
        [ 5.5467e-01, -3.5242e-03,  8.6110e-01,  ..., -5.1412e-03,
         -1.6345e-02, -9.5266e-03],
        [ 1.3673e+00,  2.7034e-01,  7.4204e-01,  ...,  7.1479e-01,
          1.0990e+00,  4.2091e-01]])