In [10]:
import math
from collections import OrderedDict
from copy import deepcopy
from functools import reduce

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import panel as pn
import plotly.colors as pc
import plotly.express as px
import seaborn as sns
import torch
import torch.nn.functional as F
import torch.nn.utils.parametrize as parametrize
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D
from plotly.subplots import make_subplots
from rich import print as rprint
from scipy.spatial import procrustes
from sklearn.decomposition import PCA, FactorAnalysis
from sklearn.manifold import TSNE, LocallyLinearEmbedding
from sklearn.random_projection import GaussianRandomProjection
from torch import nn
from torch.utils.data import DataLoader, Subset
from torcheval.metrics import MulticlassAccuracy

from analysis.common import load_autoencoder, load_model
from analysis.residual_alignment_methods import alignment, plotsvals, sab, trajectories
from koopmann import aesthetics

# from koopmann import aesthetics
from koopmann.data import (
    DatasetConfig,
    get_dataset_class,
)
from koopmann.models import MLP, Autoencoder, ExponentialKoopmanAutencoder, ResMLP
from koopmann.models.layers import LinearLayer
from koopmann.utils import (
    get_device,
)
from koopmann.visualization import plot_decision_boundary
from scripts.train_ae.losses import pad_act

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [11]:
file_dir = "/scratch/nsa325/koopmann_model_saves"
dim = 2048
k = 1
scale_idx = 0

rank = 20
flavor = f"lowrank_{rank}"
# flavor = "standard"
# flavor = "exponential"

model_name = "resmlp"
ae_name = f"dim_{dim}_k_{k}_loc_{scale_idx}_{flavor}_autoencoder_mnist_model"
device = get_device()

In [12]:
model, model_metadata = load_model(file_dir, model_name)
model.hook_model()
print(model_metadata)

{'batchnorm': True, 'bias': False, 'created_at': '2025-03-26T01:46:32.117345', 'dataset': 'MNISTDataset', 'hidden_config': [784, 784, 784, 784], 'in_features': 784, 'model_class': 'ResMLP', 'nonlinearity': 'relu', 'out_features': 10, 'stochastic_depth_mode': 'batch', 'stochastic_depth_prob': 0.0}


In [13]:
# Dataset config
dataset_config = DatasetConfig(
    dataset_name=model_metadata["dataset"],
    num_samples=3_000,
    split="test",
    seed=42,
)
DatasetClass = get_dataset_class(name=dataset_config.dataset_name)
dataset = DatasetClass(config=dataset_config)
dataloader = DataLoader(dataset, batch_size=1024)

In [14]:
# target_class = 1
# idx = torch.where(train_dataset.labels == target_class)[0]
# subset = Subset(train_dataset, idx)
# loader = torch.utils.data.DataLoader(train_dataset, batch_size=1_000, shuffle=True)
# data, labels = next(iter(loader))

In [15]:
# # Compute initial PCA
# def compute_reference_bases(data):
#     # Compute PCA reference basis
#     pca = PCA(n_components=3)
#     ref = pca.fit_transform(data)
#     return ref


# # Function to align using Procrustes from scipy
# def align_using_procrustes(reference_points, new_points):
#     _, new_points_aligned, _ = procrustes(reference_points, new_points)
#     return new_points_aligned


In [16]:
# # Enable Panel for Jupyter
# pn.extension()


# def create_3d_scatter_plot(data, labels, axis_range):
#     x, y, z = data[:, 0], data[:, 1], data[:, 2]

#     str_labels = [str(label) for label in labels]
#     color = str_labels

#     # pca_scalar_field = np.linalg.norm(ref_a, axis=1)
#     # color = pca_scalar_field
#     # color_continuous_scale="Viridis")
#     fig = px.scatter_3d(x=x, y=y, z=z, color=color)

#     fig.update_traces(marker=dict(size=1))
#     fig.update_layout(
#         scene=dict(
#             xaxis=dict(range=axis_range),
#             yaxis=dict(range=axis_range),
#             zaxis=dict(range=axis_range),
#             aspectmode="cube",
#             aspectratio=dict(x=1, y=1, z=1),
#         ),
#         showlegend=False,
#     )
#     return fig


# def process_pca_and_align(data, reference):
#     """Applies PCA, aligns using Procrustes, and returns aligned data."""
#     pca = PCA(n_components=3)
#     pca_result = pca.fit_transform(data)
#     aligned_result = align_using_procrustes(reference, pca_result)
#     return aligned_result


# def update_plots(data_a, data_b, ref_a, ref_b, labels):
#     """Updates PCA and RP plots with the given data and references."""
#     pca_axis_range = [-0.05, 0.05]  # Default axis range

#     # First plot: PCA
#     aligned_pca_result = process_pca_and_align(data_a, ref_a)
#     first_fig = create_3d_scatter_plot(aligned_pca_result, labels, pca_axis_range)

#     # Second plot: PCA
#     aligned_pca_result = process_pca_and_align(data_b, ref_b)
#     second_fig = create_3d_scatter_plot(aligned_pca_result, labels, pca_axis_range)

#     return first_fig, second_fig


# # Clone and hook model
# cloned_model = deepcopy(model)
# cloned_model.hook_model()

# # Activations from original model
# with torch.no_grad():
#     _ = cloned_model.forward(data)
# act_dict = cloned_model.get_fwd_activations(detach=True)

# temp_act_dict = OrderedDict()
# temp_act_dict[0] = data.flatten(start_dim=1)
# for i in act_dict.keys():
#     temp_act_dict[i + 1] = act_dict[i]
# act_dict = temp_act_dict

# # Get Koopman predictions
# k = int(ae_metadata["num_scaled"])
# new_keys = list(range(0, k + 1))
# decoded_act = (
#     autoencoder(x=act_dict[scale_idx], k=k, intermediate=True).predictions.detach().numpy()
# )
# decoded_act_dict = OrderedDict(zip(new_keys, decoded_act))
# ref_decoded = compute_reference_bases(decoded_act_dict[0])

# # Get observable predictions
# embedded_act = [autoencoder.encoder(act_dict[scale_idx])] * (k + 1)
# embedded_act = [
#     act if i == 0 else reduce(lambda x, _: autoencoder.koopman_matrix(x), range(i), act)
#     for i, act in enumerate(embedded_act)
# ]
# embedded_act = [act.detach().numpy() for act in embedded_act]
# embedded_act_dict = OrderedDict(zip(new_keys, embedded_act))
# ref_embedded = compute_reference_bases(embedded_act[0])

# # Create slider
# layer_select = pn.widgets.IntSlider(name="Layer Selector", start=0, end=k, step=1, value=0)


# @pn.depends(layer_select.param.value)
# def view(layer_index):
#     figs = update_plots(
#         decoded_act_dict[layer_index],
#         embedded_act_dict[layer_index],
#         ref_decoded,
#         ref_embedded,
#         labels,
#     )
#     panes = [pn.pane.Plotly(fig) for fig in figs]

#     return pn.Row(*panes, align="center")


# # Layout
# layout = pn.Column(
#     pn.Row(layer_select, align="center"),
#     view,
#     align="center",
#     sizing_mode="stretch_width",
# )

# layout.show()