In [1]:
import ipywidgets as widgets
import plotly.express as px
import torch
from IPython.display import clear_output, display
from torch.utils.data import DataLoader, Subset

from analysis.common import load_model
from koopmann import aesthetics
from koopmann.data import DatasetConfig, get_dataset_class
from koopmann.log import logger
from koopmann.utils import get_device, set_seed
from scripts.train_ae.shape_metrics import build_acts_dict, preprocess_acts

set_seed(21)

%load_ext autoreload
%autoreload 2

Control panel

In [2]:
dataset_name = "cifar10"
model_name = f"convresnet_{dataset_name}"
file_dir = "/Users/nsa325/Documents/koopmann_model_saves"
data_root = "/Users/nsa325/datasets/"

File setup

In [3]:
if dataset_name == "lotusroot":
    dim = 20
    scale_idx = 1
    k_steps = 100
    flavor = "exponential"
elif dataset_name == "mnist":
    dim = 800
    scale_idx = 1
    k_steps = 10
    flavor = "exponential"
elif dataset_name == "cifar10":
    dim = 1_000
    scale_idx = 1
    k_steps = 100
    flavor = "exponential"
else:
    raise NotImplementedError()

ae_name = f"dim_{dim}_k_{k_steps}_loc_{scale_idx}_{flavor}_autoencoder_{dataset_name}_model"
device = get_device()

Load models

In [4]:
model, model_metadata = load_model(file_dir, model_name)
model.eval().hook_model().to(device)
print("Model: ", model_metadata)


FileNotFoundError: [Errno 2] No such file or directory: '/Users/nsa325/Documents/koopmann_model_saves/convresnet_cifar10.safetensors'

In [None]:
dataset_config = DatasetConfig(
    dataset_name=model_metadata["dataset"],
    num_samples=3_000,
    split="train",
    seed=42,
)
DatasetClass = get_dataset_class(name=dataset_config.dataset_name)
dataset = DatasetClass(config=dataset_config, root=data_root)

subset_size = None
if subset_size:
    subset_indices = list(range(0, subset_size))
    subset = Subset(dataset, subset_indices)

batch_size = 5_000
batch_size = min(subset_size, batch_size) if subset_size else batch_size
dataloader = DataLoader(subset if subset_size else dataset, batch_size=batch_size, shuffle=False)

In [None]:
original_act_dict = build_acts_dict(
    data_train_loader=dataloader, model=model, only_first_last=False, device=device
)

In [None]:
svd_dim = 3
whiten_alpha = 0.5
processed_act_dict = preprocess_acts(
    original_act_dict=original_act_dict,
    svd_dim=svd_dim,
    whiten_alpha=whiten_alpha,
    preprocess_dict={},
    device=device,
    skip_svd=False,
)

In [None]:
def update_plot(change):
    # First clear everything
    clear_output(wait=True)

    # Re-display the slider first (so it appears above the plot)
    display(layer_slider)

    # Get the selected layer index and key
    coords = processed_act_dict[layer_keys[layer_slider.value]].cpu()
    target_categories = [str(t) for t in dataset.labels[:subset_size]]

    # Create and display the new plot
    fig = px.scatter_3d(
        x=coords[:, 0],
        y=coords[:, 1],
        z=coords[:, 2],
        color=target_categories,
        color_discrete_sequence=px.colors.qualitative.T10,
    )
    fig.update_traces(marker_size=2)
    fig.update_layout(showlegend=False)
    fig.show()


# Get the layer keys and convert to a list for indexing
layer_keys = list(processed_act_dict.keys())
layer_slider = widgets.IntSlider(
    value=0,
    min=0,
    max=len(layer_keys) - 1,
    step=1,
    description="Layer:",
    continuous_update=False,  # Only update when slider is released
)

# Connect slider to update function
layer_slider.observe(update_plot, names="value")
update_plot(None)