In [21]:
import math
import os

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
import torch
import torch.nn.functional as F
from sklearn.decomposition import PCA
from plotly.subplots import make_subplots
from safetensors import safe_open
from torch.utils.data import DataLoader, Subset, TensorDataset
from torcheval.metrics import MulticlassAccuracy
from tqdm import tqdm

from analysis.common import load_autoencoder, load_model
from koopmann import aesthetics
from koopmann.data import DatasetConfig, get_dataset_class
from koopmann.models import ConvResNet
from koopmann.utils import get_device, set_seed
from scripts.train_ae.shape_metrics import prepare_acts, undo_preprocessing_acts

set_seed(42)


%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


Control panel

In [22]:
dataset_name = "yinyang"
model_name = f"resmlp_{dataset_name}"
file_dir = "/Users/nsa325/koopmann_model_saves"
data_root = "/Users/nsa325/datasets/"

File setup

In [23]:
if dataset_name == "lotusroot":
    dim = 20
    scale_idx = 1
    k_steps = 100
    flavor = "exponential"
elif dataset_name == "torus":
    dim = 50
    scale_idx = 1
    k_steps = 100
    flavor = "exponential"
elif dataset_name == "yinyang":
    dim = 20
    scale_idx = 1
    k_steps = 100
    flavor = "exponential"
elif dataset_name == "mnist":
    dim = 800
    scale_idx = 1
    k_steps = 10
    flavor = "exponential"
elif dataset_name == "cifar10":
    dim = 1_000
    scale_idx = 1
    k_steps = 5000
    flavor = "exponential"
else:
    raise NotImplementedError()

ae_name = f"dim_{dim}_k_{k_steps}_loc_{scale_idx}_{flavor}_autoencoder_{dataset_name}_model"
device = "cpu"

Load models

In [24]:
model, model_metadata = load_model(file_dir, model_name)
model.eval().hook_model().to(device)
print("Model: ", model_metadata)

autoencoder, ae_metadata = load_autoencoder(file_dir, ae_name)
autoencoder.eval().to(device)
new_dim = ae_metadata["in_features"]
preprocess = ae_metadata["preprocess"]
K_matrix = autoencoder.koopman_weights.T
print("Autoencoder: ", ae_metadata)

Model:  {'batchnorm': True, 'bias': True, 'created_at': '2025-05-05T16:23:51.006581', 'dataset': 'YinYangDataset', 'hidden_config': [10, 10, 10, 10, 10, 10, 10, 10], 'in_features': 2, 'model_class': 'ResMLP', 'nonlinearity': 'relu', 'out_features': 3, 'stochastic_depth_mode': 'batch', 'stochastic_depth_prob': 0.0}
Autoencoder:  {'batchnorm': False, 'bias': True, 'created_at': '2025-05-07T11:02:43.149590', 'hidden_config': [30], 'in_features': 10, 'k_steps': 100, 'latent_features': 20, 'model_class': 'ExponentialKoopmanAutencoder', 'nonlinearity': 'leaky_relu', 'preprocess': True}


In [25]:
dataset_config = DatasetConfig(
    dataset_name=model_metadata["dataset"],
    num_samples=5_000,
    split="test",
    seed=42,
)
DatasetClass = get_dataset_class(name=dataset_config.dataset_name)
dataset = DatasetClass(config=dataset_config, root=data_root)

subset_size = None
if subset_size:
    subset_indices = list(range(0, subset_size))
    subset = Subset(dataset, subset_indices)

batch_size = 3_000
batch_size = min(subset_size, batch_size) if subset_size else batch_size
dataloader = DataLoader(subset if subset_size else dataset, batch_size=batch_size, shuffle=False)

In [26]:
preproc_dict = {}
with safe_open(
    f"{file_dir}/{ae_name}_preprocessing.safetensors", framework="pt", device="cpu"
) as f:
    for k in f.keys():
        preproc_dict[k] = f.get_tensor(k)

In [27]:
print(f"Preprocess activations?: {preprocess}")
orig_act_dict, proc_act_dict, _ = prepare_acts(
    data_train_loader=dataloader,
    model=model,
    device=device,
    svd_dim=ae_metadata["in_features"],
    whiten_alpha=preproc_dict["wh_alpha_0"],
    preprocess=preprocess,
    preprocess_dict=preproc_dict,
    only_first_last=True,
)
proc_act_dict = orig_act_dict if not preprocess else proc_act_dict
init_idx = list(orig_act_dict.keys())[0]
final_idx = list(orig_act_dict.keys())[-1]

Preprocess activations?: True


Processing activations: 100%|██████████| 2/2 [00:00<00:00, 1588.75it/s]


In [28]:
V = autoencoder.V
D_exp = autoencoder.D_exp
V_inv = autoencoder.V_inv

In [29]:
init_idx = list(orig_act_dict.keys())[0]
final_idx = list(orig_act_dict.keys())[-1]

with torch.no_grad():
    x = orig_act_dict[init_idx]
    x_proj = proc_act_dict[init_idx]

    y = orig_act_dict[final_idx]
    y_proj = proc_act_dict[final_idx]

    if preprocess:
        x_unproj = undo_preprocessing_acts(x_proj, preproc_dict, init_idx, device)
        y_unproj = undo_preprocessing_acts(y_proj, preproc_dict, final_idx, device)
    else:
        x_unproj = x_proj
        y_unproj = y_proj

    # Reconstruct first act
    x_proj_obs = autoencoder.encode(x_proj)
    x_proj_recon = autoencoder.decode(x_proj_obs)

    # Reconstruct final act
    y_proj_obs = autoencoder.encode(y_proj)
    y_proj_recon = autoencoder.decode(y_proj_obs)

    # Beta way of doing it
    pred_proj_obs = x_proj_obs @ V.T @ torch.linalg.matrix_power(D_exp, k_steps) @ V_inv.T
    pred_proj = autoencoder.decode(pred_proj_obs)

    # Alpha way of doing it
    # pred_proj = autoencoder(x_proj).predictions[-1]

    if preprocess:
        pred = undo_preprocessing_acts(pred_proj, preproc_dict, final_idx, device)
    else:
        pred = pred_proj

    if preprocess:
        y_recon = undo_preprocessing_acts(y_proj_recon, preproc_dict, final_idx, device)
    else:
        y_recon = y_proj_recon

    if type(model) is ConvResNet:
        pred = pred.reshape(-1, 512, 4, 4)

In [30]:
koopman_pred = model.components[-1:](pred)
koopman_metric = MulticlassAccuracy(num_classes=dataset.out_features)
koopman_metric.update(koopman_pred, torch.tensor(dataset.labels[:subset_size]).squeeze())
print("Koopman accuracy: ", koopman_metric.compute())

Koopman accuracy:  tensor(0.9846)


  koopman_metric.update(koopman_pred, torch.tensor(dataset.labels[:subset_size]).squeeze())


In [31]:
# Select all data that has target class
target_class = 2
target_idx, _ = torch.where(dataset.labels == target_class)

vis_batch = 100
subset_labels = dataset.labels[target_idx]
subset_x_proj = x_proj[target_idx]
subset_x_proj_obs = x_proj_obs[target_idx]

subset_y_proj = y_proj[target_idx]
subset_y_proj_obs = y_proj_obs[target_idx]

In [32]:
def imshow(x):
    fig = px.imshow(x.detach(), color_continuous_scale="balance_r", color_continuous_midpoint=0.0)
    fig.update_layout(coloraxis_showscale=False)
    fig.show()


def scatter(x, y, labels, z=None, colormap=None):
    palette = px.colors.qualitative.Plotly
    colormap = {
        "0": palette[0],
        "1": palette[1],
        "2": palette[2],
        "3": palette[3],
    }
    if z is not None:
        fig = px.scatter_3d(
            x=x,
            y=y,
            z=z,
            color=[str(label) for label in labels],
            color_discrete_map=colormap,
        )
    else:
        fig = px.scatter(
            x=x,
            y=y,
            color=[str(label) for label in labels],
            color_discrete_map=colormap,
        )
    fig.update_layout(showlegend=True, width=800)
    fig.update_traces(marker=dict(size=3, line=dict(width=0.0001, color="DarkSlateGrey")))
    fig.show()

Operator post-image

In [33]:
pca_engine = PCA(n_components=3)
reduced = pca_engine.fit_transform(pred_proj_obs)
scatter(x=reduced[:, 0], y=reduced[:, 1], z=reduced[:, 2], labels=dataset.labels.squeeze().numpy())

Identifying medoid cluster in post-image

In [34]:
with torch.no_grad():
    # Advance all
    x_obs_advanced = x_proj_obs @ torch.linalg.matrix_power(autoencoder.koopman_weights.T, 100)

    # Advance cluster
    subset_advanced = subset_x_proj_obs @ V.T @ torch.linalg.matrix_power(D_exp, 100) @ V_inv.T

    # Compute medoid
    pairwise_distances = torch.cdist(subset_advanced, subset_advanced, p=2)
    sum_distances = torch.sum(pairwise_distances, dim=1)
    medoid_idx = torch.argmin(sum_distances)
    medoid = subset_advanced[medoid_idx].unsqueeze(0)

    # Get radius
    distances_from_medoid = pairwise_distances[medoid_idx]
    radius = torch.quantile(distances_from_medoid, 1.00)  # percentile

    # Compute distances from all points in 'advanced' to the medoid
    distances_to_medoid_all = torch.cdist(x_obs_advanced, medoid, p=2).squeeze()

    # Find points within radius in the full advanced tensor
    mask_all = distances_to_medoid_all <= radius
    harmful_outputs = x_obs_advanced[mask_all]
    indices_within_radius = torch.where(mask_all)[0]  # indices in the `advanced` tensor

    # Statistics
    num_points = mask_all.sum().item()
    percentage = num_points / len(x_obs_advanced) * 100
    print(f"Selected {num_points} points ({percentage:.2f}% of all points)")


pca_engine = PCA(n_components=3)
reduced = pca_engine.fit_transform(pred_proj)
labels = dataset.labels.clone().squeeze().numpy()
labels[indices_within_radius] = -1
scatter(x=reduced[:, 0], y=reduced[:, 1], z=reduced[:, 2], labels=labels)

Selected 1686 points (33.72% of all points)


In [35]:
def get_clean_indices(all_inputs, harmful_inputs, tolerance=1e-3):
    # Calculate all pairwise distances in one go
    distances = torch.cdist(all_inputs, harmful_inputs)

    # Find the closest input to each harmful input
    min_distances, closest_indices = torch.min(distances, dim=0)

    # Identify inputs that are very close to harmful inputs (likely harmful themselves)
    within_tol = min_distances <= tolerance
    potential_harmful_indices = closest_indices[within_tol]

    # Get non-harmful input indices
    all_idx_set = set(range(len(all_inputs)))
    harmful_idx_set = set(potential_harmful_indices.cpu().numpy())
    non_harmful_indices = list(all_idx_set - harmful_idx_set)
    non_harmful_indices = torch.tensor(non_harmful_indices)

    # Find nearest non-harmful input for each harmful input
    # Get distances from non-harmful inputs to harmful inputs
    clean_distances = distances[non_harmful_indices]

    # Find the closest clean input for each harmful input
    _, closest_clean_indices_rel = torch.min(clean_distances, dim=0)

    # Map back to original indices
    closest_clean_indices = non_harmful_indices[closest_clean_indices_rel]

    return non_harmful_indices, closest_clean_indices


with torch.no_grad():
    # Reverse
    harmful_inputs = harmful_outputs @ torch.linalg.matrix_power(
        torch.linalg.pinv(autoencoder.koopman_weights.T), k_steps
    )
    dummy_labels = ["2"] * harmful_inputs.shape[0]

    clean_inputs_idx, closest_clean_idx = get_clean_indices(x_proj_obs, harmful_inputs)
    closest_clean_target = pred_proj_obs[closest_clean_idx]


In [36]:
def edit_operator(original_operator, clean_inputs, harmful_inputs, closest_clean_target):
    # Compute matrix C0 (m x m) - preservation component
    C0 = clean_inputs.T @ clean_inputs

    # Add a small regularization term to ensure invertibility (as mentioned in Section 5.1)
    C0_reg = C0 + 0.1 * np.eye(C0.shape[0])
    C0_inv = torch.linalg.inv(C0_reg)

    # Compute residual error R (e x m) - memorization component
    R = closest_clean_target - harmful_inputs @ original_operator

    # Compute matrix D (e x e)
    D = harmful_inputs @ C0_inv @ harmful_inputs.T

    # Add regularization to D as mentioned in Section 5.1 of the paper
    D_reg = D + 0.1 * torch.eye(D.shape[0])
    D_inv = torch.linalg.inv(D_reg)

    # Compute delta (m x m)
    delta = C0_inv @ harmful_inputs.T @ D_inv @ R

    # Compute edited operator
    edited_operator = original_operator + delta

    return edited_operator


new_operator = edit_operator(
    original_operator=torch.linalg.matrix_power(autoencoder.koopman_weights.T, k_steps).double(),
    clean_inputs=x_proj_obs[clean_inputs_idx].double(),
    harmful_inputs=harmful_inputs.double(),
    closest_clean_target=closest_clean_target.double(),
)

In [37]:
with torch.no_grad():
    plot_inputs = x_proj_obs @ torch.linalg.matrix_power(autoencoder.koopman_weights.T, k_steps)
    plot_labels = dataset.labels.clone().detach().squeeze().numpy()

pca_engine = PCA(n_components=3)
reduced = pca_engine.fit_transform(plot_inputs)
scatter(x=reduced[:, 0], y=reduced[:, 1], z=reduced[:, 2], labels=plot_labels)

In [38]:
with torch.no_grad():
    plot_inputs = x_proj_obs @ new_operator.float()
    plot_labels = dataset.labels.clone().detach().squeeze().numpy()

pca_engine = PCA(n_components=3)
reduced = pca_engine.fit_transform(plot_inputs)
scatter(x=reduced[:, 0], y=reduced[:, 1], z=reduced[:, 2], labels=plot_labels)

In [39]:
new_output_obs = x_proj_obs @ new_operator.float()
new_output_proj = autoencoder.decode(new_output_obs)

if preprocess:
    new_pred = undo_preprocessing_acts(new_output_proj, preproc_dict, final_idx, device)
else:
    new_pred = pred_proj

new_koopman_pred = model.components[-1:](new_pred)

# Convert target to long (int64) explicitly
target = torch.tensor(dataset.labels[:subset_size], dtype=torch.long).squeeze()

# Set average=None to get per-class accuracy scores
koopman_metric = MulticlassAccuracy(num_classes=dataset.out_features, average=None)

koopman_metric.update(new_koopman_pred, target)

# This will return a tensor with one accuracy value per class
class_accuracies = koopman_metric.compute()
print("Koopman accuracy per class: ", class_accuracies)

# If you still want the overall accuracy as well
overall_metric = MulticlassAccuracy(num_classes=dataset.out_features)
overall_metric.update(new_koopman_pred, target)
print("Overall Koopman accuracy: ", overall_metric.compute())

Koopman accuracy per class:  tensor([0.9673, 0.9935, 0.3680])
Overall Koopman accuracy:  tensor(0.7824)



To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).


The reduce argument of torch.scatter with Tensor src is deprecated and will be removed in a future PyTorch release. Use torch.scatter_reduce instead for more reduction options. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/TensorAdvancedIndexing.cpp:233.)

