In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import logging
from enum import auto
from pathlib import Path
from typing import Callable, Dict, Optional, Tuple, Type, Union

import numpy as np
import pandas as pd
import rich
import torch
import typer
from torchmetrics import (
    ErrorRelativeGlobalDimensionlessSynthesis,
    MeanSquaredError,
    MetricCollection,
    PeakSignalNoiseRatio,
    StructuralSimilarityIndexMeasure,
)
import torchvision
from nn_core.common import PROJECT_ROOT
from rae.modules.enumerations import Output
from rae.pl_modules.pl_gautoencoder import LightningAutoencoder
from rae.utils.evaluation import parse_checkpoint_id, parse_checkpoints_tree, parse_checkpoint
from collections import defaultdict

try:
    # be ready for 3.10 when it drops
    from enum import StrEnum
except ImportError:
    from backports.strenum import StrEnum

from rae.utils.evaluation import plot_latent_space
import matplotlib.pyplot as plt
from tueplots import bundles
from tueplots import figsizes
import logging
from enum import auto
from functools import cached_property, partial
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Union

import hydra
import numpy as np
import omegaconf
import pytorch_lightning as pl
import torch
from omegaconf import DictConfig
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataloader import default_collate

from nn_core.common import PROJECT_ROOT
from nn_core.nn_types import Split

logging.getLogger().setLevel(logging.ERROR)


BATCH_SIZE = 256


EXPERIMENT_ROOT = PROJECT_ROOT / "experiments" / "sec:model-reusability-ae"
EXPERIMENT_CHECKPOINTS = EXPERIMENT_ROOT / "checkpoints"
PREDICTIONS_TSV = EXPERIMENT_ROOT / "predictions.tsv"
PERFORMANCE_TSV = EXPERIMENT_ROOT / "performance.tsv"

DATASET_SANITY = {
    "mnist": ("rae.data.vision.mnist.MNISTDataset", "test"),
    "fmnist": ("rae.data.vision.fmnist.FashionMNISTDataset", "test"),
    "cifar10": ("rae.data.vision.cifar10.CIFAR10Dataset", "test"),
    "cifar100": ("rae.data.vision.cifar100.CIFAR100Dataset", "test"),
}
MODEL_SANITY = {
    "vae": "rae.modules.vae.VanillaVAE",
    "ae": "rae.modules.ae.VanillaAE",
    "rel_vae": "rae.modules.rel_vae.VanillaRelVAE",
    "rel_ae": "rae.modules.rel_ae.VanillaRelAE",
}


checkpoints, RUNS = parse_checkpoints_tree(EXPERIMENT_CHECKPOINTS)

# Train stitching

In [None]:
MODELS = checkpoints["mnist"]["ae"]
PL_MODULE = LightningAutoencoder

num_samples = 20
K = 2

In [None]:
from rae.utils.evaluation import get_dataset
from pytorch_lightning import seed_everything

seed_everything(0)

val_dataset = get_dataset(pl_module=PL_MODULE, ckpt=MODELS[0])


shuffled_idxs, shuffled_targets = shuffle(
    np.asarray(list(range(len(val_dataset)))),
    np.asarray(val_dataset.targets),
    random_state=0,
)
all_targets = sorted(set(shuffled_targets))
class2idxs = {target: shuffled_idxs[shuffled_targets == target] for target in all_targets}

idxs = []
i = 0
while len(idxs) < num_samples:
    for target, target_idxs in class2idxs.items():
        idxs.append(target_idxs[i])
        if len(anchor_indices) == K:
            break
    i += 1


images = []
targets = []
indexes = []
classes = []
for idx in idxs:
    sample = val_dataset[idx]
    indexes.append(sample["index"].item())
    images.append(sample["image"])
    targets.append(sample["target"])
    classes.append(sample["class"])
images_batch = torch.stack(images, dim=0)
images_batch.shape

# Visualize

In [None]:
N_ROWS = 1
N_COLS = images_batch.shape[0]

plt.rcParams.update(bundles.icml2022())
plt.rcParams.update(figsizes.icml2022_full(ncols=N_COLS, nrows=N_ROWS, height_to_width_ratio=1.0))


fig, ax = plt.subplots(
    1,
    1,
    dpi=150,
    sharey=False,
    sharex=False,
)


def plot_images(ax, images: torch.Tensor, title: Optional[str] = None):
    images = images.cpu().detach()
    if title is not None:
        ax.set_title(title)
    ax.axis("off")
    ax.set_aspect("equal")
    ax.imshow(torchvision.utils.make_grid(images.cpu(), 20, 5).permute(1, 2, 0))


plot_images(ax, images_batch)

# AE

In [None]:
MODELS = checkpoints["mnist"]["ae"]
MODELS

In [None]:
ae1, _ = parse_checkpoint(
    module_class=PL_MODULE,
    checkpoint_path=MODELS[0],
    map_location="cpu",
)

In [None]:
ae2, _ = parse_checkpoint(
    module_class=PL_MODULE,
    checkpoint_path=MODELS[1],
    map_location="cpu",
)

In [None]:
reconstructions = ae1(images_batch)[Output.RECONSTRUCTION]

fig, ax = plt.subplots(
    1,
    1,
    dpi=150,
    sharey=False,
    sharex=False,
)
plot_images(ax, reconstructions)

In [None]:
reconstructions = ae2(images_batch)[Output.RECONSTRUCTION]
fig, ax = plt.subplots(
    1,
    1,
    dpi=150,
    sharey=False,
    sharex=False,
)
plot_images(ax, reconstructions)

In [None]:
from rae.pl_modules.pl_stitching_module import StitchingModule

model = StitchingModule(ae1, ae2)
reconstructions = model(images_batch)[Output.RECONSTRUCTION]
fig, ax = plt.subplots(
    1,
    1,
    dpi=150,
    sharey=False,
    sharex=False,
)
plot_images(ax, reconstructions)

# RelAE

In [None]:
MODELS = checkpoints["mnist"]["rel_ae_0.5"]
MODELS

In [None]:
ae1, _ = parse_checkpoint(
    module_class=PL_MODULE,
    checkpoint_path=MODELS[0],
    map_location="cpu",
)

In [None]:
ae2, _ = parse_checkpoint(
    module_class=PL_MODULE,
    checkpoint_path=MODELS[1],
    map_location="cpu",
)

In [None]:
reconstructions = ae1(images_batch)[Output.RECONSTRUCTION]

fig, ax = plt.subplots(
    1,
    1,
    dpi=150,
    sharey=False,
    sharex=False,
)
plot_images(ax, reconstructions)

In [None]:
reconstructions = ae1(images_batch)[Output.RECONSTRUCTION]
fig, ax = plt.subplots(
    1,
    1,
    dpi=150,
    sharey=False,
    sharex=False,
)
plot_images(ax, reconstructions)

In [None]:
reconstructions = ae2(images_batch)[Output.RECONSTRUCTION]
fig, ax = plt.subplots(
    1,
    1,
    dpi=150,
    sharey=False,
    sharex=False,
)
plot_images(ax, reconstructions)

In [None]:
from rae.pl_modules.pl_stitching_module import StitchingModule

model = StitchingModule(ae1, ae2)
reconstructions = model(images_batch)[Output.RECONSTRUCTION]
fig, ax = plt.subplots(
    1,
    1,
    dpi=150,
    sharey=False,
    sharex=False,
)
plot_images(ax, reconstructions)

# VAE

In [None]:
MODELS = checkpoints["mnist"]["vae"]
MODELS

In [None]:
ae1, _ = parse_checkpoint(
    module_class=PL_MODULE,
    checkpoint_path=MODELS[0],
    map_location="cpu",
)

In [None]:
ae2, _ = parse_checkpoint(
    module_class=PL_MODULE,
    checkpoint_path=MODELS[1],
    map_location="cpu",
)

In [None]:
reconstructions = ae1(images_batch)[Output.RECONSTRUCTION]
fig, ax = plt.subplots(
    1,
    1,
    dpi=150,
    sharey=False,
    sharex=False,
)
plot_images(ax, reconstructions)

In [None]:
reconstructions = ae2(images_batch)[Output.RECONSTRUCTION]
fig, ax = plt.subplots(
    1,
    1,
    dpi=150,
    sharey=False,
    sharex=False,
)
plot_images(ax, reconstructions)

In [None]:
from rae.pl_modules.pl_stitching_module import StitchingModule

model = StitchingModule(ae1, ae2)
reconstructions = model(images_batch)[Output.RECONSTRUCTION]
fig, ax = plt.subplots(
    1,
    1,
    dpi=150,
    sharey=False,
    sharex=False,
)
plot_images(ax, reconstructions)

# RelVAE

In [None]:
MODELS = checkpoints["mnist"]["rel_vae"]
MODELS

In [None]:
ae1, _ = parse_checkpoint(
    module_class=PL_MODULE,
    checkpoint_path=MODELS[0],
    map_location="cpu",
)

In [None]:
ae2, _ = parse_checkpoint(
    module_class=PL_MODULE,
    checkpoint_path=MODELS[1],
    map_location="cpu",
)

In [None]:
reconstructions = ae1(images_batch)[Output.RECONSTRUCTION]
fig, ax = plt.subplots(
    1,
    1,
    dpi=150,
    sharey=False,
    sharex=False,
)
plot_images(ax, reconstructions)

In [None]:
reconstructions = ae2(images_batch)[Output.RECONSTRUCTION]
fig, ax = plt.subplots(
    1,
    1,
    dpi=150,
    sharey=False,
    sharex=False,
)
plot_images(ax, reconstructions)

In [None]:
from rae.pl_modules.pl_stitching_module import StitchingModule

model = StitchingModule(ae1, ae2)
reconstructions = model(images_batch)[Output.RECONSTRUCTION]
fig, ax = plt.subplots(
    1,
    1,
    dpi=150,
    sharey=False,
    sharex=False,
)
plot_images(ax, reconstructions)

# Build Figure

In [None]:
models_name = ["ae", "vae", "rel_ae", "rel_vae"]
model_a_idx = 0
model_b_idx = 1

In [None]:
def plot_images(ax, images: torch.Tensor, title: Optional[str] = None, images_per_row=10, padding=2):
    images = images.cpu().detach()
    # ax.axis("off")
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.spines["bottom"].set_visible(False)
    ax.spines["left"].set_visible(False)
    ax.get_xaxis().set_ticks([])
    ax.get_yaxis().set_ticks([])
    if title is not None:
        ax.set_ylabel(title)
    ax.set_aspect("equal")
    ax.imshow(torchvision.utils.make_grid(images.cpu(), images_per_row, padding=padding, pad_value=1).permute(1, 2, 0))

In [None]:
def plot_stitching(model_name, name1, name2, height_to_width_ratio, padding=2):
    model_a, _ = parse_checkpoint(
        module_class=PL_MODULE,
        checkpoint_path=checkpoints["mnist"][model_name][model_a_idx],
        map_location="cpu",
    )

    model_b, _ = parse_checkpoint(
        module_class=PL_MODULE,
        checkpoint_path=checkpoints["mnist"][model_name][model_b_idx],
        map_location="cpu",
    )

    recon_a = model_a(images_batch)[Output.RECONSTRUCTION]

    model_ab = StitchingModule(model_a, model_b)
    recon_ab = model_ab(images_batch)[Output.RECONSTRUCTION]

    N_ROWS = 2
    N_COLS = 1

    plt.rcParams.update(bundles.icml2022())
    plt.rcParams.update(figsizes.icml2022_full(ncols=N_COLS, nrows=N_ROWS, height_to_width_ratio=height_to_width_ratio))

    fig, [ax1, ax2] = plt.subplots(N_ROWS, N_COLS, dpi=150)
    plot_images(ax1, recon_a, title=name1, images_per_row=recon_a.shape[0], padding=padding)
    plot_images(ax2, recon_ab, title=name2, images_per_row=recon_a.shape[0], padding=padding)
    return fig

In [None]:
height_to_width_ratio = 0.07

In [None]:
N_ROWS = 1
N_COLS = 1

plt.rcParams.update(bundles.icml2022())
plt.rcParams.update(figsizes.icml2022_full(ncols=N_COLS, nrows=N_ROWS, height_to_width_ratio=height_to_width_ratio))

fig, ax = plt.subplots(
    N_ROWS,
    N_COLS,
    dpi=150,
    sharey=False,
    sharex=False,
)
plot_images(ax, images_batch, title="Source", images_per_row=images_batch.shape[0])

In [None]:
fig.savefig("source.svg", bbox_inches="tight")
!rsvg-convert -f pdf -o source.pdf source.svg
!rm source.svg

In [None]:
fig = plot_stitching("ae", "AE_11", "AE_12", height_to_width_ratio=height_to_width_ratio, padding=2)

In [None]:
fig.savefig("ae.svg", bbox_inches="tight")
!rsvg-convert -f pdf -o ae.pdf ae.svg
!rm ae.svg

In [None]:
fig = plot_stitching("vae", "VAE_11", "VAE_12", height_to_width_ratio=height_to_width_ratio)

In [None]:
fig.savefig("vae.svg", bbox_inches="tight")
!rsvg-convert -f pdf -o vae.pdf vae.svg
!rm vae.svg

In [None]:
fig = plot_stitching("rel_ae", "RAE_11", "RAE_12", height_to_width_ratio=height_to_width_ratio)

In [None]:
fig.savefig("rel_ae.svg", bbox_inches="tight")
!rsvg-convert -f pdf -o rel_ae.pdf rel_ae.svg
!rm rel_ae.svg

In [None]:
fig = plot_stitching("rel_vae", "RVAE_11", "RVAE_12", height_to_width_ratio=height_to_width_ratio)

In [None]:
fig.savefig("rel_vae.svg", bbox_inches="tight")
!rsvg-convert -f pdf -o rel_vae.pdf rel_vae.svg
!rm rel_vae.svg