In [None]:
from ast import literal_eval

import matplotlib.pyplot as plt
import numpy as np
import torch
from rich import print as rprint
from safetensors.torch import load_model
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

from koopmann import aesthetics
from koopmann.data import (
    DatasetConfig,
    create_data_loader,
    get_dataset_class,
)
from koopmann.models import (
    MLP,
    Autoencoder,
    ExponentialKoopmanAutencoder,
    LowRankKoopmanAutoencoder,
)
from koopmann.models.utils import (
    get_device,
    parse_safetensors_metadata,
)

%load_ext autoreload
%autoreload 2

In [None]:
task = "mnist"
scale_idx = "0"
k = "1"
dim = "1024"
flavor = "lowrank_20"
# flavor = "standard"
user = "nsa325"


In [None]:
model_name = f"{task}_probed"
ae_name = f"{task}_model"

# Original model path
model_file_path = f"/scratch/{user}/koopmann_model_saves/{model_name}.safetensors"

# Autoenoder path in work dir
ae_file_path = f"/scratch/{user}/koopmann_model_saves/scaling/dim_{dim}_k_{k}_loc_{scale_idx}_{flavor}_autoencoder_{ae_name}.safetensors"

In [None]:
model, _ = MLP.load_model(model_file_path)
model.modules[-2].remove_nonlinearity()
# model.modules[-3].update_nonlinearity("leakyrelu")
model.modules[-3].remove_nonlinearity()
model.eval()
model.hook_model()

In [None]:
metadata = parse_safetensors_metadata(file_path=model_file_path)

dataset_config = DatasetConfig(
    dataset_name=metadata["dataset"],
    num_samples=5_000,
    split="test",
    seed=21,
)
DatasetClass = get_dataset_class(name=dataset_config.dataset_name)
dataset = DatasetClass(config=dataset_config)
dataloader = create_data_loader(dataset, batch_size=1024)

In [None]:
# Parse metadata
metadata = parse_safetensors_metadata(file_path=ae_file_path)

# Choose model based on flag
if "standard" in flavor:
    AutoencoderClass = Autoencoder
elif "lowrank" in flavor:
    AutoencoderClass = LowRankKoopmanAutoencoder

# Instantiate model
autoencoder = AutoencoderClass(
    input_dimension=literal_eval(metadata["input_dimension"]),
    latent_dimension=literal_eval(metadata["latent_dimension"]),
    nonlinearity=metadata["nonlinearity"],
    k=literal_eval(metadata["steps"]),
    batchnorm=literal_eval(metadata["batchnorm"]),
    hidden_configuration=literal_eval(metadata["hidden_configuration"]),
    rank=literal_eval(metadata["rank"]),
)

# Load weights
load_model(autoencoder, ae_file_path, device="cpu", strict=True)
autoencoder.eval()

# Remove parameterizations
if torch.nn.utils.parametrize.is_parametrized(autoencoder.koopman_matrix.linear_layer):
    torch.nn.utils.parametrize.remove_parametrizations(
        autoencoder.koopman_matrix.linear_layer, "weight"
    )

K_matrix = autoencoder.koopman_matrix.linear_layer.weight.T.detach()

k = literal_eval(metadata["num_scaled"])
print(f"Little K: {k}")

In [None]:
@torch.no_grad()
def shrink_eigenvalues(matrix, shrink_factor=0.95, indices=None):
    # Ensure the matrix is square
    assert matrix.shape[0] == matrix.shape[1], "Input matrix must be square"

    # Compute eigendecomposition
    eigenvalues, eigenvectors = torch.linalg.eig(matrix)
    n_eigenvalues = len(eigenvalues)

    # Create a copy of eigenvalues to modify
    shrunk_eigenvalues = eigenvalues.clone()

    if indices is not None:
        # Convert list to tensor if needed
        if isinstance(indices, list):
            indices = torch.tensor(indices)

        # Validate indices
        assert torch.all(
            (indices >= 0) & (indices < n_eigenvalues)
        ), f"All indices must be between 0 and {n_eigenvalues-1}"

        # Find the top eigenvalues by magnitude
        eig_magnitudes = torch.abs(eigenvalues)
        _, top_magnitude_indices = torch.topk(eig_magnitudes, n_eigenvalues)

        # Select the specified indices from the sorted top eigenvalues
        selected_indices = top_magnitude_indices[indices]

        print(indices)
        print("real", selected_indices)
        print(eigenvalues[selected_indices])

        # Shrink only the selected top eigenvalues
        shrunk_eigenvalues[selected_indices] = eigenvalues[selected_indices] * shrink_factor

    else:
        # Shrink all eigenvalues
        shrunk_eigenvalues = eigenvalues * shrink_factor

    # Reconstruct the matrix with shrunk eigenvalues
    # For a matrix A = PDP^(-1), where D is diagonal matrix of eigenvalues
    # and P is matrix of eigenvectors
    eigenvectors_inv = torch.linalg.inv(eigenvectors)
    diagonal_matrix = torch.diag(shrunk_eigenvalues)
    modified_matrix = eigenvectors @ diagonal_matrix @ eigenvectors_inv

    # Handle numerical issues - if the output should be real, remove small imaginary parts
    if torch.is_complex(matrix):
        return modified_matrix
    else:
        # Check if imaginary parts are negligible
        if torch.max(torch.abs(modified_matrix.imag)) < 1e-10:
            return modified_matrix.real, (
                eigenvalues[selected_indices],
                eigenvectors[selected_indices],
            )
        else:
            # If not negligible, there might be an issue
            print("Warning: Reconstructed matrix has non-negligible imaginary parts")
            return modified_matrix.real, (
                eigenvalues[selected_indices],
                eigenvectors[selected_indices],
            )


indices = []
if len(indices) == 0:
    mod_K_matrix = K_matrix
else:
    mod_K_matrix, (sel_eigen) = shrink_eigenvalues(K_matrix, shrink_factor=0.5, indices=indices)
    sel_eigenvalues, sel_eigenvectors = sel_eigen
autoencoder.koopman_matrix.linear_layer.weight = torch.nn.Parameter(mod_K_matrix.T)

In [None]:
unproc_data = (dataset.data.float() / 255)[:500]
labels = (dataset.labels)[:500]

###############################################################
# TODO: This is quick and dirty
# Convert [0,1] range to [-1,1]
proc_data = 2 * unproc_data - 1
###############################################################

with torch.no_grad():
    x_obs = autoencoder._encode(proc_data)
    pred_obs = autoencoder.koopman_matrix(x_obs)

In [None]:
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go

# Convert to numpy for sklearn compatibility (if needed)
x_obs_np = x_obs.cpu().numpy()
pred_obs_np = pred_obs.cpu().numpy()

# Use PCA for dimensionality reduction to 3D
pca = PCA(n_components=3, random_state=42)
pred_obs_pca_3d = pca.fit_transform(pred_obs_np)
x_obs_pca_3d = pca.transform(x_obs_np)

# Check explained variance for 3 components
explained_variance = np.sum(pca.explained_variance_ratio_)
print(f"Total explained variance with 3 PCA components: {explained_variance:.3f}")
print(f"Component 1: {pca.explained_variance_ratio_[0]:.3f}")
print(f"Component 2: {pca.explained_variance_ratio_[1]:.3f}")
print(f"Component 3: {pca.explained_variance_ratio_[2]:.3f}")

# Create a DataFrame for Plotly Express
unique_labels = np.unique(labels.cpu().numpy())
labels_as_str = [f"Class {label}" for label in labels.cpu().numpy()]

# Prepare data for original and predicted observations
df_orig = pd.DataFrame(
    {
        "PC1": x_obs_pca_3d[:, 0],
        "PC2": x_obs_pca_3d[:, 1],
        "PC3": x_obs_pca_3d[:, 2],
        "Class": labels_as_str,
        "Type": "Original",
    }
)

df_pred = pd.DataFrame(
    {
        "PC1": pred_obs_pca_3d[:, 0],
        "PC2": pred_obs_pca_3d[:, 1],
        "PC3": pred_obs_pca_3d[:, 2],
        "Class": labels_as_str,
        "Type": "Predicted",
    }
)

# Combine the DataFrames
df = pd.concat([df_orig, df_pred])

# Define a discrete color mapping dictionary
colors = px.colors.qualitative.Bold
color_discrete_map = {}

# Create mapping from class names to colors
for i, label in enumerate(unique_labels):
    color_idx = i % len(colors)
    class_name = f"Class {label}"
    color_discrete_map[class_name] = colors[color_idx]

# Create the 3D scatter plot
fig = px.scatter_3d(
    df,
    x="PC1",
    y="PC2",
    z="PC3",
    color="Class",
    symbol="Type",
    color_discrete_map=color_discrete_map,
    opacity=0.7,
    labels={
        "PC1": f"PC1 ({pca.explained_variance_ratio_[0]:.1%} var)",
        "PC2": f"PC2 ({pca.explained_variance_ratio_[1]:.1%} var)",
        "PC3": f"PC3 ({pca.explained_variance_ratio_[2]:.1%} var)",
    },
    title="3D Movement in PCA-Reduced Latent Space",
    height=800,
    width=1000,
)

# Add lines connecting original and predicted points
sample_rate = max(1, len(x_obs_np) // 300)

for i in range(0, len(x_obs_pca_3d), sample_rate):
    class_label = labels.cpu().numpy()[i]
    class_name = f"Class {class_label}"
    color = color_discrete_map[class_name]

    # Add connecting line with transparency but without the problematic opacity parameter
    fig.add_trace(
        go.Scatter3d(
            x=[x_obs_pca_3d[i, 0], pred_obs_pca_3d[i, 0]],
            y=[x_obs_pca_3d[i, 1], pred_obs_pca_3d[i, 1]],
            z=[x_obs_pca_3d[i, 2], pred_obs_pca_3d[i, 2]],
            mode="lines",
            line=dict(
                color=color.replace("rgb", "rgba").replace(
                    ")", ", 0.3)"
                ),  # Add transparency to color directly
                width=1,
            ),
            showlegend=False,
        )
    )

# Update the layout for better viewing
fig.update_layout(
    legend_title_text="Class Labels",
    scene=dict(
        xaxis_title=f"PC1 ({pca.explained_variance_ratio_[0]:.1%} var)",
        yaxis_title=f"PC2 ({pca.explained_variance_ratio_[1]:.1%} var)",
        zaxis_title=f"PC3 ({pca.explained_variance_ratio_[2]:.1%} var)",
    ),
    scene_camera=dict(eye=dict(x=1.5, y=1.5, z=1.5)),
)

# Show the interactive plot
fig.show()

print("Tip: You can rotate, zoom, and pan the plot using your mouse!")
print(
    "Tip: Double-click on a class in the legend to isolate it, or single-click to toggle visibility."
)