In [64]:
import math
from ast import literal_eval

import plotly.express as px
import plotly.subplots as sp
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.parametrize as parametrize
from matrepr import mdisplay
from plotly.subplots import make_subplots
from rich import print as rprint
from safetensors.torch import load_file, load_model
from torch import linalg
from torch.nn.utils.parametrizations import orthogonal
from torch.utils.data import DataLoader, Dataset
from torcheval.metrics import MulticlassAccuracy
from torchvision import transforms

from koopmann import aesthetics
from koopmann.data import (
    DatasetConfig,
    create_data_loader,
    get_dataset_class,
)
from koopmann.models import (
    MLP,
    Autoencoder,
    ExponentialKoopmanAutencoder,
    LowRankKoopmanAutoencoder,
    ResMLP,
)
from koopmann.models.utils import (
    get_device,
    pad_act,
    parse_safetensors_metadata,
)

# from koopmann.utils import compute_model_accuracy
from koopmann.visualization import plot_eigenvalues

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
task = "mnist"
scale_idx = "1"
k = 1
dim = "1024"

flavor = "lowrank_20"
# flavor = "standard"
# flavor = "exponential"

user = "nsa325"

# model_name = f"{task}_probed"
model_name = f"{task}_model_residual"
ae_name = f"{task}_model"

In [None]:
# Original model path
model_file_path = f"/scratch/{user}/koopmann_model_saves/{model_name}.safetensors"

if "probed" in model_name:
    model, model_metadata = MLP.load_model(file_path=model_file_path)
    model.modules[-2].remove_nonlinearity()
    model.modules[-3].remove_nonlinearity()
    # model.modules[-3].update_nonlinearity("leakyrelu")
    is_probed = True
else:
    if "residual" in model_name:
        model, _ = ResMLP.load_model(file_path=model_file_path)
    else:
        model, _ = MLP.load_model(file_path=model_file_path)
    is_probed = False

model.eval().hook_model()

In [None]:
dataset_config = DatasetConfig(
    dataset_name=model_metadata["dataset"], num_samples=5_000, split="test", seed=21
)
DatasetClass = get_dataset_class(name=dataset_config.dataset_name)
dataset = DatasetClass(config=dataset_config)

# Raw images and labels
raw_images, labels = dataset.data, dataset.labels

# Processed for MLP
mlp_transform = transforms.Compose(
    [
        transforms.Lambda(lambda x: x / 255),
        transforms.Normalize((0.1307,), (0.3081,)),
    ]
)
mlp_inputs = mlp_transform(raw_images)

# Processed for AE
ae_transform = transforms.Compose(
    [
        transforms.Lambda(lambda x: x / 255),
        transforms.Lambda(lambda x: x * 2 - 1),
    ]
)
ae_inputs = ae_transform(raw_images)

In [92]:
# Autoenoder path in work dir
ae_file_path = f"/scratch/{user}/koopmann_model_saves/scaling/dim_{dim}_k_{k}_loc_{scale_idx}_{flavor}_autoencoder_{ae_name}.safetensors"

# Choose model based on flag
if "standard" in flavor:
    AutoencoderClass = Autoencoder
elif "lowrank" in flavor:
    AutoencoderClass = LowRankKoopmanAutoencoder
elif "exponential" in flavor:
    AutoencoderClass = ExponentialKoopmanAutencoder

autoencoder, ae_metadata = AutoencoderClass.load_model(
    ae_file_path,
    strict=True,
    remove_param=True,
)
_ = autoencoder.eval()

K_matrix = autoencoder.koopman_matrix.linear_layer.weight.T.detach()

In [71]:
images, labels = next(iter(dataloader))
sample_idx = torch.randint(1024, (1,))[0].item()

In [72]:
with torch.no_grad():
    x = images.flatten(start_dim=1)

    _ = model(x)
    act_dict = model.get_fwd_activations(detach=True)

    # TODO: This is quick and dirty
    ###############################################################
    # Undo MNIST standardization: X_original = X_standardized * std + mean
    x = x * 0.3081 + 0.1307

    # Convert [0,1] range to [-1,1]
    x = 2 * x - 1
    ###############################################################

    input_dict = {0: x}
    for key, val in act_dict.items():
        input_dict[key + 1] = val

    x = input_dict[int(ae_metadata["scale_location"])]
    y = input_dict[len(act_dict) - 1]

    x_obs = autoencoder._encode(x)
    y_obs = autoencoder._encode(y)

    x_recon = autoencoder._decode(x_obs)
    y_recon = autoencoder._decode(y_obs)

    pred_obs = x_obs @ torch.linalg.matrix_power(K_matrix, int(k))
    x_pred = autoencoder._decode(pred_obs)

In [73]:
fig = make_subplots(
    rows=1,
    cols=2,
    subplot_titles=["Original MLP Input/Activation", "Reconstructed MLP Input/Activation"],
)

for i, img in enumerate([x[sample_idx], x_recon[sample_idx]], 1):
    fig.add_trace(px.imshow(img.reshape(16, 32)).data[0], row=1, col=i)

fig.update_layout(height=400, width=800, xaxis_scaleanchor="y", xaxis2_scaleanchor="y2")
fig.show()

print(f"Error: {F.mse_loss(x[sample_idx], x_recon[sample_idx], reduction='mean')}")

Error: 0.06001684442162514


In [74]:
fig = make_subplots(
    rows=1, cols=2, subplot_titles=["Original MLP Activation", "Reconstructed MLP Activation"]
)

for i, img in enumerate([y[sample_idx], y_recon[sample_idx]], 1):
    fig.add_trace(px.imshow(img.reshape(16, 32)).data[0], row=1, col=i)

fig.update_layout(height=400, width=800, xaxis_scaleanchor="y", xaxis2_scaleanchor="y2")
fig.show()
print(f"Reconstruction Error: {F.mse_loss(y[sample_idx], y_recon[sample_idx], reduction='mean')}")

Reconstruction Error: 0.0653894692659378


In [75]:
fig = make_subplots(
    rows=1,
    cols=2,
    subplot_titles=[
        f"Input/Activation Label {labels[sample_idx]} in Observable Space",
        f"Activation {labels[sample_idx]} in Observable Space",
    ],
)

for i, img in enumerate([x_obs[sample_idx], y_obs[sample_idx]], 1):
    fig.add_trace(
        px.imshow(img.reshape(32, 32)).data[0],
        row=1,
        col=i,
    )

fig.update_layout(height=400, width=800, xaxis_scaleanchor="y", xaxis2_scaleanchor="y2")
fig.show()

rprint(
    f"Min Input/Activation Observable: {torch.topk(x_obs[sample_idx], k=1, largest=False).values.item()}"
)
rprint(f"Max Input/Activation Observable: {torch.topk(x_obs[sample_idx], k=1).values.item()}")

rprint(
    f"Min Activation Observable: {torch.topk(y_obs[sample_idx], k=1, largest=False).values.item()}"
)
rprint(f"Max Activation Observable: {torch.topk(y_obs[sample_idx], k=1).values.item()}")

In [76]:
fig = make_subplots(
    rows=1, cols=2, subplot_titles=["Activation Observable ", "Predicted Activation Observable"]
)

for i, img in enumerate([y_obs[sample_idx], pred_obs[sample_idx]], 1):
    fig.add_trace(
        px.imshow(img.reshape(32, 32)).data[0],
        row=1,
        col=i,
    )

fig.update_layout(height=400, width=800, xaxis_scaleanchor="y", xaxis2_scaleanchor="y2")
fig.show()
print(
    f"Linear Prediction Error: {F.mse_loss(y_obs[sample_idx], pred_obs[sample_idx], reduction='mean')}"
)

Linear Prediction Error: 0.09851443022489548


In [77]:
fig = make_subplots(
    rows=1, cols=2, subplot_titles=["Original MLP Activation", "Predicted MLP Activation"]
)

for i, img in enumerate([y[sample_idx], x_pred[sample_idx]], 1):
    fig.add_trace(px.imshow(img.reshape(16, 32)).data[0], row=1, col=i)

fig.update_layout(height=400, width=800, xaxis_scaleanchor="y", xaxis2_scaleanchor="y2")
fig.show()
print(f"Error: {F.mse_loss(y[sample_idx], x_pred[sample_idx], reduction='mean')}")

Error: 0.6546560525894165


In [93]:
# def compute_model_accuracy(model, dataset, batch_size=1_024):
#     model.eval()
#     model.hook_model()
#     dataloader = DataLoader(dataset, batch_size=batch_size)

#     metric_original = MulticlassAccuracy()
#     metric_altered = MulticlassAccuracy()

#     device = get_device()

#     for batch in dataloader:
#         input, target = batch
#         input, target = input.to(device), target.to(device)

#         output_original = model(input)

#         # TODO: This is quick and dirty
#         ###############################################################
#         # Undo MNIST standardization: X_original = X_standardized * std + mean
#         input = input * 0.3081 + 0.1307

#         # Convert [0,1] range to [-1,1]
#         input = 2 * input - 1
#         ###############################################################

#         input_dict = {0: x}
#         for key, val in act_dict.items():
#             input_dict[key + 1] = val

#         input = input_dict[int(ae_metadata["scale_location"])]

#         pred_act = autoencoder(input.flatten(start_dim=1), k=k).predictions[-1]
#         output_altered, _ = model.modules[-2:](pred_act)
#         # print(output_altered.shape)

#     #         metric_original.update(output_original, target.squeeze())
#     #         metric_altered.update(output_altered, target.squeeze())
#     return metric_original.compute(), metric_altered.compute()


# acc_orig, acc_ae = compute_model_accuracy(model.to("cpu"), dataset)
# print(f"Original Model Testing Accuracy: {acc_orig}")
# print(f"Autoencoder Prediction Testing Accuracy: {acc_ae}")