In [29]:
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from torch import nn

from koopmann import aesthetics
from koopmann.data import (
    DatasetConfig,
    create_data_loader,
    get_dataset_class,
)
from koopmann.models import MLP
from koopmann.models.utils import parse_safetensors_metadata
from koopmann.utils import (
    compute_model_accuracy,
)
from koopmann.visualization import plot_decision_boundary

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [30]:
model_name = "mnist_model"

In [31]:
file_path = f"/home/nsa325/work/koopmann/model_saves/{model_name}.safetensors"
model, _ = MLP.load_model(file_path)
model.eval()
model.summary()

Layer (type (var_name))                       Param #
MLP (MLP)                                     --
+ Sequential (_features)                      --
|    + LinearLayer (0)                        --
|    |    + Sequential (layers)               100,736
|    + LinearLayer (1)                        --
|    |    + Sequential (layers)               8,384
|    + LinearLayer (2)                        --
|    |    + Sequential (layers)               2,144
|    + LinearLayer (3)                        --
|    |    + Sequential (layers)               560
|    + LinearLayer (4)                        --
|    |    + Sequential (layers)               190
Total params: 112,014
Trainable params: 112,014
Non-trainable params: 0

In [32]:
# Dataset config
metadata = parse_safetensors_metadata(file_path=file_path)
print(metadata)
dataset_config = DatasetConfig(
    dataset_name=metadata["dataset"],
    num_samples=3_000,
    split="test",
    seed=42,
)
DatasetClass = get_dataset_class(name=dataset_config.dataset_name)
dataset = DatasetClass(config=dataset_config)

{'bias': 'True', 'config': '[128, 64, 32, 16]', 'dataset': 'MNISTDataset', 'input_dimension': '784', 'nonlinearity': 'relu', 'output_dimension': '10'}


In [33]:
print(f"Testing Accuracy: {compute_model_accuracy(model, dataset)}")

Testing Accuracy: 0.9757999777793884


In [34]:
if dataset.in_features == 2:
    plot_decision_boundary(
        model,
        model.state_dict(),
        dataset.features,
        dataset.labels.squeeze(),
        labels=[0, 1, 2],
    )

In [35]:
metadata = parse_safetensors_metadata(file_path=file_path)
dataset_config = DatasetConfig(
    dataset_name=metadata["dataset"],
    num_samples=5_000,
    split="test",
    seed=21,
)
DatasetClass = get_dataset_class(name=dataset_config.dataset_name)
dataset = DatasetClass(config=dataset_config)
dataloader = create_data_loader(dataset, batch_size=1024)

In [54]:
model

MLP(
  (_features): Sequential(
    (0): LinearLayer(
      (layers): Sequential(
        (linear): Linear(in_features=784, out_features=128, bias=True)
        (batchnorm): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (nonlinearity): ReLU()
      )
    )
    (1): LinearLayer(
      (layers): Sequential(
        (linear): Linear(in_features=128, out_features=64, bias=True)
        (batchnorm): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (nonlinearity): ReLU()
      )
    )
    (2): LinearLayer(
      (layers): Sequential(
        (linear): Linear(in_features=64, out_features=32, bias=True)
        (batchnorm): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (nonlinearity): ReLU()
      )
    )
    (3): LinearLayer(
      (layers): Sequential(
        (linear): Linear(in_features=32, out_features=16, bias=True)
        (batchnorm): BatchNorm1d(16, eps=1e-05, mo

In [53]:
# Attach model hooks
model.hook_model()

# Raw forward pass
images, labels = next(iter(dataloader))
images = images.flatten(start_dim=1)
with torch.no_grad():
    _ = model.forward(images)

all_acts = model.get_fwd_activations()

print(images.mean())
for key in all_acts.keys():
    print(all_acts[key].mean())

tensor(0.0066)
tensor(0.3071)
tensor(0.2930)
tensor(0.3090)
tensor(0.3839)
tensor(0.0026)
