In [2]:
import math
from ast import literal_eval

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.subplots as sp
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.parametrize as parametrize
import torchvision.transforms.functional as TVF
from matrepr import mdisplay
from plotly.subplots import make_subplots
from rich import print as rprint
from torch import linalg
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms

from koopmann import aesthetics
from koopmann.data import (
    DatasetConfig,
    create_data_loader,
    get_dataset_class,
)
from koopmann.models import (
    MLP,
    Autoencoder,
    ExponentialKoopmanAutencoder,
    LowRankKoopmanAutoencoder,
    ResMLP,
)
from koopmann.models.utils import get_device
from koopmann.visualization import plot_eigenvalues

%load_ext autoreload
%autoreload 2

In [4]:
from analysis.common import (
    compare_model_autoencoder_acc,
)

In [None]:
task = "mnist"
scale_idx = "0"
k = "1"
dim = "1024"

# flavor = "exponential"
# flavor = "standard"
flavor = "lowrank_20"

user = "nsa325"

model_name = f"{task}_probed"
ae_name = f"{task}_model"


Load and prepare MLP

In [8]:
# Original model path
model_file_path = f"/scratch/{user}/koopmann_model_saves/{model_name}.safetensors"

if "probed" in model_name:
    model, model_metadata = MLP.load_model(file_path=model_file_path)
    model.modules[-2].remove_nonlinearity()
    model.modules[-3].remove_nonlinearity()
    # model.modules[-3].update_nonlinearity("leakyrelu")
    is_probed = True
else:
    if "residual" in model_name:
        model, model_metadata = ResMLP.load_model(file_path=model_file_path)
    else:
        model, model_metadata = MLP.load_model(file_path=model_file_path)
    is_probed = False

model.eval().hook_model()

Build data

In [9]:
dataset_config = DatasetConfig(
    dataset_name=model_metadata["dataset"], num_samples=5_000, split="test", seed=21
)
DatasetClass = get_dataset_class(name=dataset_config.dataset_name)
dataset = DatasetClass(config=dataset_config)

# Raw images and labels
raw_images, labels = dataset.data, dataset.labels

# Processed for MLP
mlp_transform = transforms.Compose(
    [
        transforms.Lambda(lambda x: x / 255),
        transforms.Normalize((0.1307,), (0.3081,)),
    ]
)
mlp_inputs = mlp_transform(raw_images)

# Processed for AE
ae_transform = transforms.Compose(
    [
        transforms.Lambda(lambda x: x / 255),
        transforms.Lambda(lambda x: x * 2 - 1),
    ]
)
ae_inputs = ae_transform(raw_images)

Build autoencoder

In [10]:
# Autoenoder path in work dir
ae_file_path = f"/scratch/{user}/koopmann_model_saves/scaling/dim_{dim}_k_{k}_loc_{scale_idx}_{flavor}_autoencoder_{ae_name}.safetensors"

# Choose model based on flag
if "standard" in flavor:
    AutoencoderClass = Autoencoder
elif "lowrank" in flavor:
    AutoencoderClass = LowRankKoopmanAutoencoder
elif "exponential" in flavor:
    AutoencoderClass = ExponentialKoopmanAutencoder

autoencoder, ae_metadata = AutoencoderClass.load_model(
    ae_file_path,
    strict=True,
    remove_param=True,
)
_ = autoencoder.eval()

K_matrix = autoencoder.koopman_matrix.linear_layer.weight.T.detach()

Verify accuracy

In [None]:
# acc_mlp, acc_koopman = compare_model_autoencoder_acc(
#     model, autoencoder, int(k), len(dataset.classes), mlp_inputs, ae_inputs, labels
# )
# mdisplay(acc_mlp, title="Original Model Testing Accuracy")
# mdisplay(acc_koopman, title="Autoencoder Prediction Testing Accuracy")

Get class-specific inputs

In [None]:
# Which indices contain the target class
target_label = 0
target_idx = torch.where(labels == target_label)

# Grab the inputs
target_mlp_inputs = mlp_inputs[target_idx]
target_ae_inputs = ae_inputs[target_idx]

# Compute mlp predictions
mlp_output = model(target_mlp_inputs)
mlp_predictions = torch.argmax(mlp_output, dim=1)

# Koopman outputs
koopman_output = autoencoder(target_ae_inputs, k=int(k)).predictions.squeeze(0)
koopman_predictions = torch.argmax(model.modules[-2:](koopman_output), dim=1)

In [None]:
# Grab the correct and incorrect indices
correct_idx = torch.where(koopman_predictions == target_label)[0]
misclassified_idx = torch.where(koopman_predictions != target_label)[0]

# Correct MLP inputs
correct_mlp_inputs = target_mlp_inputs[correct_idx]
correct_ae_inputs = target_ae_inputs[correct_idx]
correct_labels = koopman_predictions[correct_idx]

# Misclassified by MLP inputs
misclassifed_mlp_inputs = target_mlp_inputs[misclassified_idx]
misclassified_ae_inputs = target_ae_inputs[misclassified_idx]
incorrect_labels = koopman_predictions[misclassified_idx]

In [None]:
import numpy as np
import plotly.graph_objects as go
import plotly.subplots as sp
from ipywidgets import interact, widgets

# Compute SVD of Koopman matrix
U, S, Vh = torch.linalg.svd(K_matrix)

# Identify significant singular values/vectors
significant_threshold = 1e-5
significant_modes = torch.where(S > significant_threshold)[0]
n_significant = len(significant_modes)
print(f"Number of significant modes: {n_significant} out of {len(S)}")

# Encode all samples
with torch.no_grad():
    # Encode samples
    correct_encodings = autoencoder._encode(correct_ae_inputs)
    misclass_encodings = autoencoder._encode(misclassified_ae_inputs)

    # Project onto significant left singular vectors
    sig_U = U[:, significant_modes]
    correct_projections = torch.matmul(correct_encodings, sig_U)
    misclass_projections = torch.matmul(misclass_encodings, sig_U)

    # Convert to numpy for plotting
    correct_proj_np = correct_projections.cpu().numpy()
    misclass_proj_np = misclass_projections.cpu().numpy()
    significant_S = S[significant_modes].cpu().numpy()
    significant_modes_np = significant_modes.cpu().numpy()

# Create interactive widgets
n_correct = len(correct_proj_np)
n_misclass = len(misclass_proj_np)


@interact(
    sample_type=widgets.RadioButtons(
        options=[("Correctly Classified", "correct"), ("Misclassified", "misclassified")],
        description="Sample Type:",
        disabled=False,
    ),
    sample_idx=widgets.IntSlider(
        min=0, max=max(n_correct, n_misclass) - 1, step=1, value=0, description="Sample Index:"
    ),
    n_modes=widgets.IntSlider(
        min=5, max=min(20, n_significant), step=1, value=10, description="# of Modes:"
    ),
    normalization=widgets.Dropdown(
        options=[
            ("Raw Projections", "raw"),
            ("Normalized by Singular Values", "normalized"),
            ("Output Contribution", "contribution"),
        ],
        value="raw",
        description="View:",
    ),
)
def update_plot(sample_type, sample_idx, n_modes, normalization):
    # Ensure sample index is valid
    if sample_type == "correct":
        max_idx = n_correct - 1
    else:
        max_idx = n_misclass - 1

    if sample_idx > max_idx:
        sample_idx = max_idx

    # Get sample data
    if sample_type == "correct":
        sample_input = correct_ae_inputs[sample_idx]
        sample_proj = correct_proj_np[sample_idx]
        sample_class = target_label
        title_color = "blue"
    else:
        sample_input = misclassified_ae_inputs[sample_idx]
        sample_proj = misclass_proj_np[sample_idx]
        sample_class = incorrect_labels[sample_idx].item()
        title_color = "red"

    # Ensure we don't exceed available modes
    n_modes = min(n_modes, n_significant)

    # Create a figure with two subplots side by side
    fig = sp.make_subplots(
        rows=1, cols=2, column_widths=[0.3, 0.7], specs=[[{"type": "image"}, {"type": "bar"}]]
    )

    # Add the input image on the left
    img = sample_input.reshape(28, 28).cpu().numpy()
    fig.add_trace(go.Heatmap(z=img, colorscale="Greys", showscale=False), row=1, col=1)

    # Get modes and projections
    mode_indices = significant_modes_np[:n_modes]
    raw_values = sample_proj[:n_modes]
    singular_values = significant_S[:n_modes]

    # Apply requested normalization
    if normalization == "raw":
        proj_values = raw_values
        y_label = "Raw Projection"
    elif normalization == "normalized":
        proj_values = raw_values / singular_values
        y_label = "Normalized Projection (p/σ)"
    else:  # contribution
        proj_values = raw_values * singular_values
        y_label = "Output Contribution (p·σ)"

    # Create a colored bar chart based on projection values
    colors = ["red" if val < 0 else "blue" for val in proj_values]

    fig.add_trace(
        go.Bar(
            x=mode_indices,
            y=proj_values,
            marker_color=colors,
            text=[f"σ={sv:.4f}" for sv in singular_values],
            hovertemplate="Mode %{x}<br>Value: %{y:.4f}<br>%{text}<extra></extra>",
            name="Projections",
        ),
        row=1,
        col=2,
    )

    # Customize layout
    n_samples = n_correct if sample_type == "correct" else n_misclass
    sample_status = (
        "Correctly Classified" if sample_type == "correct" else f"Misclassified as {sample_class}"
    )
    fig.update_layout(
        title=f"Sample {sample_idx+1}/{n_samples} ({sample_status}, True Class: {target_label})",
        title_font=dict(color=title_color),
        height=500,
        width=1000,
        showlegend=False,
    )

    fig.update_xaxes(title_text="Mode Index", row=1, col=2)
    fig.update_yaxes(title_text=y_label, row=1, col=2)

    # Add a horizontal line at y=0 for reference
    fig.add_shape(
        type="line",
        line=dict(dash="dash", color="black"),
        x0=min(mode_indices) - 0.5,
        y0=0,
        x1=max(mode_indices) + 0.5,
        y1=0,
        row=1,
        col=2,
    )

    return fig

In [None]:
import numpy as np
import plotly.graph_objects as go
from ipywidgets import interact, widgets


# Define overlap calculation function first
def compute_distribution_overlap(dist1, dist2, bins=50):
    """Compute approximate percentage overlap between two distributions"""
    min_val = min(np.min(dist1), np.min(dist2))
    max_val = max(np.max(dist1), np.max(dist2))

    # Compute histograms with same bins
    hist1, bin_edges = np.histogram(dist1, bins=bins, range=(min_val, max_val), density=True)
    hist2, _ = np.histogram(dist2, bins=bins, range=(min_val, max_val), density=True)

    # Compute bin width for integration
    bin_width = (max_val - min_val) / bins

    # Compute overlap
    overlap = np.sum(np.minimum(hist1, hist2)) * bin_width * 100

    return overlap


# Create histogram visualization
@interact(
    mode_idx=widgets.IntSlider(
        min=0, max=min(19, n_significant - 1), step=1, value=0, description="Mode:"
    ),
    normalization=widgets.Dropdown(
        options=[
            ("Raw Projections", "raw"),
            ("Normalized by Singular Values", "normalized"),
            ("Output Contribution", "contribution"),
        ],
        value="raw",
        description="View:",
    ),
)
def plot_projection_histogram(mode_idx, normalization):
    # Get the actual mode number and its singular value
    mode = significant_modes_np[mode_idx]
    sigma = significant_S[mode_idx]

    # Get projections for this mode
    correct_proj = correct_proj_np[:, mode_idx]
    misclass_proj = misclass_proj_np[:, mode_idx]

    # Apply normalization
    if normalization == "raw":
        y_label = "Raw Projection"
    elif normalization == "normalized":
        correct_proj = correct_proj / sigma
        misclass_proj = misclass_proj / sigma
        y_label = "Normalized Projection (p/σ)"
    else:  # contribution
        correct_proj = correct_proj * sigma
        misclass_proj = misclass_proj * sigma
        y_label = "Output Contribution (p·σ)"

    # Create figure
    fig = go.Figure()

    # Add histograms
    fig.add_trace(
        go.Histogram(
            x=correct_proj, name="Correctly Classified", opacity=0.7, marker_color="blue", nbinsx=30
        )
    )

    fig.add_trace(
        go.Histogram(
            x=misclass_proj, name="Misclassified", opacity=0.7, marker_color="red", nbinsx=30
        )
    )

    # Add vertical lines for means
    fig.add_vline(
        x=np.mean(correct_proj),
        line_dash="dash",
        line_color="blue",
        annotation_text="Correct Mean",
        annotation_position="top right",
    )
    fig.add_vline(
        x=np.mean(misclass_proj),
        line_dash="dash",
        line_color="red",
        annotation_text="Misclass Mean",
        annotation_position="top left",
    )

    # Compute effect size
    pooled_std = np.sqrt(
        (
            (len(correct_proj) - 1) * np.var(correct_proj)
            + (len(misclass_proj) - 1) * np.var(misclass_proj)
        )
        / (len(correct_proj) + len(misclass_proj) - 2)
    )
    effect_size = (np.mean(misclass_proj) - np.mean(correct_proj)) / pooled_std

    # Update layout
    fig.update_layout(
        title=f"Mode {mode} (σ={sigma:.4f}) Projection Distribution - Effect Size: {effect_size:.2f}",
        xaxis_title=y_label,
        yaxis_title="Count",
        barmode="overlay",
        height=500,
        width=900,
    )

    # Add hover data
    fig.update_traces(hovertemplate="Value: %{x:.4f}<br>Count: %{y}<extra></extra>")

    # Compute separation statistics
    overlap = compute_distribution_overlap(correct_proj, misclass_proj)
    fig.add_annotation(
        x=0.5,
        y=1.05,
        xref="paper",
        yref="paper",
        text=f"Distribution Overlap: {overlap:.1f}%",
        showarrow=False,
        font=dict(size=14),
    )

    return fig

In [None]:
import numpy as np

# Compute SVD of Koopman matrix
U, S, Vh = torch.linalg.svd(K_matrix)

# Define significance threshold
significant_threshold = 1e-5
significant_indices = torch.where(S > significant_threshold)[0]
print(f"Number of significant singular values: {len(significant_indices)} out of {len(S)}")

# Extract significant singular vectors
significant_U = U[:, significant_indices]
significant_S = S[significant_indices]

# Encode all samples
with torch.no_grad():
    all_encodings = autoencoder._encode(target_ae_inputs)

    # Project onto significant left singular vectors
    projections = torch.matmul(all_encodings, significant_U)

    # Two metrics of importance:
    # 1. Raw projection variance
    raw_variance = torch.var(projections, dim=0)

    # 2. Scaled variance (accounts for singular value effect)
    # Scale projections by their singular values
    scaled_projections = projections * significant_S
    scaled_variance = torch.var(scaled_projections, dim=0)

    # Sort by scaled variance (more relevant to the output)
    sorted_indices = torch.argsort(scaled_variance, descending=True)

    print("\nSignificant directions sorted by SCALED variance:")
    print("Rank | Orig Idx | Scaled Var | Raw Var | Singular Value")
    print("-----|---------|------------|---------|---------------")
    for i in range(len(significant_indices)):
        idx = sorted_indices[i].item()
        original_idx = significant_indices[idx].item()
        scaled_var = scaled_variance[idx].item()
        raw_var = raw_variance[idx].item()
        s_value = significant_S[idx].item()
        print(f"{i+1:4d} | {original_idx:7d} | {scaled_var:.6f} | {raw_var:.6f} | {s_value:.6f}")

# Store the high scaled-variance indices
high_scaled_var_indices = [significant_indices[idx].item() for idx in sorted_indices]

In [None]:
# Select numbers of high-variance directions to test
max_dirs = min(len(significant_indices), 20)
n_directions_to_test = [1, 2, 3, 5, min(10, max_dirs), min(15, max_dirs), min(max_dirs, 20)]
n_directions_to_test = sorted(list(set(n_directions_to_test)))  # Remove duplicates

results = {}
for n_dir in n_directions_to_test:
    # Select the top n high SCALED-variance directions
    test_directions = high_scaled_var_indices[:n_dir]

    # Track accuracy changes
    original_correct = 0
    modified_correct = 0
    changed_predictions = 0

    with torch.no_grad():
        for i in range(len(target_ae_inputs)):
            # Original processing
            x = target_ae_inputs[i].unsqueeze(0)
            z = autoencoder._encode(x)
            transformed = autoencoder.koopman_matrix(z)
            output = autoencoder._decode(transformed)
            original_pred = torch.argmax(model.modules[-2:](output)).item()

            # Modified processing - zero out projections on high-variance directions
            z_proj = torch.zeros_like(z)
            for j, idx in enumerate(significant_indices):
                if idx.item() not in test_directions:
                    # Keep projections for non-high-variance directions
                    z_proj += (z @ U[:, idx]) * U[:, idx]

            # Transform modified encoding
            transformed_modified = autoencoder.koopman_matrix(z_proj)
            output_modified = autoencoder._decode(transformed_modified)
            modified_pred = torch.argmax(model.modules[-2:](output_modified)).item()

            # Track metrics
            if original_pred == target_label:
                original_correct += 1
            if modified_pred == target_label:
                modified_correct += 1
            if original_pred != modified_pred:
                changed_predictions += 1

        # Compute statistics
        total = len(target_ae_inputs)
        results[n_dir] = {
            "original_accuracy": original_correct / total,
            "modified_accuracy": modified_correct / total,
            "prediction_change_rate": changed_predictions / total,
        }

# Print results
print("\nClassification Impact Test Results (using SCALED variance):")
print("High-var Dirs | Original Acc | Modified Acc | Change Rate")
print("-------------|--------------|--------------|------------")
for n_dir, stats in results.items():
    print(
        f"{n_dir:12d} | {stats['original_accuracy']:.4f} | {stats['modified_accuracy']:.4f} | {stats['prediction_change_rate']:.4f}"
    )

In [None]:
# Sort by RAW variance (not scaled)
raw_sorted_indices = torch.argsort(raw_variance, descending=True)
high_raw_var_indices = [significant_indices[idx].item() for idx in raw_sorted_indices]

print("\nSignificant directions sorted by RAW variance:")
print("Rank | Orig Idx | Raw Var | Singular Value | Scaled Var")
print("-----|---------|---------|----------------|------------")
for i in range(len(significant_indices)):
    idx = raw_sorted_indices[i].item()
    original_idx = significant_indices[idx].item()
    raw_var = raw_variance[idx].item()
    s_value = significant_S[idx].item()
    scaled_var = scaled_variance[idx].item()
    print(f"{i+1:4d} | {original_idx:7d} | {raw_var:.6f} | {s_value:.6f} | {scaled_var:.6f}")

# Test removing high RAW variance directions
results_raw = {}
for n_dir in n_directions_to_test:
    # Select the top n high RAW-variance directions
    test_directions = high_raw_var_indices[:n_dir]

    # Track accuracy changes
    original_correct = 0
    modified_correct = 0
    changed_predictions = 0

    with torch.no_grad():
        for i in range(len(target_ae_inputs)):
            # Original processing
            x = target_ae_inputs[i].unsqueeze(0)
            z = autoencoder._encode(x)
            transformed = autoencoder.koopman_matrix(z)
            output = autoencoder._decode(transformed)
            original_pred = torch.argmax(model.modules[-2:](output)).item()

            # Modified processing - zero out projections on high-variance directions
            z_proj = torch.zeros_like(z)
            for j, idx in enumerate(significant_indices):
                if idx.item() not in test_directions:
                    # Keep projections for non-high-variance directions
                    z_proj += (z @ U[:, idx]) * U[:, idx]

            # Transform modified encoding
            transformed_modified = autoencoder.koopman_matrix(z_proj)
            output_modified = autoencoder._decode(transformed_modified)
            modified_pred = torch.argmax(model.modules[-2:](output_modified)).item()

            # Track metrics
            if original_pred == target_label:
                original_correct += 1
            if modified_pred == target_label:
                modified_correct += 1
            if original_pred != modified_pred:
                changed_predictions += 1

        # Compute statistics
        total = len(target_ae_inputs)
        results_raw[n_dir] = {
            "original_accuracy": original_correct / total,
            "modified_accuracy": modified_correct / total,
            "prediction_change_rate": changed_predictions / total,
        }

print("\nClassification Impact Test Results (using RAW variance):")
print("High-var Dirs | Original Acc | Modified Acc | Change Rate")
print("-------------|--------------|--------------|------------")
for n_dir, stats in results_raw.items():
    print(
        f"{n_dir:12d} | {stats['original_accuracy']:.4f} | {stats['modified_accuracy']:.4f} | {stats['prediction_change_rate']:.4f}"
    )

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Create subplots with both visualization approaches
fig = make_subplots(
    rows=1,
    cols=2,
    subplot_titles=["Approach 1: Singular Value Allocation", "Approach 2: Variance Transformation"],
)

# Get original indices and convert to numpy
orig_indices = significant_indices.cpu().numpy()
raw_var = raw_variance.cpu().numpy()
sing_vals = significant_S.cpu().numpy()
scaled_var = scaled_variance.cpu().numpy()

# First approach: Raw Variance vs. Singular Value
fig.add_trace(
    go.Scatter(
        x=raw_var,
        y=sing_vals,
        mode="markers+text",
        marker=dict(
            size=15,
            color=scaled_var,
            colorscale="Viridis",
            colorbar=dict(title="Scaled Variance", x=0.45),
            showscale=True,
        ),
        text=[str(idx) for idx in orig_indices],
        textposition="top center",
        hovertemplate="Mode: %{text}<br>Raw Variance: %{x:.6f}<br>Singular Value: %{y:.4f}<extra></extra>",
        name="Modes",
    ),
    row=1,
    col=1,
)

# Second approach: Raw Variance vs. Scaled Variance
fig.add_trace(
    go.Scatter(
        x=raw_var,
        y=scaled_var,
        mode="markers+text",
        marker=dict(
            size=15,
            color=sing_vals,
            colorscale="Plasma",
            colorbar=dict(title="Singular Value", x=1.0),
            showscale=True,
        ),
        text=[str(idx) for idx in orig_indices],
        textposition="top center",
        hovertemplate="Mode: %{text}<br>Raw Variance: %{x:.6f}<br>Scaled Variance: %{y:.4f}<extra></extra>",
        name="Modes",
    ),
    row=1,
    col=2,
)

# Log scales for clearer visualization
fig.update_xaxes(type="log", title_text="Raw Variance (log scale)", row=1, col=1)
fig.update_yaxes(type="log", title_text="Singular Value (log scale)", row=1, col=1)

fig.update_xaxes(type="log", title_text="Raw Variance (log scale)", row=1, col=2)
fig.update_yaxes(type="log", title_text="Scaled Variance (log scale)", row=1, col=2)

# Update layout
fig.update_layout(
    title_text="Koopman's Feature Amplification Strategy: Two Complementary Views",
    height=600,
    width=1200,
    showlegend=False,
)

fig.show()

In [None]:
# We need data from multiple classes for this analysis
classes_to_analyze = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]  # MNIST digits
class_encodings = {}
class_projections = {}

with torch.no_grad():
    for c in classes_to_analyze:
        # Get samples for this class
        class_idx = torch.where(labels == c)[0]
        if len(class_idx) > 0:
            # Encode samples
            class_inputs = ae_inputs[class_idx]
            encodings = autoencoder._encode(class_inputs)
            class_encodings[c] = encodings

            # Project onto significant singular vectors
            proj = torch.matmul(encodings, significant_U)
            class_projections[c] = proj

    # Compute discriminative power metrics for each direction
    # 1. Fisher's discriminant ratio (between-class / within-class variance)
    fisher_scores = []

    for i in range(len(significant_indices)):
        # Get projections for all classes onto this direction
        class_means = []
        class_vars = []
        all_projs = []

        for c, projs in class_projections.items():
            dir_projs = projs[:, i]
            class_means.append(torch.mean(dir_projs).item())
            class_vars.append(torch.var(dir_projs).item())
            all_projs.append(dir_projs)

        # Between-class variance
        between_var = np.var(class_means) * len(class_means)

        # Within-class variance (weighted average)
        total_samples = sum(len(p) for p in all_projs)
        within_var = sum(
            v * len(class_projections[c]) / total_samples
            for c, v in zip(class_projections.keys(), class_vars)
        )

        # Fisher score (higher = more discriminative)
        fisher_score = between_var / (within_var + 1e-10)
        fisher_scores.append(fisher_score)

    # Associate with original indices and other metrics
    mode_metrics = []
    for i, idx in enumerate(significant_indices):
        mode_metrics.append(
            {
                "mode": idx.item(),
                "singular_value": significant_S[i].item(),
                "raw_variance": raw_variance[i].item(),
                "scaled_variance": scaled_variance[i].item(),
                "fisher_score": fisher_scores[i],
            }
        )

    # Sort by Fisher score (discriminative power)
    discriminative_modes = sorted(mode_metrics, key=lambda x: x["fisher_score"], reverse=True)

# Print results sorted by discriminative power
print("\nModes ranked by discriminative power (Fisher score):")
print("Rank | Mode | Fisher Score | Singular Value | Raw Variance")
print("-----|------|-------------|----------------|------------")
for i, mode in enumerate(discriminative_modes[:20]):
    print(
        f"{i+1:4d} | {mode['mode']:4d} | {mode['fisher_score']:11.4f} | {mode['singular_value']:14.4f} | {mode['raw_variance']:.6f}"
    )

In [None]:
import plotly.express as px
import plotly.graph_objects as go
from ipywidgets import HBox, VBox, interact, widgets

# Get the modes with non-zero singular values
valid_modes = [
    idx.item() for idx, sv in zip(significant_indices, significant_S) if sv > significant_threshold
]
mode_options = [
    (
        f"Mode {m} (σ={S[m]:.2f}, Fisher={next((x['fisher_score'] for x in discriminative_modes if x['mode'] == m), 0):.2f})",
        m,
    )
    for m in valid_modes
]


# Function to create visualizations
def create_mode_visualization(n_dimensions, mode1=None, mode2=None, mode3=None):
    # Select modes based on inputs
    selected_modes = [m for m in [mode1, mode2, mode3] if m is not None]
    if len(selected_modes) < n_dimensions:
        return go.Figure().update_layout(title="Please select modes for all dimensions")

    # Create the appropriate visualization based on dimensions
    if n_dimensions == 1:
        return create_1d_visualization(selected_modes[0])
    elif n_dimensions == 2:
        return create_2d_visualization(selected_modes[0], selected_modes[1])
    else:  # 3D
        return create_3d_visualization(selected_modes[0], selected_modes[1], selected_modes[2])


def create_1d_visualization(mode):
    fig = go.Figure()

    # Find index in significant_indices
    mode_idx = torch.where(significant_indices == mode)[0].item()

    # Add box plots for each class
    for c in sorted(class_projections.keys()):
        values = class_projections[c][:, mode_idx].cpu().numpy()

        fig.add_trace(
            go.Box(
                y=values,
                name=f"Class {c}",
                boxpoints="all",  # Show all points
                jitter=0.3,
                pointpos=-1.8,
            )
        )

    # Get the Fisher score and singular value
    fisher = next((x["fisher_score"] for x in discriminative_modes if x["mode"] == mode), 0)
    singular_value = S[mode].item()

    # Update layout
    fig.update_layout(
        title=f"Class Distribution on Mode {mode}",
        yaxis_title=f"Projection Value (σ={singular_value:.4f}, Fisher={fisher:.4f})",
        height=600,
        width=800,
    )

    return fig


def create_2d_visualization(mode1, mode2):
    fig = go.Figure()

    # Find indices in significant_indices
    idx1 = torch.where(significant_indices == mode1)[0].item()
    idx2 = torch.where(significant_indices == mode2)[0].item()

    # Get Fisher scores and singular values
    fisher1 = next((x["fisher_score"] for x in discriminative_modes if x["mode"] == mode1), 0)
    fisher2 = next((x["fisher_score"] for x in discriminative_modes if x["mode"] == mode2), 0)
    sv1 = S[mode1].item()
    sv2 = S[mode2].item()

    # Add scatter plots for each class
    for c in sorted(class_projections.keys()):
        x = class_projections[c][:, idx1].cpu().numpy()
        y = class_projections[c][:, idx2].cpu().numpy()

        fig.add_trace(go.Scatter(x=x, y=y, mode="markers", name=f"Class {c}", marker=dict(size=8)))

    # Update layout
    fig.update_layout(
        title=f"Class Separation by Selected Modes",
        xaxis_title=f"Mode {mode1} (σ={sv1:.4f}, Fisher={fisher1:.4f})",
        yaxis_title=f"Mode {mode2} (σ={sv2:.4f}, Fisher={fisher2:.4f})",
        height=600,
        width=800,
    )

    return fig


def create_3d_visualization(mode1, mode2, mode3):
    fig = go.Figure()

    # Find indices in significant_indices
    idx1 = torch.where(significant_indices == mode1)[0].item()
    idx2 = torch.where(significant_indices == mode2)[0].item()
    idx3 = torch.where(significant_indices == mode3)[0].item()

    # Get Fisher scores and singular values
    fisher1 = next((x["fisher_score"] for x in discriminative_modes if x["mode"] == mode1), 0)
    fisher2 = next((x["fisher_score"] for x in discriminative_modes if x["mode"] == mode2), 0)
    fisher3 = next((x["fisher_score"] for x in discriminative_modes if x["mode"] == mode3), 0)
    sv1 = S[mode1].item()
    sv2 = S[mode2].item()
    sv3 = S[mode3].item()

    # Add 3D scatter for each class
    for c in sorted(class_projections.keys()):
        x = class_projections[c][:, idx1].cpu().numpy()
        y = class_projections[c][:, idx2].cpu().numpy()
        z = class_projections[c][:, idx3].cpu().numpy()

        fig.add_trace(
            go.Scatter3d(
                x=x, y=y, z=z, mode="markers", name=f"Class {c}", marker=dict(size=4, opacity=0.8)
            )
        )

    # Update layout
    fig.update_layout(
        title=f"3D Class Separation by Selected Modes",
        scene=dict(
            xaxis_title=f"Mode {mode1} (σ={sv1:.4f}, Fisher={fisher1:.4f})",
            yaxis_title=f"Mode {mode2} (σ={sv2:.4f}, Fisher={fisher2:.4f})",
            zaxis_title=f"Mode {mode3} (σ={sv3:.4f}, Fisher={fisher3:.4f})",
        ),
        height=700,
        width=900,
    )

    return fig


# Create interactive widget
@interact(
    n_dimensions=widgets.RadioButtons(
        options=[(f"{i}D", i) for i in range(1, 4)],
        value=3,
        description="Dimensions:",
        style={"description_width": "initial"},
    ),
    mode1=widgets.Dropdown(
        options=mode_options,
        value=mode_options[0][1] if mode_options else None,
        description="X-axis Mode:",
        style={"description_width": "initial"},
        disabled=False,
    ),
    mode2=widgets.Dropdown(
        options=mode_options,
        value=mode_options[1][1] if len(mode_options) > 1 else None,
        description="Y-axis Mode:",
        style={"description_width": "initial"},
        disabled=False,
    ),
    mode3=widgets.Dropdown(
        options=mode_options,
        value=mode_options[2][1] if len(mode_options) > 2 else None,
        description="Z-axis Mode:",
        style={"description_width": "initial"},
        disabled=False,
    ),
)
def update_visualization(n_dimensions, mode1, mode2, mode3):
    fig = create_mode_visualization(n_dimensions, mode1, mode2, mode3)
    fig.show()

Modal surgery

In [None]:
with torch.no_grad():
    # Encode examples
    correct_encodings = autoencoder._encode(correct_ae_inputs)  # Shape: [n_correct, d]
    misclassified_encodings = autoencoder._encode(misclassified_ae_inputs)  # Shape: [n_misclass, d]

    # Obtain SVD of Koopman matrix
    U, S, Vh = torch.linalg.svd(K_matrix)

    # Identify significant modes
    significant_threshold = 1e-5
    significant_modes = torch.where(S > significant_threshold)[0]
    sig_U = U[:, significant_modes]
    sig_S = S[significant_modes]

    with torch.no_grad():
        # Project encodings onto left singular vectors
        correct_U_proj = torch.matmul(correct_encodings, sig_U)
        misclass_U_proj = torch.matmul(misclassified_encodings, sig_U)

        # Compute standard statistics
        correct_mean = torch.mean(correct_U_proj, dim=0)
        correct_std = torch.std(correct_U_proj, dim=0)
        misclass_mean = torch.mean(misclass_U_proj, dim=0)
        misclass_std = torch.std(misclass_U_proj, dim=0)

        # Compute standardized effect size (Cohen's d)
        n_correct = correct_U_proj.shape[0]
        n_misclass = misclass_U_proj.shape[0]
        pooled_std = torch.sqrt(
            ((n_correct - 1) * correct_std**2 + (n_misclass - 1) * misclass_std**2)
            / (n_correct + n_misclass - 2)
        )
        effect_size = (misclass_mean - correct_mean) / pooled_std

        # Calculate percent difference in magnitude
        pct_diff = 100 * (misclass_mean - correct_mean) / torch.abs(correct_mean)

        # Rank modes by absolute effect size
        ranked_indices = torch.argsort(torch.abs(effect_size), descending=True)

print("Modal Projection Analysis Results:")
print("Mode | Singular Value | Effect Size | % Difference")
print("-----|---------------|------------|-------------")

for i in range(len(significant_modes)):
    mode_idx = ranked_indices[i]
    mode = significant_modes[mode_idx]
    s_val = sig_S[mode_idx].item()
    d_val = effect_size[mode_idx].item()
    p_diff = pct_diff[mode_idx].item()

    print(f"{mode:4d} | {s_val:13.4f} | {d_val:10.4f} | {p_diff:+11.2f}%")

In [None]:
with torch.no_grad():
    # SVD decomposition of Koopman matrix
    U, S, Vh = torch.linalg.svd(K_matrix)

    # Identify significant modes
    significant_threshold = 1e-5
    significant_modes = torch.where(S > significant_threshold)[0]

    # Compute modal statistics for correctly classified examples
    correct_encodings = autoencoder._encode(correct_ae_inputs)
    correct_projections = torch.matmul(correct_encodings, U)
    correct_mean_proj = torch.mean(correct_projections, dim=0)

    # Select modes with highest discriminative power (based on previous analysis)
    target_modes = ranked_indices  # Top modes from effect size analysis

    # Process a misclassified example
    example_idx = 0  # Select first misclassified example
    x_mis = misclassified_ae_inputs[example_idx].unsqueeze(0)

    # Original classification
    mlp_pred = torch.argmax(model(misclassifed_mlp_inputs[example_idx].unsqueeze(0)))

    # Original Koopman transformation and prediction
    z_mis = autoencoder._encode(x_mis)
    koopman_transformed = autoencoder.koopman_matrix(z_mis)
    koopman_output = autoencoder._decode(koopman_transformed)

    # Push through penultimate layer
    koopman_pred = torch.argmax(model.modules[-2:](koopman_output)).item()

    print(
        f"MLP prediction: {mlp_pred}, Koopman prediction: {koopman_pred}, True class: {target_label}"
    )

    # Modal surgery: replace projections for specific modes
    for n_modes in list(range(20)):
        # Project into modal space
        z_proj = torch.matmul(z_mis, U)

        # Replace projections for selected modes
        for i in range(n_modes):
            mode = target_modes[i]
            z_proj[0, mode] = correct_mean_proj[mode]

        # Undo projection
        test = autoencoder._decode(torch.matmul(z_proj, U.t()))

        # Transform through remaining steps: z_proj → z_proj Σ → z_proj Σ V^T
        # Note: We construct a diagonal matrix from S for matrix multiplication
        S_diag = torch.diag(S)
        transformed = torch.matmul(z_proj, S_diag)
        transformed = torch.matmul(transformed, Vh)

        # Decode and classify
        corrected_output = autoencoder._decode(transformed)
        corrected_pred = torch.argmax(model.modules[-2:](corrected_output)).item()

        status = "FIXED ✓" if corrected_pred == target_label else "still wrong ✗"
        print(f"After correcting top {n_modes} modes: {status} (koopman pred={corrected_pred})")