In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import logging
from collections import defaultdict
from enum import auto
from pathlib import Path
from typing import Callable, Dict, Optional, Tuple, Type, Union

import numpy as np
import pandas as pd
import rich
import torch
import typer
from torchmetrics import (
    ErrorRelativeGlobalDimensionlessSynthesis,
    MeanSquaredError,
    MetricCollection,
    PeakSignalNoiseRatio,
    StructuralSimilarityIndexMeasure,
)

from rae.modules.enumerations import Output
from rae.pl_modules.pl_gautoencoder import LightningAutoencoder

try:
    # be ready for 3.10 when it drops
    from enum import StrEnum
except ImportError:
    from backports.strenum import StrEnum

import hydra
from omegaconf import DictConfig, OmegaConf
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm

from nn_core.serialization import NNCheckpointIO

from rae.data.vision.datamodule import MyDataModule


def parse_checkpoint(
    module_class: Type[nn.Module],
    checkpoint_path: Path,
    map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None,
) -> Tuple[nn.Module, DictConfig]:
    if checkpoint_path.name.endswith(".ckpt.zip"):
        checkpoint = NNCheckpointIO.load(path=checkpoint_path, map_location=map_location)
        model = module_class._load_model_state(
            checkpoint=checkpoint, metadata=checkpoint.get("metadata", None), strict=False
        )
        model.eval()
        return (
            model,
            OmegaConf.create(checkpoint["cfg"]),
        )
    raise ValueError(f"Wrong checkpoint: {checkpoint_path}")


logging.getLogger().setLevel(logging.ERROR)


BATCH_SIZE = 32


EXPERIMENT_ROOT = Path(".").parent
EXPERIMENT_CHECKPOINTS = EXPERIMENT_ROOT / "checkpoints"
PREDICTIONS_TSV = EXPERIMENT_ROOT / "predictions.tsv"
PERFORMANCE_TSV = EXPERIMENT_ROOT / "performance.tsv"

DATASET_SANITY = {
    "mnist": ("rae.data.vision.mnist.MNISTDataset", "test"),
    "fmnist": ("rae.data.vision.fmnist.FashionMNISTDataset", "test"),
    "cifar10": ("rae.data.vision.cifar10.CIFAR10Dataset", "test"),
    "cifar100": ("rae.data.vision.cifar100.CIFAR100Dataset", "test"),
}
MODEL_SANITY = {
    "abs": "rae.modules.vision.resnet.ResNet",
    "rel": "rae.modules.vision.relresnet.RelResNet",
}


def parse_checkpoint_id(ckpt: Path) -> str:
    return ckpt.with_suffix("").with_suffix("").name


# Parse checkpoints tree
checkpoints = defaultdict(dict)
RUNS = defaultdict(dict)
for dataset_abbrv in (
    dataset_abbrv for dataset_abbrv in sorted(EXPERIMENT_CHECKPOINTS.iterdir()) if dataset_abbrv.is_dir()
):
    checkpoints[dataset_abbrv.name] = defaultdict(list)
    RUNS[dataset_abbrv.name] = defaultdict(list)
    for model_abbrv in sorted(dataset_abbrv.iterdir()):
        for ckpt in sorted(model_abbrv.iterdir()):
            checkpoints[dataset_abbrv.name][model_abbrv.name].append(ckpt)
            RUNS[dataset_abbrv.name][model_abbrv.name].append(parse_checkpoint_id(ckpt))

In [None]:
def get_dataset(ckpt):
    _, cfg = parse_checkpoint(
        module_class=PL_MODULE,
        checkpoint_path=MODELS[0],
        map_location="cpu",
    )
    datamodule: MyDataModule = hydra.utils.instantiate(cfg.nn.data, _recursive_=False)
    datamodule.setup()
    val_dataset = datamodule.val_datasets[0]
    return val_dataset

In [None]:
from sklearn.decomposition import PCA


def get_latents(images_batch, ckpt, pca=None):
    model, _ = parse_checkpoint(
        module_class=PL_MODULE,
        checkpoint_path=ckpt,
        map_location="cpu",
    )
    latents = model(images_batch)[Output.DEFAULT_LATENT].detach()
    p = torch.randperm(latents.shape[0])
    latents = latents[p]
    if latents.shape[-1] == 2:
        latents2d = latents
    else:
        if pca is None:
            pca = PCA(n_components=2)
            pca.fit(latents)

        latents2d = pca.transform(latents)

    df = pd.DataFrame(
        {
            "x": latents2d[:, 0].tolist(),
            "y": latents2d[:, 1].tolist(),
            "class": [classes[i] for i in p],
            "target": [targets[i] for i in p],
            "index": [indexes[i] for i in p],
        }
    )
    return df, pca

In [None]:
def plot_bg(
    ax,
    df,
    size=0.5,
    bg_alpha=0.01,
):
    """Create and return a plot of all our movie embeddings with very low opacity.
    (Intended to be used as a basis for further - more prominent - plotting of a
    subset of movies. Having the overall shape of the map space in the background is
    useful for context.)
    """
    ax.scatter(df.x, df.y, c=cmap(norm(df["target"])), alpha=bg_alpha, s=size)
    return ax


def hightlight_cluster(
    ax,
    df,
    target,
    alpha,
    size=0.5,
):
    cluster_df = df[df["target"] == target]
    ax.scatter(cluster_df.x, cluster_df.y, c=cmap(norm(cluster_df["target"])), alpha=alpha, s=size)


def plot_latent_space(ax, df, targets, size, bg_alpha=0.1, alpha=0.5):
    ax = plot_bg(ax, df, bg_alpha=bg_alpha)
    for target in targets:
        hightlight_cluster(ax, df, target, alpha=alpha, size=size)
    return ax

# Latent Rotations

In [None]:
MODELS = checkpoints["mnist"]["nososmall_ae"]
MODELS

In [None]:
PL_MODULE = LightningAutoencoder

In [None]:
images = []
targets = []
indexes = []
classes = []

from pytorch_lightning import seed_everything

seed_everything(0)

val_dataset = get_dataset(MODELS[0])
K = 5_000
idxs = torch.randperm(len(val_dataset))[:K]

for idx in idxs:
    sample = val_dataset[idx]
    indexes.append(sample["index"].item())
    images.append(sample["image"])
    targets.append(sample["target"])
    classes.append(sample["class"])

images_batch = torch.stack(images, dim=0)

In [None]:
all_latents_df = []
for ckpt in MODELS:
    df, _ = get_latents(images_batch, ckpt)
    all_latents_df.append(df)

In [None]:
TO_CONSIDER = range(len(all_latents_df))
latents_df = [all_latents_df[i] for i in TO_CONSIDER]

In [None]:
import matplotlib.pyplot as plt
from tueplots import bundles

plt.rcParams.update(bundles.icml2022())
N_ROWS = 2
N_COLS = len(latents_df) // 2

plt.rcParams.update(figsizes.icml2022_full(ncols=N_COLS, nrows=N_ROWS, height_to_width_ratio=1.0))
cmap = plt.cm.get_cmap("Set1", 10)
norm = plt.Normalize(latents_df[0]["target"].min(), latents_df[0]["target"].max())


fig, axes = plt.subplots(dpi=150, nrows=N_ROWS, ncols=N_COLS, sharey=True, sharex=True, squeeze=True)

for i, row in enumerate(axes):
    for j, ax in enumerate(row):
        ax.set_aspect("equal")
        plot_latent_space(ax, all_latents_df[i * N_COLS + j], targets=[0, 4], size=0.5)

In [None]:
TO_CONSIDER = [4, 6, 5, 8]
chosen_latents_df = [all_latents_df[i] for i in TO_CONSIDER]

In [None]:
import matplotlib.pyplot as plt
from tueplots import bundles

plt.rcParams.update(bundles.icml2022())
N_ROWS = 1
N_COLS = len(chosen_latents_df)

plt.rcParams.update(figsizes.icml2022_full(ncols=N_COLS, nrows=N_ROWS, height_to_width_ratio=1.0))
cmap = plt.cm.get_cmap("Set1", 10)
norm = plt.Normalize(latents_df[0]["target"].min(), latents_df[0]["target"].max())


fig, axes = plt.subplots(dpi=150, nrows=N_ROWS, ncols=N_COLS, sharey=True, sharex=True, squeeze=True)

for i, ax in enumerate(axes):
    ax.set_aspect("equal")
    plot_latent_space(ax, chosen_latents_df[i], targets=[0, 2], size=0.75, bg_alpha=0.1)

In [None]:
fig.savefig("latent_rotation.svg", bbox_inches="tight")

In [None]:
!rsvg-convert -f pdf -o latent_rotation.pdf latent_rotation.svg

# Latent Rotations

Single PCA proof

In [None]:
MODELS = checkpoints["mnist"]["ae"]
PL_MODULE = LightningAutoencoder

In [None]:
images = []
targets = []
indexes = []
classes = []

from pytorch_lightning import seed_everything

seed_everything(0)

val_dataset = get_dataset(MODELS[0])
K = 5_000
idxs = torch.randperm(len(val_dataset))[:K]

for idx in idxs:
    sample = val_dataset[idx]
    indexes.append(sample["index"].item())
    images.append(sample["image"])
    targets.append(sample["target"])
    classes.append(sample["class"])

images_batch = torch.stack(images, dim=0)

In [None]:
latents_single_pca = []
pca = None
for ckpt in MODELS:
    df, pca = get_latents(images_batch, ckpt, pca)
    latents_single_pca.append(df)

In [None]:
latents_independent_pca = []
pca = None
for ckpt in MODELS:
    df, _ = get_latents(images_batch, ckpt, None)
    latents_independent_pca.append(df)

In [None]:
TO_CONSIDER = [0, 1, 2, 3, 4, 5][: len(all_latents_df)]
latents_single_pca = [latents_single_pca[i] for i in TO_CONSIDER]
latents_independent_pca = [latents_independent_pca[i] for i in TO_CONSIDER]

In [None]:
from tueplots import figsizes

In [None]:
import matplotlib.pyplot as plt
from tueplots import bundles

import distinctipy
import matplotlib as mpl

plt.rcParams.update(bundles.icml2022())
N_ROWS = 1
N_COLS = len(latents_single_pca)

plt.rcParams.update(figsizes.icml2022_full(ncols=N_COLS, nrows=N_ROWS, height_to_width_ratio=1.0))

cmap = plt.cm.get_cmap("Set1", 10)
norm = plt.Normalize(latents_single_pca[0]["target"].min(), latents_single_pca[0]["target"].max())

fig, axes = plt.subplots(
    ncols=N_COLS,
    nrows=N_ROWS,
    sharey=True,
    sharex=True,
    squeeze=True,
)

for j, (ax, df) in enumerate(zip(axes, latents_independent_pca)):
    ax.set_aspect("equal")
    ax.set_title(f"Train {j}")
    plot_latent_space(
        ax,
        df,
        targets=[
            0,
            1,
        ],
        size=0.5,
        bg_alpha=0.1,
        alpha=0.7,
    )

In [None]:
fig.savefig("pca-proof-row1.svg", bbox_inches="tight")

In [None]:
fig, axes = plt.subplots(
    ncols=N_COLS,
    nrows=N_ROWS,
    sharey=False,
    sharex=False,
    squeeze=True,
)

for j, (ax, df) in enumerate(zip(axes, latents_single_pca)):
    ax.set_aspect("equal")
    plot_latent_space(
        ax,
        df,
        targets=[
            0,
            1,
        ],
        size=0.5,
        bg_alpha=0.1,
        alpha=0.7,
    )

In [None]:
fig.savefig("pca-proof-row2.svg", bbox_inches="tight")

In [None]:
!rsvg-convert -f pdf -o pca-proof-row1.pdf pca-proof-row1.svg
!rsvg-convert -f pdf -o pca-proof-row2.pdf pca-proof-row2.svg