In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torchvision

from rae.modules.enumerations import Output
from rae.pl_modules.pl_gautoencoder import LightningAutoencoder
from rae.utils.evaluation import parse_checkpoints_tree, parse_checkpoint

try:
    # be ready for 3.10 when it drops
    from enum import StrEnum
except ImportError:
    from backports.strenum import StrEnum
from rae.utils.evaluation import get_dataset
import matplotlib.pyplot as plt
from tueplots import bundles
from tueplots import figsizes
import logging
from typing import Optional

import numpy as np
import torch
from sklearn.utils import shuffle

from nn_core.common import PROJECT_ROOT

logging.getLogger().setLevel(logging.ERROR)


BATCH_SIZE = 256


EXPERIMENT_ROOT = PROJECT_ROOT / "experiments" / "sec:model-reusability-ae"
EXPERIMENT_CHECKPOINTS = EXPERIMENT_ROOT / "checkpoints"

checkpoints, RUNS = parse_checkpoints_tree(EXPERIMENT_CHECKPOINTS)

In [None]:
def plot_images(ax, images: torch.Tensor, title: Optional[str] = None, images_per_row=10, padding=2, resize=None):
    if resize is not None:
        images = resize(images)
    images = images.cpu().detach()
    ax.imshow(torchvision.utils.make_grid(images.cpu(), images_per_row, padding=padding, pad_value=1).permute(1, 2, 0))

# Dataset

In [None]:
%%capture

PL_MODULE = LightningAutoencoder

num_samples = 20

mnist = get_dataset(pl_module=PL_MODULE, ckpt=checkpoints["mnist"]["ae"][0])
fmnist = get_dataset(pl_module=PL_MODULE, ckpt=checkpoints["fmnist"]["ae"][0])
cifar10 = get_dataset(pl_module=PL_MODULE, ckpt=checkpoints["cifar10"]["ae"][0])

In [None]:
from pprint import pprint
from pytorch_lightning import seed_everything

seed_everything(0)


def get_class2idx(dataset, k: int = 10):
    shuffled_idxs, shuffled_targets = shuffle(
        np.asarray(list(range(len(dataset)))),
        np.asarray(dataset.targets),
        random_state=0,
    )
    all_targets = sorted(set(shuffled_targets))
    class2idxs = {target: shuffled_idxs[shuffled_targets == target][:k] for target in all_targets}
    return class2idxs


mnist_class2idx = get_class2idx(mnist)
fmnist_class2idx = get_class2idx(fmnist)
cifar10_class2idx = get_class2idx(cifar10)


# Sample selection

In [None]:
mnist_class2idx

In [None]:
fmnist_class2idx

In [None]:
cifar10_class2idx

In [None]:
[mnist_class2idx[x][0] for x in mnist_class2idx]

In [None]:
from torch.utils.data import default_collate

mnist_idxs = [8225, 7407, 4721, 8940, 2846, 5334, 598]  # [mnist_class2idx[x][6] for x in mnist_class2idx]
fmnist_idxs = [19, 8940, 2702, 2606, 6734, 382, 122]  # + [fmnist_class2idx[x][1] for x in mnist_class2idx]
cifar10_idxs = [11, 60, 6, 84, 98, 8940, 2606]  # + [cifar10_class2idx[x][7] for x in mnist_class2idx]

batch_mnist = default_collate([mnist[i] for i in mnist_idxs])
batch_fmnist = default_collate([fmnist[i] for i in fmnist_idxs])
batch_cifar10 = default_collate([cifar10[i] for i in cifar10_idxs])


fig, [ax1, ax2, ax3] = plt.subplots(
    3,
    1,
    dpi=150,
)
plot_images(
    ax1,
    batch_mnist["image"],
)
plot_images(
    ax2,
    batch_fmnist["image"],
)
plot_images(
    ax3,
    batch_cifar10["image"],
)

plot_images(ax1, batch_mnist['image'], )
# Visualize

In [None]:
from rae.pl_modules.pl_stitching_module import StitchingModule
from torchvision.transforms import Resize

resize = Resize((28, 28))


def plot_images(ax, images: torch.Tensor, title: Optional[str] = None, images_per_row=10, padding=2, resize=None):
    ax.axis("off")
    ax.set_aspect("equal")

    if resize is not None:
        images = resize(images)
    images = images.cpu().detach()
    ax.imshow(torchvision.utils.make_grid(images.cpu(), images_per_row, padding=padding, pad_value=1).permute(1, 2, 0))


def plot_stitching(ax, ckpt_a, ckpt_b, images, padding=2, resize=resize):
    model_a, _ = parse_checkpoint(
        module_class=PL_MODULE,
        checkpoint_path=ckpt_a,
        map_location="cpu",
    )

    model_b, _ = parse_checkpoint(
        module_class=PL_MODULE,
        checkpoint_path=ckpt_b,
        map_location="cpu",
    )
    recon_a = model_a(images)[Output.RECONSTRUCTION]
    model_ab = StitchingModule(model_a, model_b)
    recon_ab = model_ab(images)[Output.RECONSTRUCTION]

    plot_images(ax, torch.cat([recon_a, recon_ab]), images_per_row=recon_a.shape[0], padding=padding, resize=resize)

In [None]:
N_ROWS = 1
N_COLS = 3
RATIO = 0.3

plt.rcParams.update(bundles.icml2022())
plt.rcParams.update(figsizes.icml2022_full(ncols=N_COLS, nrows=N_ROWS, height_to_width_ratio=RATIO))


fig, [source_mnist_ax, source_fmnist_ax, source_cifar10_ax] = plt.subplots(
    N_ROWS,
    N_COLS,
    dpi=300,
    sharey=False,
    sharex=True,
    # constrained_layout=True
)
fig.subplots_adjust(hspace=0.02, wspace=0.01)


plot_images(source_mnist_ax, batch_mnist["image"], resize=resize)
plot_images(source_fmnist_ax, batch_fmnist["image"], resize=resize)
plot_images(source_cifar10_ax, batch_cifar10["image"], resize=resize)

In [None]:
fig.savefig("source.svg", bbox_inches="tight", pad_inches=0)
!rsvg-convert -f pdf -o source.pdf source.svg
!rm source.svg

In [None]:
import matplotlib.gridspec as gridspec

N_ROWS = 4
N_COLS = 3
RATIO = 0.3

plt.rcParams.update(bundles.icml2022())
plt.rcParams.update(figsizes.icml2022_full(ncols=N_COLS, nrows=N_ROWS, height_to_width_ratio=RATIO))


fig, [
    [ae_mnist_ax, ae_fmnist_ax, ae_cifar10_ax],
    [vae_mnist_ax, vae_fmnist_ax, vae_cifar10_ax],
    [relae_mnist_ax, relae_fmnist_ax, relae_cifar10_ax],
    [relvae_mnist_ax, relvae_fmnist_ax, relvae_cifar10_ax],
] = plt.subplots(
    N_ROWS,
    N_COLS,
    dpi=300,
    sharey=False,
    sharex=True,
    # constrained_layout=True
)
fig.subplots_adjust(hspace=0.02, wspace=0.01)

# fig.tight_layout()
# fig.subplots_adjust(hspace=0, wspace=0)
# fig.subplots_adjust(hspace = .001)

plot_stitching(
    ae_mnist_ax,
    checkpoints["mnist"]["ae"][0],
    checkpoints["mnist"]["ae"][1],
    batch_mnist["image"],
    padding=2,
    resize=resize,
)
plot_stitching(
    ae_fmnist_ax,
    checkpoints["fmnist"]["ae"][0],
    checkpoints["fmnist"]["ae"][1],
    batch_fmnist["image"],
    padding=2,
    resize=resize,
)
plot_stitching(
    ae_cifar10_ax,
    checkpoints["cifar10"]["ae"][0],
    checkpoints["cifar10"]["ae"][1],
    batch_cifar10["image"],
    padding=2,
    resize=resize,
)

plot_stitching(
    vae_mnist_ax,
    checkpoints["mnist"]["vae"][0],
    checkpoints["mnist"]["vae"][1],
    batch_mnist["image"],
    padding=2,
    resize=resize,
)
plot_stitching(
    vae_fmnist_ax,
    checkpoints["fmnist"]["vae"][0],
    checkpoints["fmnist"]["vae"][1],
    batch_fmnist["image"],
    padding=2,
    resize=resize,
)
plot_stitching(
    vae_cifar10_ax,
    checkpoints["cifar10"]["vae"][0],
    checkpoints["cifar10"]["vae"][1],
    batch_cifar10["image"],
    padding=2,
    resize=resize,
)

plot_stitching(
    relae_mnist_ax,
    checkpoints["mnist"]["rel_ae"][0],
    checkpoints["mnist"]["rel_ae"][1],
    batch_mnist["image"],
    padding=2,
    resize=resize,
)
plot_stitching(
    relae_fmnist_ax,
    checkpoints["fmnist"]["rel_ae"][0],
    checkpoints["fmnist"]["rel_ae"][1],
    batch_fmnist["image"],
    padding=2,
    resize=resize,
)
plot_stitching(
    relae_cifar10_ax,
    checkpoints["cifar10"]["rel_ae"][0],
    checkpoints["cifar10"]["rel_ae"][1],
    batch_cifar10["image"],
    padding=2,
    resize=resize,
)

plot_stitching(
    relvae_mnist_ax,
    checkpoints["mnist"]["rel_vae"][0],
    checkpoints["mnist"]["rel_vae"][1],
    batch_mnist["image"],
    padding=2,
    resize=resize,
)
plot_stitching(
    relvae_fmnist_ax,
    checkpoints["fmnist"]["rel_vae"][0],
    checkpoints["fmnist"]["rel_vae"][1],
    batch_fmnist["image"],
    padding=2,
    resize=resize,
)
plot_stitching(
    relvae_cifar10_ax,
    checkpoints["cifar10"]["rel_vae"][0],
    checkpoints["cifar10"]["rel_vae"][1],
    batch_cifar10["image"],
    padding=2,
    resize=resize,
)

In [None]:
fig.savefig("stitching.svg", bbox_inches="tight", pad_inches=0)
!rsvg-convert -f pdf -o stitching.pdf stitching.svg
!rm stitching.svg