In [None]:
%load_ext autoreload
%autoreload 2

import hydra
from hydra import initialize, compose
from typing import Dict, List
from la.data.my_dataset_dict import MyDatasetDict

initialize(version_base=None, config_path=str("../conf"), job_name="matrioska_learning")

# Instantiate configuration

In [None]:
from nn_core.common import PROJECT_ROOT

# Instantiate torchvision dataset
cfg = compose(config_name="matrioska_learning", overrides=[])

# Instantiate dataset

In [None]:
from la.utils.io_utils import add_ids_to_dataset, load_data
from la.utils.io_utils import preprocess_dataset


original_dataset = dataset = load_data(cfg)  # .shard(num_shards=10, index=0)  # TODO remove sharding when done develop
dataset = preprocess_dataset(dataset, cfg)
dataset = add_ids_to_dataset(dataset)

In [None]:
img_size = dataset["train"][0]["x"].shape[1]
dataset

In [None]:
# hf specific variables
# (if a dataset change is needed, it is enough to redefine these variables...)
class_names = original_dataset["train"].features["fine_label"].names
class_idxs = [original_dataset["train"].features["fine_label"].str2int(class_name) for class_name in class_names]

class_names, class_idxs

In [None]:
import dataclasses


@dataclasses.dataclass
class Result:
    matrioska_idx: int
    num_train_classes: int
    metric_name: str
    score: float

# Define matrioska datasets

In [None]:
# Define matrioska parameters... just start with the first two classes
MATRIOSKA_START_NCLASSES = [0, 1]
LIMIT_N_CLASSES = 30
remanining_classes = sorted((set(class_idxs) - set(MATRIOSKA_START_NCLASSES)))[:LIMIT_N_CLASSES]
MATRIOSKA_START_NCLASSES, remanining_classes

In [None]:
# Generate matrioska classes
matrioskaclasses = [set(MATRIOSKA_START_NCLASSES + remanining_classes[:i]) for i in range(len(remanining_classes) + 1)]
matrioskaclasses

In [None]:
# Generate associated datasets
# TODO: do we want to have the same number of samples in all the datasets?
# I think not. This is more fair, if this works we are in the worst case scenario.
matrioskaidx2dataset = {
    i: dataset.filter(lambda row: row["y"] in matrioskaclasses[i]) for i in range(len(matrioskaclasses))
}

# Note that we are using the prefix convention for the classes, thus we have consistency
# between local and global classes ids... let's stay with that it is easier
matrioskaidx2dataset

# Train matrioska models

In [None]:
from datasets import Dataset, DatasetDict

from typing import Dict
import tqdm
import torch

matrioskaidx2dataset

matrioskaidx2embeds: Dict[str, DatasetDict] = {
    f"matrioska{matrioska_idx}": DatasetDict(train=DatasetDict(), test=DatasetDict())
    for matrioska_idx in range(len(matrioskaclasses))
}
len(matrioskaidx2embeds), matrioskaidx2embeds

In [None]:
from torch.utils.data import DataLoader
from tqdm import tqdm

HF_EMBEDDING_DATASET_PATH = PROJECT_ROOT / "matrioska_learning" / "hf_embedding_dataset"


def embed_and_save_samples(matrioskaidx2embeds, dataset, model, matrioska_idx, batch_size=1024) -> Dict:
    modes = ["train", "test"]

    model.cuda().eval()

    for mode in modes:
        mode_embeddings = []
        mode_ids = []
        mode_labels = []
        mode_loader = DataLoader(
            dataset[mode],
            batch_size=batch_size,
            pin_memory=True,
            shuffle=False,
            num_workers=4,
        )
        for batch in tqdm(mode_loader, desc=f"Embedding {mode} samples for {matrioska_idx}th matrioska"):
            x = batch["x"].to("cuda")
            mode_embeddings.extend(model(x)["embeds"].detach())
            mode_ids.extend(batch["id"])
            mode_labels.extend(batch["y"])

        matrioskaidx2embeds[f"matrioska{matrioska_idx}"][mode] = Dataset.from_dict(
            {
                "embeds": mode_embeddings,
                "id": mode_ids,
                "y": mode_labels,
            }
        )

    model.cpu()
    matrioskaidx2embeds[f"matrioska{matrioska_idx}"].save_to_disk(
        HF_EMBEDDING_DATASET_PATH / f"matrioska{matrioska_idx}"
    )

In [None]:
from typing import List
from nn_core.callbacks import NNTemplateCore
from nn_core.model_logging import NNLogger
from nn_core.serialization import NNCheckpointIO
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader
from pytorch_lightning import Callback

from la.utils.utils import build_callbacks


matrioskaidx2model = {}


for i in range(len(matrioskaclasses)):
    print(f"Training model {i}...")

    model: pl.LightningModule = hydra.utils.instantiate(
        cfg.nn.model,
        _recursive_=False,
        num_classes=len(matrioskaclasses[i]),
        model=cfg.nn.model.model,
        input_dim=img_size,
    )

    processed_dataset = matrioskaidx2dataset[i].map(
        desc=f"Preprocessing samples",
        function=lambda x: {"x": model.transform_func(x["x"])},
    )
    processed_dataset.set_format(type="torch", columns=["x", "y", "id"])

    train_loader = DataLoader(
        processed_dataset["train"],
        batch_size=512,
        pin_memory=True,
        shuffle=True,
        num_workers=4,
    )
    val_loader = DataLoader(
        processed_dataset["test"],
        batch_size=512,
        pin_memory=True,
        shuffle=True,
        num_workers=1,
    )

    template_core: NNTemplateCore = NNTemplateCore(
        restore_cfg=cfg.train.get("restore", None),
    )
    callbacks: List[Callback] = build_callbacks(cfg.train.callbacks, template_core)

    storage_dir: str = cfg.core.storage_dir

    logger: NNLogger = NNLogger(logging_cfg=cfg.train.logging, cfg=cfg, resume_id=template_core.resume_id)

    # Use this in case we need to restore models, search for it in the wandb UI
    logger.experiment.config["matrioska_idx"] = i

    trainer = pl.Trainer(
        default_root_dir=storage_dir,
        plugins=[NNCheckpointIO(jailing_dir=logger.run_dir)],
        logger=logger,
        callbacks=callbacks,
        **cfg.train.trainer,
    )
    trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

    matrioskaidx2model[i] = trainer.model.eval().cpu().requires_grad_(False)

    embed_and_save_samples(matrioskaidx2embeds, processed_dataset, matrioskaidx2model[i], i, batch_size=1024)
    logger.experiment.finish()

# Evalute matrioska models with classifier

In [None]:
from datasets import Dataset, DatasetDict

import tqdm
import torch
from nn_core.common import PROJECT_ROOT
from torch.utils.data import DataLoader
from tqdm import tqdm
import datasets

HF_EMBEDDING_DATASET_PATH = PROJECT_ROOT / "matrioska_learning" / "hf_embedding_dataset"
N_MATRIOSKA = 21

matrioskaidx2embeds = {
    i: datasets.load_from_disk(HF_EMBEDDING_DATASET_PATH / f"matrioska{i}") for i in range(N_MATRIOSKA)
}
len(matrioskaidx2embeds)

In [None]:
# Decide which classes to evaluate on -- it may be interesting to change this
EVALUATION_CLASSES = {0, 1, 2, 3, 4}
EVALUATION_CLASSES

In [None]:
from typing import List
from nn_core.callbacks import NNTemplateCore
from nn_core.model_logging import NNLogger
from nn_core.serialization import NNCheckpointIO
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader
from pytorch_lightning import Callback
from la.pl_modules.classifier import Classifier

from la.utils.utils import build_callbacks

performance = []

# Iterate over models that have been trained on at least EVALUATION_CLASSES
for matrioska_idx, embeds in list(matrioskaidx2embeds.items())[len(EVALUATION_CLASSES) - 2 :]:
    embeds_dataset = matrioskaidx2embeds[matrioska_idx].filter(
        lambda x: x["y"] in EVALUATION_CLASSES,
    )
    embeds_dataset.set_format(type="torch", columns=["embeds", "y"])

    eval_train_loader = DataLoader(
        embeds_dataset["train"],
        batch_size=64,
        pin_memory=True,
        shuffle=True,
        num_workers=0,
    )

    eval_test_loader = DataLoader(
        embeds_dataset["test"],
        batch_size=64,
        pin_memory=True,
        shuffle=False,
        num_workers=0,
    )

    model = Classifier(
        input_dim=embeds_dataset["train"]["embeds"].size(1),
        num_classes=len(EVALUATION_CLASSES),
        lr=1e-4,
        deep=True,
        x_feature="embeds",
        y_feature="y",
    )

    callbacks: List[Callback] = build_callbacks(cfg.train.callbacks)

    storage_dir: str = cfg.core.storage_dir

    trainer = pl.Trainer(
        default_root_dir=storage_dir,
        logger=None,
        fast_dev_run=False,
        gpus=1,
        precision=32,
        max_epochs=50,
        accumulate_grad_batches=1,
        num_sanity_val_steps=2,
        gradient_clip_val=10.0,
        val_check_interval=1.0,
    )
    trainer.fit(model, train_dataloaders=eval_train_loader, val_dataloaders=eval_test_loader)

    classifier_model = trainer.model.eval().cpu().requires_grad_(False)
    run_results = trainer.test(model=classifier_model, dataloaders=eval_test_loader)[0]

    performance.extend(
        (
            Result(
                matrioska_idx=matrioska_idx,
                num_train_classes=len(matrioskaclasses[matrioska_idx]),
                metric_name="test_accuracy",
                score=run_results["accuracy"],
            ),
            Result(
                matrioska_idx=matrioska_idx,
                num_train_classes=len(matrioskaclasses[matrioska_idx]),
                metric_name="test_f1",
                score=run_results["f1"],
            ),
            Result(
                matrioska_idx=matrioska_idx,
                num_train_classes=len(matrioskaclasses[matrioska_idx]),
                metric_name="test_loss",
                score=run_results["test_loss"],
            ),
        )
    )

In [None]:
import pandas as pd

perf = pd.DataFrame(performance)

In [None]:
import plotly.express as px

perf = perf[perf["metric_name"] == "test_accuracy"]

fig = px.scatter(perf, x="matrioska_idx", y="score")

fig.update_layout(yaxis_title="accuracy", xaxis_title="# classes")

In [None]:
perf.to_json(PROJECT_ROOT / "paper_results" / "matrioska.json", orient="records")