In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import logging
from enum import auto
from pathlib import Path
from typing import Callable, Dict, Optional, Tuple, Type, Union

import numpy as np
import pandas as pd
import rich
import torch
import typer
from torchmetrics import (
    ErrorRelativeGlobalDimensionlessSynthesis,
    MeanSquaredError,
    MetricCollection,
    PeakSignalNoiseRatio,
    StructuralSimilarityIndexMeasure,
)

from nn_core.common import PROJECT_ROOT
from rae.modules.enumerations import Output
from rae.pl_modules.pl_gautoencoder import LightningAutoencoder
from rae.utils.evaluation import parse_checkpoint_id, parse_checkpoints_tree, parse_checkpoint
from collections import defaultdict

try:
    # be ready for 3.10 when it drops
    from enum import StrEnum
except ImportError:
    from backports.strenum import StrEnum

from rae.utils.evaluation import plot_latent_space
import matplotlib.pyplot as plt
from tueplots import bundles
from tueplots import figsizes

logging.getLogger().setLevel(logging.ERROR)


BATCH_SIZE = 256


EXPERIMENT_ROOT = PROJECT_ROOT / "experiments" / "fig:latent-rotation-comparison"
EXPERIMENT_CHECKPOINTS = EXPERIMENT_ROOT / "checkpoints"
PREDICTIONS_TSV = EXPERIMENT_ROOT / "predictions.tsv"
PERFORMANCE_TSV = EXPERIMENT_ROOT / "performance.tsv"

DATASET_SANITY = {
    "mnist": ("rae.data.vision.mnist.MNISTDataset", "test"),
    "fmnist": ("rae.data.vision.fmnist.FashionMNISTDataset", "test"),
    "cifar10": ("rae.data.vision.cifar10.CIFAR10Dataset", "test"),
    "cifar100": ("rae.data.vision.cifar100.CIFAR100Dataset", "test"),
}
MODEL_SANITY = {
    "vae": "rae.modules.vae.VanillaVAE",
    "ae": "rae.modules.ae.VanillaAE",
    "rel_vae": "rae.modules.rel_vae.VanillaRelVAE",
    "rel_ae": "rae.modules.rel_ae.VanillaRelAE",
}


checkpoints, RUNS = parse_checkpoints_tree(EXPERIMENT_CHECKPOINTS)

In [None]:
from rae.utils.evaluation import parse_checkpoint
from sklearn.decomposition import PCA


def get_latents(images_batch, ckpt, pca=None, key=Output.DEFAULT_LATENT):
    model, _ = parse_checkpoint(
        module_class=PL_MODULE,
        checkpoint_path=ckpt,
        map_location="cpu",
    )
    latents = model(images_batch)[key].detach().squeeze()

    latents2d = latents[:, [0, 1]]

    df = pd.DataFrame(
        {
            "x": latents2d[:, 0].tolist(),
            "y": latents2d[:, 1].tolist(),
            "class": classes,
            "target": targets,
            "index": indexes,
        }
    )
    return df, latents, pca

# Latent Rotations

In [None]:
MODELS = checkpoints["mnist"]["ae"]

In [None]:
PL_MODULE = LightningAutoencoder

In [None]:
from rae.utils.evaluation import get_dataset

images = []
targets = []
indexes = []
classes = []

from pytorch_lightning import seed_everything

seed_everything(0)

val_dataset = get_dataset(pl_module=PL_MODULE, ckpt=MODELS[0])
K = 2_000
idxs = torch.randperm(len(val_dataset))[:K]

for idx in idxs:
    sample = val_dataset[idx]
    indexes.append(sample["index"].item())
    images.append(sample["image"])
    targets.append(sample["target"])
    classes.append(sample["class"])

images_batch = torch.stack(images, dim=0)

## AE

In [None]:
MODELS = checkpoints["mnist"]["ae"]
MODELS

In [None]:
ae_latents = []

for ckpt in MODELS:
    df, latents, _ = get_latents(images_batch, ckpt, None)
    ae_latents.append((df, latents))

## RelAE

In [None]:
MODELS = checkpoints["mnist"]["rel_ae"]
MODELS

In [None]:
rel_ae_latents = []

for ckpt in MODELS:
    df, latents, _ = get_latents(images_batch, ckpt, None, key=Output.SIMILARITIES)
    rel_ae_latents.append((df, latents))

In [None]:
model, _ = parse_checkpoint(
    module_class=PL_MODULE,
    checkpoint_path=checkpoints["mnist"]["rel_ae"][0],
    map_location="cpu",
)
sim = model(images_batch)[Output.RECONSTRUCTION]
sim.mean()

In [None]:
sim = model(images_batch)[Output.RECONSTRUCTION]
sim.mean()

## RelAE Q

In [None]:
from tqdm import tqdm

In [None]:
quantized_rel_aes = ["rel_ae_0.1", "rel_ae_0.2", "rel_ae_0.3", "rel_ae_0.5"]

quantized_rel_latents = defaultdict(list)

for model in tqdm(quantized_rel_aes):

    ckpts = checkpoints["mnist"][model]

    for ckpt in tqdm(ckpts, leave=False):
        df, latents, _ = get_latents(images_batch, ckpt, None, key=Output.SIMILARITIES)
        quantized_rel_latents[model].append((df, latents))

In [None]:
model, _ = parse_checkpoint(
    module_class=PL_MODULE,
    checkpoint_path=checkpoints["mnist"]["rel_ae_0.1"][0],
    map_location="cpu",
)

In [None]:
x = model.decode(**model.encode(images_batch))[Output.RECONSTRUCTION]

In [None]:
y = model(images_batch)[Output.RECONSTRUCTION]

In [None]:
torch.allclose(x, y)

# Visualize

In [None]:
def latents_distance(latents):
    dists = []
    for i in range(len(latents)):
        for j in range(i + 1, len(latents)):
            x = latents[i][1]
            y = latents[j][1]
            # dist = ((x - y)**2).sum(dim=-1).sqrt().mean()
            dist = ((x - y) ** 2).mean(dim=-1).mean()
            dists.append(dist)
    return sum(dists) / len(dists)

In [None]:
LIM = 4

In [None]:
template_df = ae_latents[0][0]

N_ROWS = 1
N_COLS = LIM

plt.rcParams.update(bundles.icml2022())
plt.rcParams.update(figsizes.icml2022_full(ncols=N_COLS, nrows=N_ROWS, height_to_width_ratio=1.0))

cmap = plt.cm.get_cmap("Set1", 10)
norm = plt.Normalize(template_df["target"].min(), template_df["target"].max())


def plot_row(df, title, equal=True, sharey=False, sharex=False, dpi=150):
    fig, axes = plt.subplots(dpi=dpi, nrows=N_ROWS, ncols=N_COLS, sharey=sharey, sharex=sharex, squeeze=True)

    for j, ax in enumerate(axes):
        if j == 0:
            ax.set_ylabel(title)
        if equal:
            ax.set_aspect("equal")
        plot_latent_space(ax, df[j], targets=[0, 1], size=0.75, bg_alpha=0.25, alpha=1, cmap=cmap, norm=norm)
    return fig

In [None]:
f = plot_row([df for df, _ in ae_latents[:LIM]], "AE", True, True, True)
latents_distance(ae_latents[:LIM])

In [None]:
f.savefig("ae.svg", bbox_inches="tight")
!rsvg-convert -f pdf -o ae.pdf ae.svg
!rm ae.svg

In [None]:
f = plot_row([df for df, _ in rel_ae_latents[:LIM]], "RelAE", True, True, True)
latents_distance(rel_ae_latents[:LIM])

In [None]:
f.savefig("rel_ae.svg", bbox_inches="tight")
!rsvg-convert -f pdf -o rel_ae.pdf rel_ae.svg
!rm rel_ae.svg

In [None]:
f = plot_row([df for df, _ in quantized_rel_latents["rel_ae_0.1"][:LIM]], "RelAE 0.1", True, True, True)
latents_distance(quantized_rel_latents["rel_ae_0.1"][:LIM])

In [None]:
f.savefig("rel_ae_0.1.svg", bbox_inches="tight")
!rsvg-convert -f pdf -o 'rel_ae_0.1.pdf' 'rel_ae_0.1.svg'
!rm 'rel_ae_0.1'.svg

In [None]:
f = plot_row([df for df, _ in quantized_rel_latents["rel_ae_0.2"][:LIM]], "RelAE 0.2", True, True, True)
latents_distance(quantized_rel_latents["rel_ae_0.2"][:LIM])

In [None]:
f.savefig("rel_ae_0.2.svg", bbox_inches="tight")
!rsvg-convert -f pdf -o 'rel_ae_0.2.pdf' 'rel_ae_0.2.svg'
!rm 'rel_ae_0.2'.svg

In [None]:
f = plot_row([df for df, _ in quantized_rel_latents["rel_ae_0.3"][:LIM]], "RelAE 0.3", True, True, True)
latents_distance(quantized_rel_latents["rel_ae_0.3"][:LIM])

In [None]:
f.savefig("rel_ae_0.3.svg", bbox_inches="tight")
!rsvg-convert -f pdf -o 'rel_ae_0.3.pdf' 'rel_ae_0.3.svg'
!rm 'rel_ae_0.3'.svg

In [None]:
f = plot_row([df for df, _ in quantized_rel_latents["rel_ae_0.5"][:LIM]], "RelAE 0.5", True, True, True)
latents_distance(quantized_rel_latents["rel_ae_0.5"][:LIM])

In [None]:
f.savefig("rel_ae_0.5.svg", bbox_inches="tight")
!rsvg-convert -f pdf -o 'rel_ae_0.5.pdf' 'rel_ae_0.5.svg'
!rm 'rel_ae_0.5'.svg