In [None]:
__file__ = "."

In [None]:
import torchvision

from rae.modules.enumerations import Output
from rae.pl_modules.pl_gautoencoder import LightningAutoencoder
from rae.utils.evaluation import parse_checkpoints_tree, parse_checkpoint

try:
    # be ready for 3.10 when it drops
    from enum import StrEnum
except ImportError:
    from backports.strenum import StrEnum
from rae.utils.evaluation import get_dataset
import matplotlib.pyplot as plt
from tueplots import bundles
from tueplots import figsizes
import logging
from typing import Optional

import numpy as np
import torch
from sklearn.utils import shuffle
from torch.utils.data import DataLoader

from nn_core.common import PROJECT_ROOT
from pytorch_lightning import seed_everything

seed_everything(0)

logging.getLogger().setLevel(logging.ERROR)


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
EXPERIMENT_ROOT = PROJECT_ROOT / "experiments" / "sec:data-manifold"
EXPERIMENT_CHECKPOINTS = EXPERIMENT_ROOT / "checkpoints"
checkpoints, RUNS = parse_checkpoints_tree(EXPERIMENT_CHECKPOINTS)
ckpt = checkpoints["fmnist"]["ae"][0]

In [None]:
import hydra

PL_MODULE = LightningAutoencoder


def get_dataset(pl_module, ckpt):
    _, cfg = parse_checkpoint(
        module_class=pl_module,
        checkpoint_path=ckpt,
        map_location="cpu",
    )
    datamodule = hydra.utils.instantiate(cfg.nn.data, _recursive_=False)
    datamodule.setup()
    train_dataset = datamodule.train_dataset
    val_dataset = datamodule.val_datasets[0]
    return train_dataset, val_dataset, datamodule.metadata


train_dataset, test_dataset, metadata = get_dataset(pl_module=PL_MODULE, ckpt=ckpt)

In [None]:
train_dl = DataLoader(train_dataset, batch_size=256, shuffle=True, pin_memory=True, num_workers=8)
test_dl = DataLoader(test_dataset, batch_size=256, shuffle=False, pin_memory=True, num_workers=8)

In [None]:
anchors = metadata.anchors_images.to(DEVICE)

In [None]:
import itertools

from rae.modules.blocks import build_dynamic_encoder_decoder
from torch.nn import CrossEntropyLoss, MSELoss
from torch.optim import Adam
import torch
from tqdm import tqdm
from torch import nn
from pytorch_lightning import seed_everything
from torch.utils.data import TensorDataset, DataLoader
from torch.nn import functional as F
import math
from rae.modules.rel_ae import VanillaRelAE
from rae.modules.attention import RelativeAttention


def fit(dataset_dl, lr=1e-3, epochs=1, seed=0, hidden_dims=[3, 6, 12, 24], batch_lim=100):
    seed_everything(seed)
    model = VanillaRelAE(
        metadata=metadata,
        input_size=None,
        latent_dim=None,
        hidden_dims=hidden_dims,
        relative_attention=RelativeAttention(
            n_anchors=anchors.shape[0],
            n_classes=len(metadata.class_to_idx),
            similarity_mode="inner",
            values_mode="similarities",
            normalization_mode="l2",
        ),
        remove_encoder_last_activation=False,
    )

    model = model.to(DEVICE)
    opt = Adam(model.parameters(), lr=lr)
    loss_fn = MSELoss()
    for epoch in (tqdm_bar := tqdm(range(epochs), leave=False, desc="epoch")):
        for batch in itertools.islice(dataset_dl, batch_lim):
            batch_x = batch["image"].to(DEVICE, non_blocking=True)
            pred_y = model.decode(**model.encode(batch_x))["reconstruction"]
            loss = loss_fn(pred_y, batch_x)
            loss.backward()
            opt.step()
            opt.zero_grad()
        tqdm_bar.set_description(f"Loss: {loss:2f}")
    model = model.eval().cpu()

    return model, loss.cpu().item()


best_model, best_loss = fit(train_dl)

In [None]:
def plot_images(ax, images: torch.Tensor, title: Optional[str] = None, images_per_row=10, padding=2, resize=None):
    ax.axis("off")
    ax.set_aspect("equal")

    if resize is not None:
        images = resize(images)
    images = images.cpu().detach()
    ax.imshow(torchvision.utils.make_grid(images.cpu(), images_per_row, padding=padding, pad_value=1).permute(1, 2, 0))

In [None]:
fig, [source, pred] = plt.subplots(
    2,
    1,
    dpi=150,
    sharey=False,
    sharex=False,
)
plot_images(source, anchors.cpu()[:10])
plot_images(pred, best_model.decode(**best_model.encode(anchors.cpu()[:10]))["reconstruction"])

In [None]:
models2loss = []


for epoch, lr, seed, batch_lim in tqdm(itertools.product([1, 2], [1e-5, 1e-3, 1e-1], [1, 2, 3], [10, 100, None])):
    models2loss.append(fit(train_dl))

In [None]:
models2loss = sorted(models2loss, key=lambda x: x[1])

In [None]:
best_model, best_loss = models2loss[0]
best_similarities = best_model.encode(anchors.cpu())["similarities"]
best_similarities.shape

In [None]:
def latents_distance(latents1, latents2):
    dist = F.pairwise_distance(latents1, latents2, p=2).mean().item()
    return dist

In [None]:
dists_to_best = []
losses = []
for model, loss in models2loss:
    similarities = model.encode(anchors.cpu())["similarities"]

    dists_to_best.append(latents_distance(similarities, best_similarities))
    losses.append(loss)

In [None]:
plt.rcParams.update({"figure.dpi": 300})
plt.rcParams.update(bundles.icml2022())
plt.rcParams.update(figsizes.icml2022_full(ncols=1, nrows=1, height_to_width_ratio=0.4))

fig, ax = plt.subplots(nrows=1, ncols=1, sharey=True, sharex=True, squeeze=True)


ax.scatter(x=dists_to_best, y=losses, s=5)
ax.set_xlabel("Distances to best model")
ax.set_ylabel("Loss")
# ax.set_title("title")

In [None]:
fig.savefig("lossVSdist.svg", bbox_inches="tight", pad_inches=0)
!rsvg-convert -f pdf -o lossVSdist.pdf lossVSdist.svg
!rm lossVSdist.svg