In [None]:
import torch
from safetensors import safe_open
from torch.utils.data import DataLoader, Subset

from analysis.common import load_autoencoder, load_model
from koopmann.data import DatasetConfig, get_dataset_class
from koopmann.utils import get_device
from scripts.train_ae.shape_metrics import Processor, build_acts_dict, prepare_acts
import plotly.graph_objects as go
import plotly.express as px
import cuml
from cuml.manifold.umap import UMAP as umap
from cuml.random_projection import SparseRandomProjection as random_projection
from cuml.decomposition import IncrementalPCA as pca

In [None]:
file_dir = "/scratch/nsa325/koopmann_model_saves"
dim = 800
k_steps = 1
scale_idx = 1
dataset_name = "cifar10"

# rank = 100
# flavor = f"lowrank_{rank}"
flavor = "standard"
# flavor = "exponential"

model_name = f"convresnet_{dataset_name}"
ae_name = f"dim_{dim}_k_{k_steps}_loc_{scale_idx}_{flavor}_autoencoder_{dataset_name}_model"
# device = get_device()
device = "cpu"

In [None]:
model, model_metadata = load_model(file_dir, model_name)
model.eval().hook_model().to(device)
print(model_metadata)

In [None]:
autoencoder, ae_metadata = load_autoencoder(file_dir, ae_name)
autoencoder.eval()
print(ae_metadata)

In [None]:
# Dataset config
dataset_config = DatasetConfig(
    dataset_name=model_metadata["dataset"],
    num_samples=3_000,
    split="train",
    seed=42,
)
DatasetClass = get_dataset_class(name=dataset_config.dataset_name)
dataset = DatasetClass(config=dataset_config)

subset_size = 20_000
if subset_size:
    subset_indices = list(range(0, subset_size))
    dataset = Subset(dataset, subset_indices)

batch_size = 3_000
batch_size = min(subset_size, batch_size) if subset_size else batch_size
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

In [None]:
preproc_dict = {}
with safe_open(
    f"{file_dir}/{ae_name}_preprocessing.safetensors", framework="pt", device="cpu"
) as f:
    for k in f.keys():
        preproc_dict[k] = f.get_tensor(k)


In [None]:
orig_act_dict = build_acts_dict(
    data_train_loader=dataloader, model=model, only_first_last=True, device=device
)

In [None]:
def _dim_reduce_umap(x, dim):
    n_samples = x.shape[0]
    # g_embedding = pca(n_components=dim, batch_size=200).fit_transform(x)
    # g_embedding = random_projection(n_components=dim).fit_transform(x)
    g_embedding = umap(n_neighbors=100, n_components=dim, spread=5, min_dist=0.01).fit_transform(x)
    # print(cuml.random_projection.johnson_lindenstrauss_min_dim(n_samples, eps=0.3))
    return g_embedding, None

In [None]:
device = "cpu"
for key, curr_act in orig_act_dict.items():
    print(key)
    new_dim = 3
    processed_act = torch.flatten(curr_act.clone().to(device), start_dim=1)
    means = torch.mean(processed_act, dim=0, keepdim=True)
    # processed_act -= means.to(device)
    processed_act, directions = _dim_reduce_umap(processed_act.numpy(), new_dim)
    processed_act = torch.tensor(processed_act)
    norms = torch.linalg.norm(processed_act, ord="fro") / 1000
    processed_act /= norms.to(device)
    # processed_act = Processor._whiten(processed_act, alpha=0.9)
    break


In [None]:
x = processed_act.cpu()[:, 0]
y = processed_act.cpu()[:, 1]
z = processed_act.cpu()[:, 2]


# Convert targets to strings to force categorical interpretation
target_categories = [str(t) for t in dataset.dataset.targets[:subset_size]]

fig = px.scatter_3d(
    x=x,
    y=y,
    z=z,
    color=target_categories,
    color_discrete_sequence=px.colors.qualitative.T10,
)

fig.update_traces(marker_size=2)
fig.show()