## Imports

In [None]:
import random

import pytorch_lightning
import torchmetrics
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import transforms
from tqdm import tqdm
import torch
from pathlib import Path
from pytorch_lightning import seed_everything
from torch.nn import functional as F

from la.modules.efficient_net import MyEfficientNet
from la.pl_modules.pl_module import MyLightningModule
from la.utils.utils import MyDatasetDict
from backports.strenum import StrEnum
from enum import auto
from nn_core.common import PROJECT_ROOT

import hdf5storage
from torch.nn.functional import mse_loss, pairwise_distance
from torchmetrics.functional import pearson_corrcoef, spearman_corrcoef

from hydra.core.global_hydra import GlobalHydra
from hydra import compose, initialize
from omegaconf import OmegaConf
from datasets import concatenate_datasets

In [None]:
from tueplots import bundles

seed_everything(43)
bundles.icml2022()

## Load data


In [None]:
dataset_name = "tiny_imagenet"
num_shared_classes = 100
num_novel_classes = 20

dataset_path = f"{PROJECT_ROOT}/data/{dataset_name}/S{num_shared_classes}_N{num_novel_classes}_embedded"

In [None]:
data: MyDatasetDict = MyDatasetDict.load_from_disk(dataset_dict_path=dataset_path)
num_tasks = data["metadata"]["num_tasks"]

In [None]:
data

## Focus on training samples


### Sort the two datasets in the same way using index

In [None]:
for task_ind in range(0, num_tasks + 1):
    data[f"task_{task_ind}_train"] = data[f"task_{task_ind}_train"].sort("id")
    data[f"task_{task_ind}_test"] = data[f"task_{task_ind}_test"].sort("id")

### Subspaces, embeddings from classifiers trained on a subset of the classes


#### Map the local labels back to global

In [None]:
for task_ind in range(1, num_tasks + 1):

    global_to_local_map = data["metadata"]["global_to_local_class_mappings"][f"task_{task_ind}"]
    local_to_global_map = {v: int(k) for k, v in global_to_local_map.items()}

    data[f"task_{task_ind}_train"] = data[f"task_{task_ind}_train"].map(
        lambda row: {"y": local_to_global_map[row["y"].item()]}
    )

# Obtain anchors

### Get shared samples indices
Get the indices of samples from the shared classes, we will sample anchors only from these ones

In [None]:
num_shared_samples = data["metadata"]["num_train_samples_per_class"] * data["metadata"]["num_shared_classes"]
shared_classes = set(data["metadata"]["shared_classes"])

for task_ind in range(num_tasks + 1):
    data[f"task_{task_ind}_train"] = data[f"task_{task_ind}_train"].map(
        lambda row: {"shared": row["y"].item() in shared_classes}
    )

### Get non shared samples indices

In [None]:
num_classes_per_task = data["metadata"]["num_shared_classes"] + data["metadata"]["num_novel_classes_per_task"]
num_train_samples_per_task = data["metadata"]["num_train_samples_per_class"] * num_classes_per_task

### Sample anchor indices

In [None]:
num_anchors = 256

shared_ids = []

for task_ind in range(num_tasks + 1):
    data[f"task_{task_ind}_train"].set_format(type="torch", columns=["id"])

    # get the indices of samples having shared to True
    shared_map = data[f"task_{task_ind}_train"]["shared"]
    ids = data[f"task_{task_ind}_train"]["id"]
    task_shared_ids = ids[shared_map].tolist()
    shared_ids.append(task_shared_ids)

In [None]:
for task_i in range(num_tasks + 1):
    for task_j in range(task_i, num_tasks + 1):
        assert shared_ids[task_i] == shared_ids[task_j]

In [None]:
anchor_ids = random.sample(shared_ids[0], num_anchors)

In [None]:
for task_ind in range(num_tasks + 1):
    data[f"task_{task_ind}_train"] = data[f"task_{task_ind}_train"].map(
        lambda row: {"anchor": row["id"].item() in anchor_ids}
    )

In [None]:
centering = False
if centering:
    for task_ind in range(num_tasks + 1):
        embedding_mean = data[f"task_{task_ind}_train"]["embedding"].mean(dim=0)
        data[f"task_{task_ind}_train"] = data[f"task_{task_ind}_train"].map(
            lambda row: {"embedding": row["embedding"] - embedding_mean}
        )

### Select the anchors

In [None]:
anchors = []
for i in tqdm(range(0, num_tasks + 1)):
    data[f"task_{i}_train"].set_format(type="torch", columns=["embedding"])
    task_i_anchors = data[f"task_{i}_train"]["embedding"][data[f"task_{i}_train"]["anchor"]]
    anchors.append(task_i_anchors)

print(anchors[0].shape)

# Map to relative spaces

In [None]:
relatives = []

for task_ind in range(0, num_tasks + 1):
    task_embeddings = data[f"task_{task_ind}_train"]["embedding"]
    task_anchors = anchors[task_ind]

    abs_space = F.normalize(task_embeddings, p=2, dim=-1)
    norm_anchors = F.normalize(task_anchors, p=2, dim=-1)

    rel_space = abs_space @ norm_anchors.T

    # _, _, Vt = torch.linalg.svd(norm_anchors)

    # Project all X onto the anchor-space
    # rel_space = torch.einsum("nd,ad -> na", abs_space, Vt)

    relatives.append(rel_space)

In [None]:
from datasets import Dataset

for task_ind in range(0, num_tasks + 1):

    dataset_dict = data[f"task_{task_ind}_train"].to_dict()

    dataset_dict["relative_embeddings"] = relatives[task_ind].tolist()

    dataset = Dataset.from_dict(dataset_dict)

    data[f"task_{task_ind}_train"] = dataset

In [None]:
del relatives
del dataset_dict

In [None]:
for task_ind in range(0, num_tasks + 1):
    data[f"task_{task_ind}_train"].set_format(
        type="torch", columns=["relative_embeddings", "embedding", "y", "id", "shared"]
    )

### Average the shared samples

In [None]:
shared_samples = []
disjoint_samples = []

for task_ind in range(1, num_tasks + 1):
    task_shared_samples = data[f"task_{task_ind}_train"].filter(lambda row: row["shared"]).sort("id")

    task_novel_samples = data[f"task_{task_ind}_train"].filter(lambda row: ~row["shared"])

    shared_samples.append(task_shared_samples)
    disjoint_samples.append(task_novel_samples)

In [None]:
novel_samples = concatenate_datasets(disjoint_samples)

In [None]:
shared_samples

In [None]:
novel_samples

In [None]:
del disjoint_samples

for task_ind in range(1, num_tasks + 1):
    del data[f"task_{task_ind}_train"]

In [None]:
# compute the mean of the shared_samples and put them back in the dataset
# Extract the 'embedding' columns from each dataset

shared_rel_embeddings = [dataset["relative_embeddings"] for dataset in shared_samples]

# Calculate the mean of the embeddings for each sample
mean_embeddings = torch.mean(torch.stack(shared_rel_embeddings), dim=0)

In [None]:
# Create a new dataset with the same features as the original datasets
new_features = shared_samples[0].features.copy()

# Replace the 'embedding' column in the new dataset with the mean embeddings
new_data = {column: shared_samples[0][column] for column in new_features}
new_data["relative_embeddings"] = mean_embeddings.tolist()

# Create the new Hugging Face dataset
shared_dataset = Dataset.from_dict(new_data, features=new_features)
del new_data

### Concat the task-specific samples and the shared samples to go to the merged space

In [None]:
merged_dataset = concatenate_datasets([shared_dataset, novel_samples])
del shared_dataset
del novel_samples

In [None]:
merged_dataset = merged_dataset.sort("id")
original_dataset = data[f"task_0_train"].sort("id")

In [None]:
merged_dataset.set_format(type="torch", columns=["relative_embeddings", "fine_label", "coarse_label"])
original_dataset.set_format(type="torch", columns=["relative_embeddings", "fine_label", "coarse_label"])

In [None]:
merged_space = merged_dataset["relative_embeddings"]
original_space = original_dataset["relative_embeddings"]

In [None]:
original_space_y = original_dataset["fine_label"]
original_space_coarse_labels = original_dataset["coarse_label"]

# Analysis

## Whole space (all classes)

### Pick a subsample

In [None]:
subsample_dim: int = 1000
subsample_indices = random.sample(range(0, original_space.shape[0]), subsample_dim)

In [None]:
subsample_original = original_space[subsample_indices]
subsample_merged = merged_space[subsample_indices]
subsample_labels = original_space_y[subsample_indices]
subsample_coarse_labels = original_space_coarse_labels[subsample_indices]

In [None]:
sort_indices: torch.Tensor = subsample_labels.sort().indices

subsample_original_sorted: torch.Tensor = subsample_original[sort_indices]
subsample_merged_sorted: torch.Tensor = subsample_merged[sort_indices]
subsample_labels_sorted: torch.Tensor = subsample_labels[sort_indices]
subsample_coarse_labels_sorted: torch.Tensor = subsample_coarse_labels[sort_indices]

In [None]:
from la.utils.relative_analysis import plot_pairwise_dist

plot_pairwise_dist(space1=subsample_original_sorted, space2=subsample_merged_sorted, prefix="Relative")

In [None]:
from la.utils.relative_analysis import self_sim_comparison

self_sim_comparison(space1=subsample_original_sorted, space2=subsample_merged_sorted, normalize=True)

In [None]:
from la.utils.relative_analysis import plot_self_dist

plot_self_dist(space1=subsample_original_sorted, space2=subsample_merged_sorted, prefix="Relative")

In [None]:
from la.utils.relative_analysis import Reduction, reduce

x_header = [reduction.upper() for reduction in Reduction]
y_header = ["Relative Space 1", "Relative Space 2"]

spaces = [
    [
        *reduce(space1=subsample_original_sorted, space2=subsample_merged_sorted, reduction=reduction),
    ]
    for reduction in Reduction
]

In [None]:
from la.utils.relative_analysis import plot_space_grid

fig = plot_space_grid(x_header=x_header, y_header=y_header, spaces=spaces, c=subsample_labels_sorted)
fig

In [None]:
from la.utils.relative_analysis import plot_space_grid

fig = plot_space_grid(x_header=x_header, y_header=y_header, spaces=spaces, c=subsample_coarse_labels_sorted)
fig

## Only non-shared classes


In [None]:
non_shared_classes = set(range(100)).difference(data["metadata"]["shared_classes"])
non_shared_classes

In [None]:
merged_dataset_nonshared = merged_dataset.filter(lambda row: row["fine_label"].item() in non_shared_classes)
original_dataset_nonshared = original_dataset.filter(lambda row: row["fine_label"].item() in non_shared_classes)

In [None]:
merged_space_nonshared = merged_dataset_nonshared["relative_embeddings"]
original_space_nonshared = original_dataset_nonshared["relative_embeddings"]
original_space_y_nonshared = original_dataset_nonshared["fine_label"]
original_space_coarse_labels_nonshared = original_dataset_nonshared["coarse_label"]

### Pick a subsample

In [None]:
subsample_dim: int = 1000
subsample_indices = random.sample(range(0, original_space_nonshared.shape[0]), subsample_dim)

In [None]:
subsample_original = original_space_nonshared[subsample_indices]
subsample_merged = merged_space_nonshared[subsample_indices]
subsample_labels = original_space_y_nonshared[subsample_indices]
subsample_coarse_labels = original_space_coarse_labels_nonshared[subsample_indices]

In [None]:
sort_indices: torch.Tensor = subsample_labels.sort().indices

subsample_original_sorted: torch.Tensor = subsample_original[sort_indices]
subsample_merged_sorted: torch.Tensor = subsample_merged[sort_indices]
subsample_labels_sorted: torch.Tensor = subsample_labels[sort_indices]
subsample_coarse_labels_sorted: torch.Tensor = subsample_coarse_labels[sort_indices]

In [None]:
from la.utils.relative_analysis import plot_pairwise_dist

plot_pairwise_dist(space1=subsample_original_sorted, space2=subsample_merged_sorted, prefix="Relative")

In [None]:
from la.utils.relative_analysis import self_sim_comparison

self_sim_comparison(space1=subsample_original_sorted, space2=subsample_merged_sorted, normalize=True)

In [None]:
from la.utils.relative_analysis import plot_self_dist

plot_self_dist(space1=subsample_original_sorted, space2=subsample_merged_sorted, prefix="Relative")

In [None]:
from la.utils.relative_analysis import Reduction, reduce

x_header = [reduction.upper() for reduction in Reduction]
y_header = ["Relative Space 1", "Relative Space 2"]

spaces = [
    [
        *reduce(space1=subsample_original_sorted, space2=subsample_merged_sorted, reduction=reduction),
    ]
    for reduction in Reduction
]

In [None]:
from la.utils.relative_analysis import plot_space_grid

fig = plot_space_grid(x_header=x_header, y_header=y_header, spaces=spaces, c=subsample_labels_sorted)
fig

In [None]:
from la.utils.relative_analysis import plot_space_grid

fig = plot_space_grid(x_header=x_header, y_header=y_header, spaces=spaces, c=subsample_coarse_labels_sorted)
fig

# Classifier Analysis


In [None]:
from torch import nn
import pytorch_lightning
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import RichProgressBar, TQDMProgressBar

In [None]:
from functools import partial

dataloader_func = partial(
    torch.utils.data.DataLoader,
    batch_size=32,
    num_workers=4,
)

trainer_func = partial(Trainer, gpus=1, max_epochs=100, logger=False, enable_progress_bar=True)

In [None]:
classifier_embed_dim = 512
num_classes = 100


def get_classifier():
    return nn.Sequential(
        nn.LayerNorm(normalized_shape=original_space.shape[1]),
        nn.Linear(num_anchors, classifier_embed_dim),
        nn.ReLU(),
        nn.Linear(classifier_embed_dim, num_classes),
    )

In [None]:
class Model(pytorch_lightning.LightningModule):
    def __init__(self, classifier: nn.Module):
        super().__init__()
        self.classifier = classifier
        self.accuracy = torchmetrics.Accuracy()

    def forward(self, x):
        return self.classifier(x)

    def training_step(self, batch, batch_idx):
        x, y = batch["relative_embeddings"], batch["fine_label"]
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("train_loss", loss, on_step=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch["relative_embeddings"], batch["fine_label"]
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("val_loss", loss, on_step=True, prog_bar=True)

        val_acc = self.accuracy(y_hat, y)
        self.log("val_acc", val_acc, on_step=True, on_epoch=True, prog_bar=True)

        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch["relative_embeddings"], batch["fine_label"]
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("test_loss", loss, on_step=True)

        test_acc = self.accuracy(y_hat, y)
        self.log("test_acc", test_acc, on_step=True, on_epoch=True, prog_bar=True)

        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

## Classifier over the original space


In [None]:
seed_everything(42)

In [None]:
original_classifier = Model(classifier=get_classifier())

In [None]:
trainer = trainer_func(callbacks=[pytorch_lightning.callbacks.EarlyStopping(monitor="val_loss", patience=10)])

In [None]:
# split dataset in train, val and test
split_dataset = original_dataset.train_test_split(test_size=0.3, seed=42)
original_dataset_train = split_dataset["train"]
original_dataset_val_test = split_dataset["test"]

split_val_test = original_dataset_val_test.train_test_split(test_size=0.5, seed=42)
original_dataset_val = split_val_test["train"]
original_dataset_test = split_val_test["test"]

In [None]:
original_train_dataloader = dataloader_func(original_dataset_train, shuffle=True)
original_val_dataloader = dataloader_func(original_dataset_val, shuffle=False)
original_test_dataloader = dataloader_func(original_dataset_test, shuffle=False)

In [None]:
trainer.fit(original_classifier, original_train_dataloader, original_val_dataloader)

In [None]:
results = trainer.test(original_classifier, original_test_dataloader)
results

## Classifier over the merged space

In [None]:
seed_everything(42)

In [None]:
# split dataset in train, val and test
split_dataset = merged_dataset.train_test_split(test_size=0.2, seed=42)
merged_dataset_train = split_dataset["train"]
merged_dataset_val_test = split_dataset["test"]

split_val_test = merged_dataset_val_test.train_test_split(test_size=0.5, seed=42)
merged_dataset_val = split_val_test["train"]
merged_dataset_test = split_val_test["test"]

In [None]:
merged_train_dataloader = dataloader_func(merged_dataset_train, shuffle=True)
merged_val_dataloader = dataloader_func(merged_dataset_val, shuffle=False)
merged_test_dataloader = dataloader_func(merged_dataset_test, shuffle=False)

In [None]:
classifier = get_classifier()

merged_classifier = Model(classifier=classifier)

In [None]:
trainer = trainer_func(callbacks=[pytorch_lightning.callbacks.EarlyStopping(monitor="val_loss", patience=10)])
trainer.fit(merged_classifier, merged_train_dataloader, merged_val_dataloader)

In [None]:
results = trainer.test(merged_classifier, merged_test_dataloader)
results