# Imports

In [None]:
import random

from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import transforms

from la.data.datamodule import MetaData

import torch
from pathlib import Path
from pytorch_lightning import seed_everything
from torch.nn import functional as F


from la.data.dataset import MyDataset
from la.modules.module import CNN
from la.pl_modules.pl_module import MyLightningModule

try:
    # be ready for 3.10 when it drops
    from enum import StrEnum
except ImportError:
    from backports.strenum import StrEnum
from enum import auto
from nn_core.common import PROJECT_ROOT

import hdf5storage
from torch.nn.functional import mse_loss, pairwise_distance
from torchmetrics.functional import pearson_corrcoef, spearman_corrcoef

In [None]:
from tueplots import bundles

seed_everything(43)
bundles.icml2022()

# Load data

In [None]:
digit1, digit2 = 4, 6

In [None]:
data_path = PROJECT_ROOT / "data" / "MNIST"
transform = transforms.Compose([transforms.ToTensor()])
mnist = MNIST(
    data_path,
    train=False,
    download=True,
    transform=transform,
)

# Load models

In [None]:
from hydra.core.global_hydra import GlobalHydra
from hydra import compose, initialize
from omegaconf import OmegaConf

GlobalHydra.instance().clear()
initialize(config_path="../conf")
cfg = compose(config_name="prelim_exp")

In [None]:
metadata = MetaData(class_vocab=mnist.class_to_idx)
print(mnist.class_to_idx)

In [None]:
model1: MyLightningModule = MyLightningModule.load_from_checkpoint(
    "checkpoints/missing_6/checkpoint.ckpt", metadata=metadata
)
model1.eval()

In [None]:
model2: MyLightningModule = MyLightningModule.load_from_checkpoint(
    "checkpoints/missing_9/checkpoint.ckpt", metadata=metadata
)
model2.eval()

# Embed samples

In [None]:
num_samples = 5000
dataset = MyDataset(samples=mnist, split="test", class_vocab=mnist.class_to_idx)
dataloader = DataLoader(dataset, shuffle=False, batch_size=16)

In [None]:
embeds1 = []
for batch in dataloader:
    embeds1.append(model1.model(batch["x"])["embeds"])
embeds1 = torch.cat(embeds1, dim=0)

In [None]:
embeds2 = []
for batch in dataloader:
    embeds2.append(model2.model(batch["x"])["embeds"])
embeds2 = torch.cat(embeds2, dim=0)

# Analysis

In [None]:
CMAP = "jet"

In [None]:
sample_limit: int = 5000

In [None]:
# (num_shapes, latent_dim)
abs_space1 = embeds1.detach()[:sample_limit]
abs_space2 = embeds2.detach()[:sample_limit]

In [None]:
targets = mnist.targets[:num_samples]

digit1_mask = targets == digit1
digit2_mask = targets == digit2

unseen_classes_mask = digit1_mask | digit2_mask
seen_classes_mask = ~unseen_classes_mask

In [None]:
targets_seen_classes = targets[seen_classes_mask]
abs_space1_seen_classes = abs_space1[seen_classes_mask]
abs_space2_seen_classes = abs_space2[seen_classes_mask]

In [None]:
targets_unseen_classes = targets[unseen_classes_mask]
abs_space1_unseen_classes = abs_space1[unseen_classes_mask]
abs_space2_unseen_classes = abs_space2[unseen_classes_mask]

In [None]:
print(targets.unique())

## Sort items by digit label

In [None]:
sort_indices: torch.Tensor = targets_unseen_classes.sort().indices
abs_space1_unseen_classes: torch.Tensor = abs_space1_unseen_classes[sort_indices, :]
abs_space2_unseen_classes: torch.Tensor = abs_space2_unseen_classes[sort_indices, :]
labels: torch.Tensor = targets_unseen_classes[sort_indices]

assert abs_space1_unseen_classes.shape == abs_space2_unseen_classes.shape
assert abs_space1_unseen_classes.size(0) == labels.size(0)
abs_space1_unseen_classes.shape

## Anchor selection
Only pick anchors among shared classes.

In [None]:
num_samples, embedding_dim = abs_space1_seen_classes.size()
num_anchors: int = embedding_dim

anchor_idxs = list(range(num_samples))
random.shuffle(anchor_idxs)
anchor_idxs = anchor_idxs[:num_anchors]

In [None]:
abs_space1 = abs_space1_unseen_classes
abs_space2 = abs_space2_unseen_classes

In [None]:
norm_abs_space1: torch.Tensor = F.normalize(abs_space1, p=2, dim=-1)
norm_abs_space2: torch.Tensor = F.normalize(abs_space2, p=2, dim=-1)

assert norm_abs_space1.shape == norm_abs_space2.shape

In [None]:
norm_abs_space1_seen_classes = F.normalize(abs_space1_seen_classes, p=2, dim=-1)
norm_abs_space2_seen_classes = F.normalize(abs_space2_seen_classes, p=2, dim=-1)

In [None]:
space1_anchors = norm_abs_space1_seen_classes[anchor_idxs]
space2_anchors = norm_abs_space2_seen_classes[anchor_idxs]

In [None]:
from la.utils.relative_analysis import plot_pairwise_dist

plot_pairwise_dist(space1=abs_space1, space2=abs_space2, prefix="Absolute")

In [None]:
from la.utils.relative_analysis import self_sim_comparison

self_sim_comparison(space1=abs_space1, space2=abs_space2, normalize=True)

In [None]:
from la.utils.relative_analysis import plot_self_dist

plot_self_dist(space1=abs_space1, space2=abs_space2, prefix="Absolute")

## Relative projection

In [None]:
rel_space1 = norm_abs_space1 @ space1_anchors.T
rel_space2 = norm_abs_space2 @ space2_anchors.T

In [None]:
from la.utils.relative_analysis import plot_pairwise_dist

plot_pairwise_dist(space1=rel_space1, space2=rel_space2, prefix="Relative")

In [None]:
from la.utils.relative_analysis import self_sim_comparison

self_sim_comparison(space1=rel_space1, space2=rel_space2, normalize=False)

In [None]:
from la.utils.relative_analysis import plot_self_dist

plot_self_dist(space1=rel_space1, space2=rel_space2, prefix="Relative")

In [None]:
from la.utils.relative_analysis import Reduction, reduce

x_header = [reduction.upper() for reduction in Reduction]
y_header = ["Absolute Space 1", "Absolute Space 2", "Relative Space 1", "Relative Space 2"]

spaces = [
    [
        *reduce(space1=abs_space1, space2=abs_space2, reduction=reduction),
        *reduce(space1=rel_space1, space2=rel_space2, reduction=reduction),
    ]
    for reduction in Reduction
]

In [None]:
from la.utils.relative_analysis import plot_space_grid

fig = plot_space_grid(x_header=x_header, y_header=y_header, spaces=spaces, c=labels)
fig