In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import logging
from enum import auto
from pathlib import Path
from typing import Callable, Dict, Optional, Tuple, Type, Union

import numpy as np
import pandas as pd
import rich
import torch
import typer
from torchmetrics import (
    ErrorRelativeGlobalDimensionlessSynthesis,
    MeanSquaredError,
    MetricCollection,
    PeakSignalNoiseRatio,
    StructuralSimilarityIndexMeasure,
)

from nn_core.common import PROJECT_ROOT
from rae.modules.enumerations import Output
from rae.pl_modules.pl_gautoencoder import LightningAutoencoder
from rae.utils.evaluation import parse_checkpoint_id, parse_checkpoints_tree, parse_checkpoint
from collections import defaultdict

try:
    # be ready for 3.10 when it drops
    from enum import StrEnum
except ImportError:
    from backports.strenum import StrEnum

from rae.utils.evaluation import plot_latent_space
import matplotlib.pyplot as plt
from tueplots import bundles
from tueplots import figsizes

logging.getLogger().setLevel(logging.ERROR)


BATCH_SIZE = 256


EXPERIMENT_ROOT = PROJECT_ROOT / "experiments" / "fig:latent-rotation-comparison"
EXPERIMENT_CHECKPOINTS = EXPERIMENT_ROOT / "checkpoints"
PREDICTIONS_TSV = EXPERIMENT_ROOT / "predictions.tsv"
PERFORMANCE_TSV = EXPERIMENT_ROOT / "performance.tsv"

DATASET_SANITY = {
    "mnist": ("rae.data.vision.mnist.MNISTDataset", "test"),
    "fmnist": ("rae.data.vision.fmnist.FashionMNISTDataset", "test"),
    "cifar10": ("rae.data.vision.cifar10.CIFAR10Dataset", "test"),
    "cifar100": ("rae.data.vision.cifar100.CIFAR100Dataset", "test"),
}
MODEL_SANITY = {
    "vae": "rae.modules.vae.VanillaVAE",
    "ae": "rae.modules.ae.VanillaAE",
    "rel_vae": "rae.modules.rel_vae.VanillaRelVAE",
    "rel_ae": "rae.modules.rel_ae.VanillaRelAE",
}


checkpoints, RUNS = parse_checkpoints_tree(EXPERIMENT_CHECKPOINTS)
checkpoints

In [None]:
from rae.utils.evaluation import parse_checkpoint


def get_latents(images_batch, model, key=Output.DEFAULT_LATENT, return_df: bool = True):
    latents = model(images_batch)[key].detach().squeeze()
    latents2d = latents[:, [0, 1]]
    df = None
    if return_df:
        df = pd.DataFrame(
            {
                "x": latents2d[:, 0].tolist(),
                "y": latents2d[:, 1].tolist(),
                "class": classes,
                "target": targets,
                "index": indexes,
            }
        )
    return df, latents

# Latent Rotations

In [None]:
MODELS = checkpoints["mnist"]["ae"]

In [None]:
PL_MODULE = LightningAutoencoder

## Images

In [None]:
from rae.utils.evaluation import get_dataset

images = []
targets = []
indexes = []
classes = []

from pytorch_lightning import seed_everything

seed_everything(0)

val_dataset = get_dataset(pl_module=PL_MODULE, ckpt=MODELS[0])
K = 2_000
idxs = torch.randperm(len(val_dataset))[:K]

for idx in idxs:
    sample = val_dataset[idx]
    indexes.append(sample["index"].item())
    images.append(sample["image"])
    targets.append(sample["target"])
    classes.append(sample["class"])

images_batch = torch.stack(images, dim=0)
images_batch.shape

In [None]:
LIM = 2
N_ROWS = 1
N_COLS = LIM

plt.rcParams.update(bundles.icml2022())
plt.rcParams.update(figsizes.icml2022_full(ncols=N_COLS, nrows=N_ROWS, height_to_width_ratio=1.0))

cmap = plt.cm.get_cmap("Set1", 10)
norm = plt.Normalize(min(targets), max(targets))


def plot_row(df, title, equal=True, sharey=False, sharex=False, dpi=150):
    fig, axes = plt.subplots(dpi=dpi, nrows=N_ROWS, ncols=N_COLS, sharey=sharey, sharex=sharex, squeeze=True)

    for j, ax in enumerate(axes):
        if j == 0:
            ax.set_title(title)
        if equal:
            ax.set_aspect("equal")
        plot_latent_space(ax, df[j], targets=[0, 1], size=0.75, bg_alpha=0.25, alpha=1, cmap=cmap, norm=norm)
    return fig

In [None]:
def latents_distance(latents):
    dists = []
    for i in range(len(latents)):
        for j in range(i + 1, len(latents)):
            x = latents[i][1]
            y = latents[j][1]
            # dist = ((x - y)**2).sum(dim=-1).sqrt().mean()
            dist = F.mse_loss(x, y, reduction="sum")
            # dist = ((x - y) ** 2).mean(dim=-1).mean()
            dists.append(dist)
    return sum(dists) / len(dists)

## Anchors

In [None]:
model_rel, _ = parse_checkpoint(
    module_class=PL_MODULE,
    checkpoint_path=checkpoints["mnist"]["rel_ae"][0],
    map_location="cpu",
)

In [None]:
anchors_batch = model_rel.metadata.anchors_images
anchors_batch.shape

## AE

In [None]:
MODELS = checkpoints["mnist"]["ae"]
MODELS

In [None]:
ae_latents = []
anchors_latents = []
for ckpt in MODELS:
    model, _ = parse_checkpoint(
        module_class=PL_MODULE,
        checkpoint_path=ckpt,
        map_location="cpu",
    )
    df, latents = get_latents(images_batch, model, return_df=True)
    _, a_latents = get_latents(anchors_batch, model, return_df=False)
    ae_latents.append((df, latents))
    anchors_latents.append(a_latents)

import copy

original_ae_latents = copy.deepcopy(ae_latents)
original_anchor_latents = copy.deepcopy(anchors_latents)

In [None]:
f = plot_row([df for df, _ in original_ae_latents[:LIM]], "AE", True, True, True)
latents_distance(ae_latents[:LIM])

## Random isometry

# Rel Attention

## Rel Attention Quantized

In [None]:
from sklearn.decomposition import PCA

# Absolut
from scipy.stats import ortho_group


import torch.nn.functional as F

raw_latents = original_ae_latents[0][1]
raw_anchor_latents = original_anchor_latents[0]


anchors_latents = [raw_anchor_latents]
ae_latents = [
    (
        pd.DataFrame(
            {
                "x": raw_latents[:, 0].tolist(),
                "y": raw_latents[:, 1].tolist(),
                "class": classes,
                "target": targets,
                "index": indexes,
            }
        ),
        raw_latents,
    )
]

for i in range(4):

    # random_isometry = torch.as_tensor(ortho_group.rvs(raw_latents.shape[-1]), dtype=torch.float)
    #
    # random_isometry = random_isometry + torch.randn_like(random_isometry) * 0.01
    # # random_isometry[0, :] += torch.randn_like(random_isometry[0])* 0.1

    transformed_latents = transform(raw_latents)
    anchors_transformed = transform(raw_anchor_latents)

    df = pd.DataFrame(
        {
            "x": transformed_latents[:, 0].tolist(),
            "y": transformed_latents[:, 1].tolist(),
            "class": classes,
            "target": targets,
            "index": indexes,
        }
    )
    ae_latents.append((df, transformed_latents))
    anchors_latents.append(anchors_transformed)

f = plot_row([df for df, _ in ae_latents[:LIM]], f"AE: {latents_distance(ae_latents[:LIM])}", True, True, True)

# Relative
from rae.modules.attention import *

# Qunatized
for quant_mode, bin_size in (
    (None, None),
    (SimilaritiesQuantizationMode.DIFFERENTIABLE_ROUND, 0.1),
    (SimilaritiesQuantizationMode.DIFFERENTIABLE_ROUND, 0.3),
):
    rel_latents = []
    rel_attention = RelativeAttention(
        n_anchors=anchors_batch.shape,
        n_classes=len(set(targets)),
        similarity_mode=RelativeEmbeddingMethod.INNER,
        values_mode=ValuesMethod.SIMILARITIES,
        normalization_mode=NormalizationMode.L2,
        output_normalization_mode=OutputNormalization.NONE,
        similarities_quantization_mode=quant_mode,
        similarities_bin_size=bin_size,
        # absolute_quantization_mode=quant_mode,
        # absolute_bin_size=bin_size
    )
    assert sum(x.numel() for x in rel_attention.parameters()) == 0
    for (_, latents), a_latents in zip(ae_latents, anchors_latents):
        rel = rel_attention(x=latents, anchors=a_latents)[AttentionOutput.SIMILARITIES]
        rellatents2d = rel[:, [0, 1]]
        pca = PCA(n_components=2)
        rellatents2d = pca.fit(rel.detach())
        rellatents2d = pca.transform(rel.detach())
        df = pd.DataFrame(
            {
                "x": rellatents2d[:, 0].tolist(),
                "y": rellatents2d[:, 1].tolist(),
                "class": classes,
                "target": targets,
                "index": indexes,
            }
        )
        rel_latents.append((df, rel))
    f = plot_row(
        [df for df, _ in rel_latents[:LIM]],
        f"QAtt, bin size: {bin_size}: {latents_distance(rel_latents[:LIM])}",
        True,
        True,
        True,
    )

# Optimal transofrm

In [None]:
from rae.modules.attention import *

ae, _ = parse_checkpoint(
    module_class=PL_MODULE,
    checkpoint_path=checkpoints["mnist"]["ae"][0],
    map_location="cpu",
)

att = RelativeAttention(
    n_anchors=anchors_batch.shape,
    n_classes=len(set(targets)),
    similarity_mode=RelativeEmbeddingMethod.INNER,
    values_mode=ValuesMethod.SIMILARITIES,
    normalization_mode=NormalizationMode.L2,
    output_normalization_mode=OutputNormalization.NONE,
    similarities_quantization_mode=None,
    similarities_bin_size=None,
    # absolute_quantization_mode=AbsoluteQuantizationMode.DIFFERENTIABLE_ROUND,
    # absolute_bin_size=bin_size
)
att_q = RelativeAttention(
    n_anchors=anchors_batch.shape,
    n_classes=len(set(targets)),
    similarity_mode=RelativeEmbeddingMethod.INNER,
    values_mode=ValuesMethod.SIMILARITIES,
    normalization_mode=NormalizationMode.L2,
    output_normalization_mode=OutputNormalization.NONE,
    similarities_quantization_mode=SimilaritiesQuantizationMode.CUSTOM_ROUND,
    similarities_bin_size=0.1,
    # absolute_quantization_mode=AbsoluteQuantizationMode.DIFFERENTIABLE_ROUND,
    # absolute_bin_size=0.1
)

ae.eval()
images_z = ae(images_batch)[Output.DEFAULT_LATENT].detach()
anchors_z = ae(anchors_batch)[Output.DEFAULT_LATENT].detach()

In [None]:
from tqdm import tqdm
import torch
from torch.optim.adam import Adam
from sklearn.decomposition import PCA

# Absolut
from scipy.stats import ortho_group

opt_isometry = torch.tensor(ortho_group.rvs(images_z.shape[-1]), dtype=torch.float, requires_grad=True)
opt_shift = torch.zeros(images_z.shape[-1], dtype=torch.float, requires_grad=True)


opt = Adam([opt_isometry, opt_shift], lr=1e-4)


def transform(x):
    return x @ opt_isometry + opt_shift


R = 1000
Q = 1
I = 1000
S = 0
for i in (bar := tqdm(range(100))):

    rel = att(x=images_z, anchors=anchors_z)[AttentionOutput.SIMILARITIES]
    rel_iso = att(x=transform(images_z), anchors=transform(anchors_z))[AttentionOutput.SIMILARITIES]
    rel_dist = F.mse_loss(rel, rel_iso, reduction="sum")
    rel_loss = -rel_dist * R

    qrel = att_q(x=images_z, anchors=anchors_z)[AttentionOutput.SIMILARITIES]
    qrel_iso = att_q(x=transform(images_z), anchors=transform(anchors_z))[AttentionOutput.SIMILARITIES]
    qrel_dist = F.mse_loss(qrel, rel_iso, reduction="sum")
    qrel_loss = qrel_dist * Q

    t_temp = opt_isometry @ opt_isometry.T
    iso_loss = ((t_temp - t_temp.diag().diag()) ** 2).sum() * I
    # iso_loss = (t_temp ** 2 - torch.eye(t_temp.shape[0])).sum() * I
    shift_loss = opt_shift.abs().sum() * S
    loss = rel_loss + qrel_loss + iso_loss + shift_loss

    bar.set_description(f"Rel: {rel_loss.item():3f} \t Qua: {qrel_loss.item():3f} \t  Iso: {iso_loss.item():3f}")
    loss.backward()
    opt.step()
    opt.zero_grad()

rel = att(x=images_z, anchors=anchors_z)[AttentionOutput.SIMILARITIES]
rel_iso = att(x=transform(images_z), anchors=transform(anchors_z))[AttentionOutput.SIMILARITIES]
print("Relative mse:", F.mse_loss(rel, rel_iso, reduction="sum"))

qrel = att_q(x=images_z, anchors=anchors_z)[AttentionOutput.SIMILARITIES]
qrel_iso = att_q(x=transform(images_z), anchors=transform(anchors_z))[AttentionOutput.SIMILARITIES]
print("Quantized mse:", F.mse_loss(qrel, qrel_iso, reduction="sum"))

In [None]:
ae.eval()
images_z = ae(images_batch)[Output.DEFAULT_LATENT].detach()
anchors_z = ae(anchors_batch)[Output.DEFAULT_LATENT].detach()


rel = att(x=images_z, anchors=anchors_z)[AttentionOutput.SIMILARITIES]
rel_iso = att(x=transform(images_z), anchors=transform(anchors_z))[AttentionOutput.SIMILARITIES]
print("Relative mse:", F.mse_loss(rel, rel_iso, reduction="sum"))

qrel = att_q(x=images_z, anchors=anchors_z)[AttentionOutput.SIMILARITIES]
qrel_iso = att_q(x=transform(images_z), anchors=transform(anchors_z))[AttentionOutput.SIMILARITIES]
print("Quantized mse:", F.mse_loss(qrel, qrel_iso, reduction="sum"))

In [None]:
qrel

In [None]:
import numpy as np

b = torch.as_tensor(0.5)
x = torch.linspace(-1, 1, 200)
y = x - torch.sin(2 * torch.pi * x) / (2 * torch.pi)

a = 1
f = 1 / b
s = 0
y = x - a * torch.cos(2 * torch.pi * f * x + s) / (2 * torch.pi * f)

fig, ax = plt.subplots(1, 1, dpi=150)
f = ax.plot(
    x,
    y,
)