In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import logging
from pathlib import Path

import pandas as pd
import torch

from rae.utils.evaluation import parse_checkpoints_tree
from rae.modules.enumerations import Output
from rae.pl_modules.pl_gautoencoder import LightningAutoencoder

try:
    # be ready for 3.10 when it drops
    from enum import StrEnum
except ImportError:
    from backports.strenum import StrEnum

import hydra

from rae.data.vision.datamodule import MyDataModule


logging.getLogger().setLevel(logging.ERROR)


BATCH_SIZE = 32


EXPERIMENT_ROOT = Path(".").parent
EXPERIMENT_CHECKPOINTS = EXPERIMENT_ROOT / "checkpoints"
PREDICTIONS_TSV = EXPERIMENT_ROOT / "predictions.tsv"
PERFORMANCE_TSV = EXPERIMENT_ROOT / "performance.tsv"

DATASET_SANITY = {
    "mnist": ("rae.data.vision.mnist.MNISTDataset", "test"),
    "fmnist": ("rae.data.vision.fmnist.FashionMNISTDataset", "test"),
    "cifar10": ("rae.data.vision.cifar10.CIFAR10Dataset", "test"),
    "cifar100": ("rae.data.vision.cifar100.CIFAR100Dataset", "test"),
}
MODEL_SANITY = {
    "abs": "rae.modules.vision.resnet.ResNet",
    "rel": "rae.modules.vision.relresnet.RelResNet",
}


checkpoints, RUNS = parse_checkpoints_tree(EXPERIMENT_CHECKPOINTS)

In [None]:
from rae.utils.evaluation import parse_checkpoint
from sklearn.decomposition import PCA


def get_latents(images_batch, ckpt, pca=None):
    model, _ = parse_checkpoint(
        module_class=PL_MODULE,
        checkpoint_path=ckpt,
        map_location="cpu",
    )
    latents = model(images_batch)[Output.DEFAULT_LATENT].detach()

    if latents.shape[-1] == 2:
        latents2d = latents
    else:
        if pca is None:
            pca = PCA(n_components=2)
            pca.fit(latents)

        latents2d = pca.transform(latents)

    df = pd.DataFrame(
        {
            "x": latents2d[:, 0].tolist(),
            "y": latents2d[:, 1].tolist(),
            "class": classes,
            "target": targets,
            "index": indexes,
        }
    )
    return df, pca

# Latent Rotations

In [None]:
MODELS = checkpoints["mnist"]["small_ae"]
MODELS

In [None]:
PL_MODULE = LightningAutoencoder

In [None]:
from rae.utils.evaluation import get_dataset

images = []
targets = []
indexes = []
classes = []

from pytorch_lightning import seed_everything

seed_everything(0)

val_dataset = get_dataset(pl_module=PL_MODULE, ckpt=MODELS[0])
K = 2_000
idxs = torch.randperm(len(val_dataset))[:K]

for idx in idxs:
    sample = val_dataset[idx]
    indexes.append(sample["index"].item())
    images.append(sample["image"])
    targets.append(sample["target"])
    classes.append(sample["class"])

images_batch = torch.stack(images, dim=0)

In [None]:
all_latents_df = []
for ckpt in MODELS:
    df, _ = get_latents(images_batch, ckpt)
    all_latents_df.append(df)

In [None]:
TO_CONSIDER = range(len(all_latents_df))
latents_df = [all_latents_df[i] for i in TO_CONSIDER]

In [None]:
from rae.utils.evaluation import plot_latent_space
import matplotlib.pyplot as plt
from tueplots import bundles
from tueplots import figsizes

plt.rcParams.update(bundles.icml2022())
N_ROWS = 2
N_COLS = len(latents_df) // 2

plt.rcParams.update(figsizes.icml2022_full(ncols=N_COLS, nrows=N_ROWS, height_to_width_ratio=1.0))
cmap = plt.cm.get_cmap("Set1", 10)
norm = plt.Normalize(latents_df[0]["target"].min(), latents_df[0]["target"].max())


fig, axes = plt.subplots(dpi=150, nrows=N_ROWS, ncols=N_COLS, sharey=True, sharex=True, squeeze=True)

for i, row in enumerate(axes):
    for j, ax in enumerate(row):
        ax.set_aspect("equal")
        plot_latent_space(
            ax, all_latents_df[i * N_COLS + j], targets=[0, 2], size=0.5, cmap=cmap, norm=norm, bg_alpha=0.15
        )

In [None]:
TO_CONSIDER = [4, 6, 5, 8]
chosen_latents_df = [all_latents_df[i] for i in TO_CONSIDER]

In [None]:
from rae.utils.evaluation import plot_latent_space
import matplotlib.pyplot as plt
from tueplots import bundles

plt.rcParams.update(bundles.icml2022())
N_ROWS = 1
N_COLS = len(chosen_latents_df)

plt.rcParams.update(figsizes.icml2022_full(ncols=N_COLS, nrows=N_ROWS, height_to_width_ratio=1.0))
cmap = plt.cm.get_cmap("Set1", 10)
norm = plt.Normalize(latents_df[0]["target"].min(), latents_df[0]["target"].max())


fig, axes = plt.subplots(dpi=150, nrows=N_ROWS, ncols=N_COLS, sharey=True, sharex=True, squeeze=True)

for i, ax in enumerate(axes):
    ax.set_aspect("equal")
    plot_latent_space(ax, chosen_latents_df[i], targets=[0, 2], size=0.75, bg_alpha=0.15, cmap=cmap, norm=norm)

In [None]:
fig.savefig("latent_rotation.svg", bbox_inches="tight")

In [None]:
!rsvg-convert -f pdf -o latent_rotation.pdf latent_rotation.svg
!rm latent_rotation.svg

# Latent Rotations

Single PCA proof

In [None]:
MODELS = checkpoints["mnist"]["ae"]
PL_MODULE = LightningAutoencoder
MODELS

In [None]:
images = []
targets = []
indexes = []
classes = []

from pytorch_lightning import seed_everything

seed_everything(0)

val_dataset = get_dataset(pl_module=PL_MODULE, ckpt=MODELS[0])

K = 2_000
idxs = torch.randperm(len(val_dataset))[:K]

for idx in idxs:
    sample = val_dataset[idx]
    indexes.append(sample["index"].item())
    images.append(sample["image"])
    targets.append(sample["target"])
    classes.append(sample["class"])

images_batch = torch.stack(images, dim=0)

In [None]:
latents_single_pca = []
pca = None
for ckpt in MODELS:
    df, pca = get_latents(images_batch, ckpt, pca)
    latents_single_pca.append(df)

In [None]:
latents_independent_pca = []
pca = None
for ckpt in MODELS:
    df, _ = get_latents(images_batch, ckpt, None)
    latents_independent_pca.append(df)

In [None]:
TO_CONSIDER = [0, 1, 2, 3, 4][: len(all_latents_df)]
latents_single_pca = [latents_single_pca[i] for i in TO_CONSIDER]
latents_independent_pca = [latents_independent_pca[i] for i in TO_CONSIDER]

In [None]:
from tueplots import figsizes

In [None]:
import matplotlib.pyplot as plt
from tueplots import bundles

plt.rcParams.update(bundles.icml2022())
N_ROWS = 1
N_COLS = len(latents_single_pca)

plt.rcParams.update(figsizes.icml2022_full(ncols=N_COLS, nrows=N_ROWS, height_to_width_ratio=1.0))

cmap = plt.cm.get_cmap("Set1", 10)
norm = plt.Normalize(latents_single_pca[0]["target"].min(), latents_single_pca[0]["target"].max())

fig, axes = plt.subplots(
    ncols=N_COLS,
    nrows=N_ROWS,
    sharey=True,
    sharex=True,
    squeeze=True,
)

for j, (ax, df) in enumerate(zip(axes, latents_independent_pca)):
    ax.set_aspect("equal")
    ax.set_title(f"Train {j}")
    plot_latent_space(
        ax,
        df,
        targets=[
            0,
            1,
        ],
        size=0.5,
        bg_alpha=0.1,
        alpha=0.7,
        cmap=cmap,
        norm=norm,
    )

In [None]:
fig.savefig("pca-proof-row1.svg", bbox_inches="tight")

In [None]:
fig, axes = plt.subplots(
    ncols=N_COLS,
    nrows=N_ROWS,
    sharey=False,
    sharex=False,
    squeeze=True,
)

for j, (ax, df) in enumerate(zip(axes, latents_single_pca)):
    ax.set_aspect("equal")
    plot_latent_space(
        ax,
        df,
        targets=[
            0,
            1,
        ],
        size=0.5,
        bg_alpha=0.1,
        alpha=0.7,
        cmap=cmap,
        norm=norm,
    )

In [None]:
fig.savefig("pca-proof-row2.svg", bbox_inches="tight")

In [None]:
!rsvg-convert -f pdf -o pca-proof-row1.pdf pca-proof-row1.svg
!rsvg-convert -f pdf -o pca-proof-row2.pdf pca-proof-row2.svg
!rm pca-proof-row2.svg
!rm pca-proof-row1.svg