## Imports

In [None]:
%autoreload 2
%load_ext autoreload

In [None]:
import random

import numpy as np
import pytorch_lightning
import torchmetrics
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import transforms
from la.utils.utils import MyDatasetDict, add_tensor_column
from tqdm import tqdm
import torch
from pathlib import Path
from pytorch_lightning import seed_everything
from torch.nn import functional as F

from la.pl_modules.pl_module import MyLightningModule
from la.utils.utils import MyDatasetDict
from backports.strenum import StrEnum
from enum import auto
from nn_core.common import PROJECT_ROOT

import hdf5storage
from torch.nn.functional import mse_loss, pairwise_distance
from torchmetrics.functional import pearson_corrcoef, spearman_corrcoef

from hydra.core.global_hydra import GlobalHydra
from hydra import compose, initialize
from omegaconf import OmegaConf
from datasets import concatenate_datasets, Dataset

import matplotlib.pyplot as plt

plt.style.use("dark_background")

In [None]:
from tueplots import bundles

seed_everything(43)
bundles.icml2022()

# Data preprocessing


## Data loading

In [None]:
dataset_name = "cifar100"
model_name = "efficient_net"
num_shared_classes = 80
num_novel_classes = 5
num_total_classes = 100

dataset_path = f"{PROJECT_ROOT}/data/{dataset_name}/S{num_shared_classes}_N{num_novel_classes}_{model_name}"

In [None]:
data: MyDatasetDict = MyDatasetDict.load_from_disk(dataset_dict_path=dataset_path)
num_tasks = data["metadata"]["num_tasks"]

In [None]:
non_shared_classes = set(range(num_total_classes)).difference(data["metadata"]["shared_classes"])
non_shared_classes

In [None]:
data

In [None]:
for task in range(num_tasks + 1):
    for mode in ["train", "test"]:
        data[f"task_{task}_{mode}"].set_format("torch", columns=["embedding", "y", "id"])

In [None]:
num_tasks = data["metadata"]["num_tasks"]

shared_classes = set(data["metadata"]["shared_classes"])

num_shared_samples = data["metadata"]["num_train_samples_per_class"] * data["metadata"]["num_shared_classes"]

num_classes_per_task = data["metadata"]["num_shared_classes"] + data["metadata"]["num_novel_classes_per_task"]
num_train_samples_per_task = data["metadata"]["num_train_samples_per_class"] * num_classes_per_task

#### Map the local labels back to global

In [None]:
for task_ind in range(1, num_tasks + 1):

    global_to_local_map = data["metadata"]["global_to_local_class_mappings"][f"task_{task_ind}"]
    local_to_global_map = {v: int(k) for k, v in global_to_local_map.items()}

    for mode in ["train", "test"]:
        data[f"task_{task_ind}_{mode}"] = data[f"task_{task_ind}_{mode}"].map(
            lambda row: {"y": local_to_global_map[row["y"].item()]}
        )

# Obtain anchors

In [None]:
num_anchors = data["task_0_train"]["embedding"].shape[1]
print(f"Using {num_anchors} anchors")

### Get shared samples indices
Add **shared** column, `True` for samples belonging to shared classes and False otherwise

In [None]:
for task_ind in range(num_tasks + 1):

    for mode in ["train", "test"]:

        data[f"task_{task_ind}_{mode}"] = data[f"task_{task_ind}_{mode}"].map(
            lambda row: {"shared": row["y"].item() in shared_classes}
        )

In [None]:
shared_ids = []

for task_ind in range(num_tasks + 1):
    all_ids = data[f"task_{task_ind}_train"]["id"]

    # get the indices of samples having shared to True
    task_shared_ids = all_ids[data[f"task_{task_ind}_train"]["shared"]].tolist()

    shared_ids.append(task_shared_ids)

Make sure the shared indices are the same across all the tasks

In [None]:
for task_i in range(num_tasks + 1):
    for task_j in range(task_i, num_tasks + 1):
        assert shared_ids[task_i] == shared_ids[task_j]

shared_ids = shared_ids[0]

### Sample anchor indices

In [None]:
anchor_ids = random.sample(shared_ids, num_anchors)

Add **anchor** column, being `True` only if the corresponding sample is an anchor

In [None]:
# only training samples can be anchors
for task_ind in range(num_tasks + 1):
    data[f"task_{task_ind}_train"] = data[f"task_{task_ind}_train"].map(
        lambda row: {"anchor": row["id"].item() in anchor_ids}
    )

(Optional) center the spaces

In [None]:
centering = False
if centering:
    for task_ind in range(num_tasks + 1):
        embedding_mean = data[f"task_{task_ind}_train"]["embedding"].mean(dim=0)
        data[f"task_{task_ind}_train"] = data[f"task_{task_ind}_train"].map(
            lambda row: {"embedding": row["embedding"] - embedding_mean}
        )

# Map to relative spaces

In [None]:
for task_ind in range(0, num_tasks + 1):

    task_anchors = data[f"task_{task_ind}_train"]["embedding"][data[f"task_{task_ind}_train"]["anchor"]]
    norm_anchors = F.normalize(task_anchors, p=2, dim=-1)

    for mode in ["train", "test"]:

        task_embeddings = data[f"task_{task_ind}_{mode}"]["embedding"]

        abs_space = F.normalize(task_embeddings, p=2, dim=-1)

        rel_space = abs_space @ norm_anchors.T

        data[f"task_{task_ind}_{mode}"] = add_tensor_column(
            data[f"task_{task_ind}_{mode}"], "relative_embeddings", rel_space
        )

In [None]:
for task_ind in range(0, num_tasks + 1):
    for mode in ["train", "test"]:
        data[f"task_{task_ind}_{mode}"].set_format(
            type="torch", columns=["relative_embeddings", "embedding", "y", "id", "shared"]
        )

### Divide the shared and the non-shared samples

In [None]:
shared_samples = {"train": [], "test": []}
disjoint_samples = {"train": [], "test": []}

for task_ind in range(1, num_tasks + 1):

    for mode in ["train", "test"]:

        task_shared_samples = data[f"task_{task_ind}_{mode}"].filter(lambda row: row["shared"]).sort("id")

        task_novel_samples = data[f"task_{task_ind}_{mode}"].filter(lambda row: ~row["shared"])

        shared_samples[mode].append(task_shared_samples)
        disjoint_samples[mode].append(task_novel_samples)

In [None]:
novel_samples = {
    "train": concatenate_datasets(disjoint_samples["train"]),
    "test": concatenate_datasets(disjoint_samples["test"]),
}

In [None]:
shared_samples

In [None]:
novel_samples

### Average the shared samples and then concat the task-specific samples and the shared samples to go to the merged space

In [None]:
# compute the mean of the shared_samples and put them back in the dataset
# Extract the 'embedding' columns from each dataset

merged_datasets = {"train": [], "test": []}

for mode in ["train", "test"]:
    shared_rel_embeddings = [dataset["relative_embeddings"] for dataset in shared_samples[mode]]

    # Calculate the mean of the embeddings for each sample
    mean_embeddings = torch.mean(torch.stack(shared_rel_embeddings), dim=0)

    # Create a new dataset with the same features as the original datasets
    new_features = shared_samples[mode][0].features.copy()

    # Replace the 'embedding' column in the new dataset with the mean embeddings
    new_data = {column: shared_samples[mode][0][column] for column in new_features}
    new_data["relative_embeddings"] = mean_embeddings.tolist()

    # Create the new Hugging Face dataset
    shared_dataset = Dataset.from_dict(new_data, features=new_features)

    merged_dataset = concatenate_datasets([shared_dataset, novel_samples[mode]])

    merged_datasets[mode] = merged_dataset

## Sort both datasets by ID

In [None]:
has_coarse_label = dataset_name == "cifar_100"

In [None]:
mode = "test"  # train or test

In [None]:
merged_dataset = merged_datasets[mode].sort("id")
original_dataset = data[f"task_0_{mode}"].sort("id")

In [None]:
columns = ["relative_embeddings", "y", "embedding"]
if has_coarse_label:
    columns.append("coarse_label")

merged_dataset.set_format(type="torch", columns=columns)
original_dataset.set_format(type="torch", columns=columns)

merged_space = merged_dataset["relative_embeddings"]
original_space = original_dataset["relative_embeddings"]

original_space_y = original_dataset["y"]

# Similarity Analysis

## CKA

In [None]:
from la.utils.cka import CKA

In [None]:
cka = CKA(mode="linear", device="cuda")

cka_score = cka(merged_space, original_space)
print(cka_score)

## Plots

In [None]:
from la.utils.relative_analysis import plot_space_grid
from la.utils.relative_analysis import plot_pairwise_dist
from la.utils.relative_analysis import plot_self_dist
from la.utils.relative_analysis import Reduction, reduce
from la.utils.relative_analysis import self_sim_comparison

### Whole space (all classes)

#### Pick a subsample

In [None]:
subsample_dim: int = 1000
subsample_indices = random.sample(range(0, original_space.shape[0]), subsample_dim)

In [None]:
subsample_original = original_space[subsample_indices]
subsample_merged = merged_space[subsample_indices]
subsample_labels = original_space_y[subsample_indices]

In [None]:
sort_indices: torch.Tensor = subsample_labels.sort().indices

subsample_original_sorted: torch.Tensor = subsample_original[sort_indices]
subsample_merged_sorted: torch.Tensor = subsample_merged[sort_indices]
subsample_labels_sorted: torch.Tensor = subsample_labels[sort_indices]

In [None]:
plot_pairwise_dist(space1=subsample_original_sorted, space2=subsample_merged_sorted, prefix="Relative")

In [None]:
self_sim_comparison(space1=subsample_original_sorted, space2=subsample_merged_sorted, normalize=True)

In [None]:
plot_self_dist(space1=subsample_original_sorted, space2=subsample_merged_sorted, prefix="Relative")

In [None]:
x_header = [reduction.upper() for reduction in Reduction]
y_header = ["Relative Space 1", "Relative Space 2"]

spaces = [
    [
        *reduce(space1=subsample_original_sorted, space2=subsample_merged_sorted, reduction=reduction),
    ]
    for reduction in Reduction
]

In [None]:
fig = plot_space_grid(x_header=x_header, y_header=y_header, spaces=spaces, c=subsample_labels_sorted)
fig

In [None]:
if has_coarse_label:
    original_space_coarse_labels = original_dataset["coarse_label"]
    subsample_coarse_labels = original_space_coarse_labels[subsample_indices]
    subsample_coarse_labels_sorted: torch.Tensor = subsample_coarse_labels[sort_indices]
    fig = plot_space_grid(x_header=x_header, y_header=y_header, spaces=spaces, c=subsample_coarse_labels_sorted)
    fig

### Only non-shared classes


In [None]:
merged_dataset_nonshared = merged_dataset.filter(lambda row: row["y"].item() in non_shared_classes)
original_dataset_nonshared = original_dataset.filter(lambda row: row["y"].item() in non_shared_classes)

In [None]:
merged_space_nonshared = merged_dataset_nonshared["relative_embeddings"]
original_space_nonshared = original_dataset_nonshared["relative_embeddings"]
original_space_y_nonshared = original_dataset_nonshared["y"]

#### Pick a subsample

In [None]:
subsample_dim: int = 1000
subsample_indices = random.sample(range(0, original_space_nonshared.shape[0]), subsample_dim)

In [None]:
subsample_original = original_space_nonshared[subsample_indices]
subsample_merged = merged_space_nonshared[subsample_indices]
subsample_labels = original_space_y_nonshared[subsample_indices]

In [None]:
sort_indices: torch.Tensor = subsample_labels.sort().indices

subsample_original_sorted: torch.Tensor = subsample_original[sort_indices]
subsample_merged_sorted: torch.Tensor = subsample_merged[sort_indices]
subsample_labels_sorted: torch.Tensor = subsample_labels[sort_indices]

In [None]:
plot_pairwise_dist(space1=subsample_original_sorted, space2=subsample_merged_sorted, prefix="Relative")

In [None]:
self_sim_comparison(space1=subsample_original_sorted, space2=subsample_merged_sorted, normalize=True)

In [None]:
plot_self_dist(space1=subsample_original_sorted, space2=subsample_merged_sorted, prefix="Relative")

In [None]:
x_header = [reduction.upper() for reduction in Reduction]
y_header = ["Relative Space 1", "Relative Space 2"]

spaces = [
    [
        *reduce(space1=subsample_original_sorted, space2=subsample_merged_sorted, reduction=reduction),
    ]
    for reduction in Reduction
]

In [None]:
fig = plot_space_grid(x_header=x_header, y_header=y_header, spaces=spaces, c=subsample_labels_sorted)
fig

In [None]:
if has_coarse_label:
    original_space_coarse_labels_nonshared = original_dataset_nonshared["coarse_label"]
    subsample_coarse_labels = original_space_coarse_labels_nonshared[subsample_indices]
    subsample_coarse_labels_sorted: torch.Tensor = subsample_coarse_labels[sort_indices]
    fig = plot_space_grid(x_header=x_header, y_header=y_header, spaces=spaces, c=subsample_coarse_labels_sorted)
    fig

#### Color by task

In [None]:
def get_novel_classes(task_classes, shared_classes):
    return set(task_classes).difference(shared_classes)


def get_label_to_task_mapping():

    novel_classes_per_task = []
    for i in range(1, num_tasks + 1):
        task_classes = [int(key) for key in data["metadata"]["global_to_local_class_mappings"][f"task_{i}"].keys()]
        task_novel_classes = get_novel_classes(task_classes, shared_classes)

        assert len(task_novel_classes) == num_novel_classes

        novel_classes_per_task.append(task_novel_classes)

    label_to_task = {}
    for task_ind in range(0, num_tasks):
        for label in novel_classes_per_task[task_ind]:
            label_to_task[label] = task_ind

    return label_to_task


label_to_task = get_label_to_task_mapping()
print(label_to_task)

In [None]:
task_labels = subsample_labels_sorted.apply_(lambda x: label_to_task[x])

In [None]:
fig = plot_space_grid(x_header=x_header, y_header=y_header, spaces=spaces, c=task_labels)
fig

#### Only non-shared classes, prototypes

In [None]:
merged_space_nonshared = merged_dataset_nonshared["relative_embeddings"]
original_space_nonshared = original_dataset_nonshared["relative_embeddings"]
original_space_y_nonshared = original_dataset_nonshared["y"]

##### Compute prototypes

In [None]:
class_prototypes_merged = []
for class_ind in non_shared_classes:
    class_prototypes_merged.append(torch.mean(merged_space_nonshared[original_space_y_nonshared == class_ind], dim=0))

class_prototypes_merged = torch.stack(class_prototypes_merged)

In [None]:
class_prototypes_original = []
for class_ind in non_shared_classes:
    class_prototypes_original.append(
        torch.mean(original_space_nonshared[original_space_y_nonshared == class_ind], dim=0)
    )

class_prototypes_original = torch.stack(class_prototypes_original)

In [None]:
plot_pairwise_dist(space1=class_prototypes_original, space2=class_prototypes_merged, prefix="Relative")

In [None]:
plot_self_dist(space1=class_prototypes_original, space2=class_prototypes_merged, prefix="Relative")

In [None]:
x_header = [reduction.upper() for reduction in Reduction]
y_header = ["Relative Space 1", "Relative Space 2"]

spaces = [
    [
        *reduce(space1=class_prototypes_original, space2=class_prototypes_merged, reduction=reduction, perplexity=15),
    ]
    for reduction in Reduction
]

In [None]:
fig = plot_space_grid(x_header=x_header, y_header=y_header, spaces=spaces, c=np.arange(len(non_shared_classes)))
fig

# Classifier Analysis


In [None]:
from torch import nn
import pytorch_lightning
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import RichProgressBar, TQDMProgressBar

In [None]:
from functools import partial

dataloader_func = partial(
    torch.utils.data.DataLoader,
    batch_size=128,
    num_workers=4,
)

trainer_func = partial(Trainer, gpus=1, max_epochs=100, logger=False, enable_progress_bar=True)

In [None]:
classifier_embed_dim = 256


class Classifier(nn.Module):
    def __init__(self, input_dim, classifier_embed_dim, num_classes):
        super().__init__()
        self.classifier = nn.Sequential(
            nn.LayerNorm(normalized_shape=original_space.shape[1]),
            nn.Linear(input_dim, classifier_embed_dim),
            nn.ReLU(),
            nn.Linear(classifier_embed_dim, num_classes),
        )

    def forward(self, x):
        return self.classifier(x)

In [None]:
class Model(pytorch_lightning.LightningModule):
    def __init__(self, classifier: nn.Module, shared_classes: set, non_shared_classes: set, use_relatives: bool):
        super().__init__()
        self.classifier = classifier

        shared_classes = torch.Tensor(list(shared_classes)).long()
        non_shared_classes = torch.Tensor(list(non_shared_classes)).long()

        self.register_buffer("shared_classes", shared_classes)
        self.register_buffer("non_shared_classes", non_shared_classes)

        self.accuracy = torchmetrics.Accuracy()

        self.use_relatives = use_relatives
        self.embedding_key = "relative_embeddings" if self.use_relatives else "embedding"

    def forward(self, x):
        return self.classifier(x)

    def training_step(self, batch, batch_idx):

        x, y = batch[self.embedding_key], batch["y"]
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("train_loss", loss, on_step=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch[self.embedding_key], batch["y"]
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("val_loss", loss, on_step=True, prog_bar=True)

        val_acc = self.accuracy(y_hat, y)
        self.log("val_acc", val_acc, on_step=True, on_epoch=True, prog_bar=True)

        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch[self.embedding_key], batch["y"]
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("test_loss", loss, on_step=True)

        test_acc = self.accuracy(y_hat, y)
        self.log("test_acc", test_acc, on_step=True, on_epoch=True, prog_bar=True)

        # compute accuracy for shared classes
        shared_classes_mask = torch.isin(y, self.shared_classes)
        shared_classes_y = y[shared_classes_mask]

        y_hat = torch.argmax(y_hat, dim=1)
        shared_classes_y_hat = y_hat[shared_classes_mask]

        shared_classes_acc = torch.sum(shared_classes_y == shared_classes_y_hat) / len(shared_classes_y)
        self.log("test_acc_shared_classes", shared_classes_acc, on_step=True, on_epoch=True, prog_bar=True)

        # compute accuracy for non-shared classes
        non_shared_classes_mask = torch.isin(y, self.non_shared_classes)
        non_shared_classes_y = y[non_shared_classes_mask]
        non_shared_classes_y_hat = y_hat[non_shared_classes_mask]

        non_shared_classes_acc = torch.sum(non_shared_classes_y == non_shared_classes_y_hat) / len(non_shared_classes_y)
        self.log("test_acc_non_shared_classes", non_shared_classes_acc, on_step=True, on_epoch=True, prog_bar=True)

        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

## Classifier over all the classes, original, absolute

In [None]:
seed_everything(42)

In [None]:
original_classifier = Classifier(
    input_dim=original_space.shape[1], classifier_embed_dim=classifier_embed_dim, num_classes=num_total_classes
)
original_model = Model(
    classifier=original_classifier,
    shared_classes=shared_classes,
    non_shared_classes=non_shared_classes,
    use_relatives=False,
)

In [None]:
trainer = trainer_func(callbacks=[pytorch_lightning.callbacks.EarlyStopping(monitor="val_loss", patience=10)])

In [None]:
# split dataset in train, val and test
split_dataset = original_dataset.train_test_split(test_size=0.3, seed=42)
original_dataset_train = split_dataset["train"]
original_dataset_val_test = split_dataset["test"]

split_val_test = original_dataset_val_test.train_test_split(test_size=0.5, seed=42)
original_dataset_val = split_val_test["train"]
original_dataset_test = split_val_test["test"]

In [None]:
original_train_dataloader = dataloader_func(original_dataset_train, shuffle=True)
original_val_dataloader = dataloader_func(original_dataset_val, shuffle=False)
original_test_dataloader = dataloader_func(original_dataset_test, shuffle=False)

In [None]:
trainer.fit(original_model, original_train_dataloader, original_val_dataloader)

In [None]:
results = trainer.test(original_model, original_test_dataloader)
results

## Classification over all the classes, original, relative

In [None]:
seed_everything(42)

In [None]:
original_classifier = Classifier(
    input_dim=original_space.shape[1], classifier_embed_dim=classifier_embed_dim, num_classes=num_total_classes
)
original_model = Model(
    classifier=original_classifier,
    shared_classes=shared_classes,
    non_shared_classes=non_shared_classes,
    use_relatives=True,
)
trainer = trainer_func(callbacks=[pytorch_lightning.callbacks.EarlyStopping(monitor="val_loss", patience=10)])

In [None]:
# split dataset in train, val and test
split_dataset = original_dataset.train_test_split(test_size=0.3, seed=42)
original_dataset_train = split_dataset["train"]
original_dataset_val_test = split_dataset["test"]

split_val_test = original_dataset_val_test.train_test_split(test_size=0.5, seed=42)
original_dataset_val = split_val_test["train"]
original_dataset_test = split_val_test["test"]

In [None]:
original_train_dataloader = dataloader_func(original_dataset_train, shuffle=True)
original_val_dataloader = dataloader_func(original_dataset_val, shuffle=False)
original_test_dataloader = dataloader_func(original_dataset_test, shuffle=False)

In [None]:
trainer.fit(original_model, original_train_dataloader, original_val_dataloader)

In [None]:
results = trainer.test(original_model, original_test_dataloader)
results

## Classification over all the classes, merged

In [None]:
seed_everything(42)

In [None]:
# split dataset in train, val and test
split_dataset = merged_dataset.train_test_split(test_size=0.3, seed=42)
merged_dataset_train = split_dataset["train"]
merged_dataset_val_test = split_dataset["test"]

split_val_test = merged_dataset_val_test.train_test_split(test_size=0.5, seed=42)
merged_dataset_val = split_val_test["train"]
merged_dataset_test = split_val_test["test"]

In [None]:
merged_train_dataloader = dataloader_func(merged_dataset_train, shuffle=True)
merged_val_dataloader = dataloader_func(merged_dataset_val, shuffle=False)
merged_test_dataloader = dataloader_func(merged_dataset_test, shuffle=False)

In [None]:
merged_classifier = Classifier(
    input_dim=merged_space.shape[1], classifier_embed_dim=classifier_embed_dim, num_classes=num_total_classes
)
merged_model = Model(
    classifier=merged_classifier,
    shared_classes=shared_classes,
    non_shared_classes=non_shared_classes,
    use_relatives=True,
)

In [None]:
trainer = trainer_func(callbacks=[pytorch_lightning.callbacks.EarlyStopping(monitor="val_loss", patience=10)])
trainer.fit(merged_model, merged_train_dataloader, merged_val_dataloader)

In [None]:
results = trainer.test(merged_model, merged_test_dataloader)
results