In [12]:
import matplotlib.pyplot as plt
import pandas as pd
import plotly.express as px
import seaborn as sns
import torch
from rich import print as rprint
from torch import nn
from torch.utils.data import DataLoader

from analysis.common import load_model

# from koopmann import aesthetics
from koopmann.data import (
    DatasetConfig,
    get_dataset_class,
)
from koopmann.models import MLP, AnalyticEncoder
from koopmann.utils import (
    compute_model_accuracy,
    get_device,
)

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [13]:
file_dir = "/scratch/nsa325/koopmann_model_saves"
model_name = "mlp"
device = get_device()

In [14]:
model = MLP(
    in_features=784,
    out_features=10,
    hidden_config=[512],
    bias=False,
    batchnorm=False,
    nonlinearity="relu",
)
model

MLP(
  (components): Sequential(
    (linear_0): LinearLayer(
      (components): ModuleDict(
        (linear): Linear(in_features=784, out_features=512, bias=False)
        (nonlinearity): ReLU()
      )
    )
    (linear_1): LinearLayer(
      (components): ModuleDict(
        (linear): Linear(in_features=512, out_features=10, bias=False)
      )
    )
  )
)

In [15]:
analytic_encoder = AnalyticEncoder.from_model(model)

In [16]:
model

MLP(
  (components): Sequential(
    (linear_0): LinearLayer(
      (components): ModuleDict(
        (linear): Linear(in_features=784, out_features=512, bias=False)
        (nonlinearity): ReLU()
      )
    )
    (linear_1): LinearLayer(
      (components): ModuleDict(
        (linear): Linear(in_features=512, out_features=10, bias=False)
      )
    )
  )
)

In [17]:
# Dataset config
dataset_config = DatasetConfig(
    dataset_name="MNISTDataset",
    num_samples=3_000,
    split="test",
    seed=42,
)
DatasetClass = get_dataset_class(name=dataset_config.dataset_name)
dataset = DatasetClass(config=dataset_config)
dataloader = DataLoader(dataset, batch_size=1000)

In [18]:
inputs, labels = next(iter(dataloader))
W_0 = model.components[0].components.linear.weight
W_1 = model.components[1].components.linear.weight
model_pred = model(inputs)

In [19]:
dense_koopman = analytic_encoder.build_operator().to_dense()

In [20]:
dense_koopman[:10, :]

tensor([[-0.0812, -0.0311, -0.0904,  ...,  0.0900,  0.0358,  0.0000],
        [-0.0476,  0.0597,  0.0614,  ...,  0.0470,  0.0637,  0.0000],
        [-0.0120,  0.0951, -0.0375,  ...,  0.0231,  0.0249,  0.0000],
        ...,
        [-0.0494,  0.0360,  0.0376,  ...,  0.0126,  0.0825,  0.0000],
        [-0.0979,  0.0056,  0.0146,  ...,  0.0119,  0.0274,  0.0000],
        [ 0.0393,  0.0092,  0.1405,  ...,  0.0276, -0.1266,  0.0000]])

In [23]:
x_embedded = analytic_encoder.forward(inputs)

In [22]:
print(model_pred[0])
print(x_embedded[0])

tensor([ 0.6606, -0.4021,  0.1712, -0.5286, -0.3006,  1.0289,  0.2432, -1.8944,
         0.2768,  0.6058], grad_fn=<SelectBackward0>)
tensor([ 1.3374,  0.2477,  1.9115, -1.6500,  0.9678, -0.0663, -0.5192, -1.0517,
         0.2461,  0.4203,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000])
