## Imports

In [None]:
%autoreload 2
%load_ext autoreload

In [None]:
import random

import numpy as np
import pytorch_lightning
import torchmetrics
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import transforms
from la.utils.utils import MyDatasetDict, add_tensor_column
from tqdm import tqdm
import torch
from pathlib import Path
from pytorch_lightning import seed_everything
from torch.nn import functional as F

from la.pl_modules.pl_module import MyLightningModule
from la.utils.utils import MyDatasetDict
from backports.strenum import StrEnum
from enum import auto
from nn_core.common import PROJECT_ROOT

import hdf5storage
from torch.nn.functional import mse_loss, pairwise_distance
from torchmetrics.functional import pearson_corrcoef, spearman_corrcoef

from hydra.core.global_hydra import GlobalHydra
from hydra import compose, initialize
from omegaconf import OmegaConf
from datasets import concatenate_datasets, Dataset

import matplotlib.pyplot as plt

plt.style.use("dark_background")

In [None]:
from tueplots import bundles

seed_everything(43)
bundles.icml2022()

# Data preprocessing

## Data loading

In [None]:
model_name = "from_scratch_cnn"
dataset_name = "cifar100"

dataset_path = f"{PROJECT_ROOT}/data/{dataset_name}/partitioned_{model_name}"

In [None]:
data: MyDatasetDict = MyDatasetDict.load_from_disk(dataset_dict_path=dataset_path)
num_tasks = data["metadata"]["num_tasks"]
num_anchors = len(data["task_0_anchors"])
num_classes = data["metadata"]["num_classes"]

In [None]:
# (OPT) select a subset of the anchors
SUBSAMPLE_ANCHORS = True

if SUBSAMPLE_ANCHORS:
    num_anchors = 256
    for task in range(num_tasks + 1):
        anchors_subsample = data[f"task_{task}_anchors"].select(range(num_anchors))
        print(anchors_subsample)
        data[f"task_{task}_anchors"] = anchors_subsample

In [None]:
for task in range(num_tasks + 1):
    for mode in ["train", "test", "anchors"]:
        data[f"task_{task}_{mode}"].set_format("torch", columns=["embedding", "y", "id"])

# Map to relative space

In [None]:
for task_ind in range(0, num_tasks + 1):

    task_anchors = data[f"task_{task_ind}_anchors"]["embedding"]
    norm_anchors = F.normalize(task_anchors, p=2, dim=-1)

    for mode in ["train", "test"]:

        task_embeddings = data[f"task_{task_ind}_{mode}"]["embedding"]

        abs_space = F.normalize(task_embeddings, p=2, dim=-1)

        rel_space = abs_space @ norm_anchors.T

        data[f"task_{task_ind}_{mode}"] = add_tensor_column(
            data[f"task_{task_ind}_{mode}"], "relative_embeddings", rel_space
        )

In [None]:
for task_ind in range(0, num_tasks + 1):
    for mode in ["train", "test"]:
        data[f"task_{task_ind}_{mode}"].set_format(
            type="torch", columns=["relative_embeddings", "embedding", "y", "id"]
        )

# Merge the spaces

In [None]:
mode = "test"

In [None]:
merged_dataset = concatenate_datasets([data[f"task_{i}_{mode}"] for i in range(1, num_tasks + 1)])

In [None]:
print(merged_dataset["relative_embeddings"])

In [None]:
original_dataset = data[f"task_0_{mode}"]

In [None]:
merged_dataset = merged_dataset.sort("id")
original_dataset = data[f"task_0_{mode}"].sort("id")

In [None]:
merged_space = merged_dataset["relative_embeddings"]
original_space = original_dataset["relative_embeddings"]

In [None]:
from la.utils.cka import CKA

cka = CKA(mode="linear", device="cuda")

cka_score = cka(merged_space, original_space)
print(cka_score)

# Classification analysis

In [None]:
import torch.nn as nn


class Model(pytorch_lightning.LightningModule):
    def __init__(
        self,
        classifier: nn.Module,
        use_relatives: bool,
    ):
        super().__init__()
        self.classifier = classifier

        self.accuracy = torchmetrics.Accuracy()

        self.use_relatives = use_relatives
        self.embedding_key = "relative_embeddings" if self.use_relatives else "embedding"

    def forward(self, x):
        return self.classifier(x)

    def training_step(self, batch, batch_idx):
        x, y = batch[self.embedding_key], batch["y"]
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("train_loss", loss, on_step=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch[self.embedding_key], batch["y"]
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("val_loss", loss, on_step=True, prog_bar=True)

        val_acc = self.accuracy(y_hat, y)
        self.log("val_acc", val_acc, on_step=True, on_epoch=True, prog_bar=True)

        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch[self.embedding_key], batch["y"]
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("test_loss", loss, on_step=True)

        test_acc = self.accuracy(y_hat, y)
        self.log("test_acc", test_acc, on_step=True, on_epoch=True, prog_bar=True)

        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

In [None]:
from functools import partial

from pytorch_lightning import Trainer

from la.utils.class_analysis import Classifier


def run_classification_experiment(
    num_total_classes,
    input_dim,
    dataset,
    use_relatives,
    classifier_embed_dim,
):
    seed_everything(42)

    dataloader_func = partial(
        torch.utils.data.DataLoader,
        batch_size=128,
        num_workers=8,
    )

    trainer_func = partial(Trainer, gpus=1, max_epochs=100, logger=False, enable_progress_bar=True)

    classifier = Classifier(
        input_dim=input_dim,
        classifier_embed_dim=classifier_embed_dim,
        num_classes=num_total_classes,
    )
    model = Model(
        classifier=classifier,
        use_relatives=use_relatives,
    )
    trainer = trainer_func(callbacks=[pytorch_lightning.callbacks.EarlyStopping(monitor="val_loss", patience=10)])

    # split dataset in train, val and test
    split_dataset = dataset.train_test_split(test_size=0.3, seed=42)
    train_dataset = split_dataset["train"]
    val_test_dataset = split_dataset["test"]

    split_val_test = val_test_dataset.train_test_split(test_size=0.5, seed=42)
    val_dataset = split_val_test["train"]
    test_dataset = split_val_test["test"]

    train_dataloader = dataloader_func(train_dataset, shuffle=True)
    val_dataloader = dataloader_func(val_dataset, shuffle=False)
    test_dataloader = dataloader_func(test_dataset, shuffle=False)

    trainer.fit(model, train_dataloader, val_dataloader)

    results = trainer.test(model, test_dataloader)[0]

    results = {
        "total_acc": results["test_acc_epoch"],
    }

    return results

In [None]:
classifier_embed_dim = 128

In [None]:
run_classification_experiment(
    dataset=original_dataset,
    use_relatives=False,
    input_dim=original_dataset["embedding"].shape[1],
    num_total_classes=num_classes,
    classifier_embed_dim=classifier_embed_dim,
)

In [None]:
run_classification_experiment(
    dataset=original_dataset,
    use_relatives=True,
    input_dim=num_anchors,
    num_total_classes=num_classes,
    classifier_embed_dim=classifier_embed_dim,
)

In [None]:
run_classification_experiment(
    dataset=merged_dataset,
    use_relatives=True,
    input_dim=num_anchors,
    num_total_classes=num_classes,
    classifier_embed_dim=classifier_embed_dim,
)