In [None]:
%load_ext autoreload
%autoreload 2

import hydra
from hydra import initialize, compose
from typing import Dict, List
from la.data.my_dataset_dict import MyDatasetDict

initialize(version_base=None, config_path=str("../conf"), job_name="matrioska_learning")

# Instantiate configuration

In [None]:
from nn_core.common import PROJECT_ROOT

# Instantiate torchvision dataset
cfg = compose(config_name="matrioska_learning", overrides=[])

# Instantiate dataset

In [None]:
from la.utils.io_utils import add_ids_to_dataset, load_data
from la.utils.io_utils import preprocess_dataset


original_dataset = dataset = load_data(cfg)  # .shard(num_shards=10, index=0)  # TODO remove sharding when done develop
dataset = preprocess_dataset(dataset, cfg)
dataset = add_ids_to_dataset(dataset)

In [None]:
img_size = dataset["train"][0]["x"].shape[1]
dataset

In [None]:
# hf specific variables
# (if a dataset change is needed, it is enough to redefine these variables...)
class_names = original_dataset["train"].features["fine_label"].names
class_idxs = [original_dataset["train"].features["fine_label"].str2int(class_name) for class_name in class_names]

class_names, class_idxs

In [None]:
import dataclasses


@dataclasses.dataclass
class Result:
    matrioska_idx: int
    num_train_classes: int
    metric_name: str
    score: float

# Define matrioska datasets

In [None]:
# Define matrioska parameters... just start with the first two classes
MATRIOSKA_START_NCLASSES = [0, 1]
LIMIT_N_CLASSES = 30
remanining_classes = sorted((set(class_idxs) - set(MATRIOSKA_START_NCLASSES)))[:LIMIT_N_CLASSES]
MATRIOSKA_START_NCLASSES, remanining_classes

In [None]:
# Generate matrioska classes
matrioskaclasses = [set(MATRIOSKA_START_NCLASSES + remanining_classes[:i]) for i in range(len(remanining_classes) + 1)]
matrioskaclasses

In [None]:
# Generate associated datasets
# TODO: do we want to have the same number of samples in all the datasets?
# I think not. This is more fair, if this works we are in the worst case scenario.
matrioskaidx2dataset = {
    i: dataset.filter(lambda row: row["y"] in matrioskaclasses[i]) for i in range(len(matrioskaclasses))
}

# Note that we are using the prefix convention for the classes, thus we have consistency
# between local and global classes ids... let's stay with that it is easier
matrioskaidx2dataset

# Train matrioska models

In [None]:
from datasets import Dataset, DatasetDict

from typing import Dict
import tqdm
import torch

matrioskaidx2dataset

matrioskaidx2embeds: Dict[str, DatasetDict] = {
    f"matrioska{matrioska_idx}": DatasetDict(train=DatasetDict(), test=DatasetDict())
    for matrioska_idx in range(len(matrioskaclasses))
}
len(matrioskaidx2embeds), matrioskaidx2embeds

In [None]:
from torch.utils.data import DataLoader
from tqdm import tqdm

HF_EMBEDDING_DATASET_PATH = PROJECT_ROOT / "matrioska_learning" / "hf_embedding_dataset"


def embed_and_save_samples(matrioskaidx2embeds, dataset, model, matrioska_idx, batch_size=1024) -> Dict:
    modes = ["train", "test"]

    model.cuda().eval()

    for mode in modes:
        mode_embeddings = []
        mode_ids = []
        mode_labels = []
        mode_loader = DataLoader(
            dataset[mode],
            batch_size=batch_size,
            pin_memory=True,
            shuffle=False,
            num_workers=4,
        )
        for batch in tqdm(mode_loader, desc=f"Embedding {mode} samples for {matrioska_idx}th matrioska"):
            x = batch["x"].to("cuda")
            mode_embeddings.extend(model(x)["embeds"].detach())
            mode_ids.extend(batch["id"])
            mode_labels.extend(batch["y"])

        matrioskaidx2embeds[f"matrioska{matrioska_idx}"][mode] = Dataset.from_dict(
            {
                "embeds": mode_embeddings,
                "id": mode_ids,
                "y": mode_labels,
            }
        )

    model.cpu()
    matrioskaidx2embeds[f"matrioska{matrioska_idx}"].save_to_disk(
        HF_EMBEDDING_DATASET_PATH / f"matrioska{matrioska_idx}"
    )

In [None]:
from typing import List
from nn_core.callbacks import NNTemplateCore
from nn_core.model_logging import NNLogger
from nn_core.serialization import NNCheckpointIO
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader
from pytorch_lightning import Callback

from la.utils.utils import build_callbacks


matrioskaidx2model = {}


for i in range(len(matrioskaclasses)):
    print(f"Training model {i}...")

    model: pl.LightningModule = hydra.utils.instantiate(
        cfg.nn.model,
        _recursive_=False,
        num_classes=len(matrioskaclasses[i]),
        model=cfg.nn.model.model,
        input_dim=img_size,
    )

    processed_dataset = matrioskaidx2dataset[i].map(
        desc=f"Preprocessing samples",
        function=lambda x: {"x": model.transform_func(x["x"])},
    )
    processed_dataset.set_format(type="torch", columns=["x", "y", "id"])

    train_loader = DataLoader(
        processed_dataset["train"],
        batch_size=512,
        pin_memory=True,
        shuffle=True,
        num_workers=4,
    )
    val_loader = DataLoader(
        processed_dataset["test"],
        batch_size=512,
        pin_memory=True,
        shuffle=True,
        num_workers=1,
    )

    template_core: NNTemplateCore = NNTemplateCore(
        restore_cfg=cfg.train.get("restore", None),
    )
    callbacks: List[Callback] = build_callbacks(cfg.train.callbacks, template_core)

    storage_dir: str = cfg.core.storage_dir

    logger: NNLogger = NNLogger(logging_cfg=cfg.train.logging, cfg=cfg, resume_id=template_core.resume_id)

    # Use this in case we need to restore models, search for it in the wandb UI
    logger.experiment.config["matrioska_idx"] = i

    trainer = pl.Trainer(
        default_root_dir=storage_dir,
        plugins=[NNCheckpointIO(jailing_dir=logger.run_dir)],
        logger=logger,
        callbacks=callbacks,
        **cfg.train.trainer,
    )
    trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

    matrioskaidx2model[i] = trainer.model.eval().cpu().requires_grad_(False)

    embed_and_save_samples(matrioskaidx2embeds, processed_dataset, matrioskaidx2model[i], i, batch_size=1024)
    logger.experiment.finish()

# Evalute matrioska models with classifier

In [None]:
from datasets import Dataset, DatasetDict

import tqdm
import torch
from nn_core.common import PROJECT_ROOT
from torch.utils.data import DataLoader
from tqdm import tqdm
import datasets

HF_EMBEDDING_DATASET_PATH = PROJECT_ROOT / "matrioska_learning" / "hf_embedding_dataset"
N_MATRIOSKA = 21

matrioskaidx2embeds = {
    i: datasets.load_from_disk(HF_EMBEDDING_DATASET_PATH / f"matrioska{i}") for i in range(N_MATRIOSKA)
}
len(matrioskaidx2embeds)

In [None]:
# Decide which classes to evaluate on -- it may be interesting to change this
EVALUATION_CLASSES = {0, 1, 2, 3, 4}
EVALUATION_CLASSES

In [None]:
from typing import List
from nn_core.callbacks import NNTemplateCore
from nn_core.model_logging import NNLogger
from nn_core.serialization import NNCheckpointIO
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader
from pytorch_lightning import Callback
from la.pl_modules.classifier import Classifier

from la.utils.utils import build_callbacks

performance = []

# Iterate over models that have been trained on at least EVALUATION_CLASSES
for matrioska_idx, embeds in list(matrioskaidx2embeds.items())[len(EVALUATION_CLASSES) - 2 :]:
    embeds_dataset = matrioskaidx2embeds[matrioska_idx].filter(
        lambda x: x["y"] in EVALUATION_CLASSES,
    )
    embeds_dataset.set_format(type="torch", columns=["embeds", "y"])

    eval_train_loader = DataLoader(
        embeds_dataset["train"],
        batch_size=64,
        pin_memory=True,
        shuffle=True,
        num_workers=0,
    )

    eval_test_loader = DataLoader(
        embeds_dataset["test"],
        batch_size=64,
        pin_memory=True,
        shuffle=False,
        num_workers=0,
    )

    model = Classifier(
        input_dim=embeds_dataset["train"]["embeds"].size(1),
        num_classes=len(EVALUATION_CLASSES),
        lr=1e-4,
        deep=True,
        x_feature="embeds",
        y_feature="y",
    )

    callbacks: List[Callback] = build_callbacks(cfg.train.callbacks)

    storage_dir: str = cfg.core.storage_dir

    trainer = pl.Trainer(
        default_root_dir=storage_dir,
        logger=None,
        fast_dev_run=False,
        gpus=1,
        precision=32,
        max_epochs=50,
        accumulate_grad_batches=1,
        num_sanity_val_steps=2,
        gradient_clip_val=10.0,
        val_check_interval=1.0,
    )
    trainer.fit(model, train_dataloaders=eval_train_loader, val_dataloaders=eval_test_loader)

    classifier_model = trainer.model.eval().cpu().requires_grad_(False)
    run_results = trainer.test(model=classifier_model, dataloaders=eval_test_loader)[0]

    performance.extend(
        (
            Result(
                matrioska_idx=matrioska_idx,
                num_train_classes=len(matrioskaclasses[matrioska_idx]),
                metric_name="test_accuracy",
                score=run_results["accuracy"],
            ),
            Result(
                matrioska_idx=matrioska_idx,
                num_train_classes=len(matrioskaclasses[matrioska_idx]),
                metric_name="test_f1",
                score=run_results["f1"],
            ),
            Result(
                matrioska_idx=matrioska_idx,
                num_train_classes=len(matrioskaclasses[matrioska_idx]),
                metric_name="test_loss",
                score=run_results["test_loss"],
            ),
        )
    )

In [None]:
import pandas as pd

perf = pd.DataFrame(performance)

In [None]:
import plotly.express as px

px.scatter(perf, x="matrioska_idx", y="score", color="metric_name")

# ~~Evaluate matrioska models with clusters~~

In [None]:
# NOTE: Deprectaed!
# # This experiments is deprecated:
# #  - Clustering metrics are not informative
# #  - The current code requires having the models in memory (can be easily adapted to load from disk the hf dataset)


# # Decide which classes to evaluate on -- it may be interesting to change this
# EVALUATION_CLASSES = MATRIOSKA_START_NCLASSES
# EVALUATION_CLASSES

# from torch.utils.data import DataLoader

# # Define the evaluation dataset according to chosen classes
# eval_dataset = dataset.filter(lambda row: row["y"] in set(EVALUATION_CLASSES))
# eval_dataset = eval_dataset.map(
#     desc=f"Preprocessing samples",
#     function=lambda x: {"x": model.transform_func(x["x"])},
# )
# eval_dataset.set_format(type="torch", columns=["x", "y"])

# eval_train_loader = DataLoader(
#     eval_dataset["test"],
#     batch_size=64,
#     pin_memory=True,
#     shuffle=True,
#     num_workers=0,
# )

# eval_test_loader = DataLoader(
#     eval_dataset["test"],
#     batch_size=64,
#     pin_memory=True,
#     shuffle=True,
#     num_workers=0,
# )

# trainer = Trainer(
#     accelerator="auto",
#     devices=1,
#     max_epochs=3,
#     logger=None,
#     # callbacks=[RichProgressBar()],
#     enable_progress_bar=True,
# )

# eval_dataset

# import dataclasses


# @dataclasses.dataclass
# class Result:
#     matrioska_idx: int
#     clusterer: str
#     metric_name: str
#     score: float

#     from sklearn.cluster import KMeans, BisectingKMeans
# import sklearn
# import torch


# model = matrioskaidx2model[0]


# def compute_eval_embedings(model, eval_loader):
#     eval_embeddings = []
#     eval_labels = []
#     for batch in eval_loader:
#         out = model(batch["x"])
#         eval_embeddings.append(out["embeds"])
#         eval_labels.append(batch["y"])

#     eval_embeddings = torch.cat(eval_embeddings, dim=0)
#     eval_labels = torch.cat(eval_labels, dim=0)
#     return eval_embeddings.detach().cpu().numpy(), eval_labels.detach().cpu().numpy()


# clusterizer = {
#     "kmeans": lambda embeds: KMeans(n_clusters=len(EVALUATION_CLASSES)).fit(embeds).labels_,
#     "bisect-kmeans": lambda embeds: BisectingKMeans(n_clusters=len(EVALUATION_CLASSES)).fit(embeds).labels_,
#     "dbscan": lambda embeds: sklearn.cluster.DBSCAN().fit(embeds).labels_,
#     "spectral": lambda embeds: sklearn.cluster.SpectralClustering(n_clusters=len(EVALUATION_CLASSES))
#     .fit(embeds)
#     .labels_,
#     "birch": lambda embeds: sklearn.cluster.Birch(n_clusters=len(EVALUATION_CLASSES)).fit(embeds).labels_,
#     "agglomerative": lambda embeds: sklearn.cluster.AgglomerativeClustering(n_clusters=len(EVALUATION_CLASSES))
#     .fit(embeds)
#     .labels_,
#     "optics": lambda embeds: sklearn.cluster.OPTICS().fit(embeds).labels_,
# }

# clustering_metric = {
#     "v_measure_score": lambda x, y_pred, y_true: sklearn.metrics.v_measure_score(y_true, y_pred),
#     "adjusted_mutual_info_score": lambda x, y_pred, y_true: sklearn.metrics.adjusted_mutual_info_score(y_true, y_pred),
#     "adjusted_rand_score": lambda x, y_pred, y_true: sklearn.metrics.adjusted_rand_score(y_true, y_pred),
#     "completeness_score": lambda x, y_pred, y_true: sklearn.metrics.completeness_score(y_true, y_pred),
#     "fowlkes_mallows_score": lambda x, y_pred, y_true: sklearn.metrics.fowlkes_mallows_score(y_true, y_pred),
#     # "homogeneity_completeness_v_measure": lambda x, y_pred, y_true: sklearn.metrics.homogeneity_completeness_v_measure(
#     # y_true, y_pred
#     # ),
#     "homogeneity_score": lambda x, y_pred, y_true: sklearn.metrics.homogeneity_score(y_true, y_pred),
#     "mutual_info_score": lambda x, y_pred, y_true: sklearn.metrics.mutual_info_score(y_true, y_pred),
#     "normalized_mutual_info_score": lambda x, y_pred, y_true: sklearn.metrics.normalized_mutual_info_score(
#         y_true, y_pred
#     ),
#     "rand_score": lambda x, y_pred, y_true: sklearn.metrics.rand_score(y_true, y_pred),
# }

# performance = []
# for i in range(len(matrioskaidx2model)):
#     result = trainer.test(model=matrioskaidx2model[i], dataloaders=eval_loader)[0]

#     performance.append(Result(matrioska_idx=i, clusterer="none", metric_name="test_acc", score=result["acc/test"]))
#     performance.append(Result(matrioska_idx=i, clusterer="none", metric_name="test_loss", score=result["loss/test"]))
#     eval_embeddings, eval_labels = compute_eval_embedings(model, eval_loader)

#     for clusterizer_name, clusterizer_func in clusterizer.items():
#         clustering_labels = clusterizer[clusterizer_name](eval_embeddings)

#         for metric_name, metric_func in clustering_metric.items():
#             performance.append(
#                 Result(
#                     matrioska_idx=i,
#                     clusterer=clusterizer_name,
#                     metric_name=metric_name,
#                     score=metric_func(x=eval_embeddings, y_pred=clustering_labels, y_true=eval_labels),
#                 )
#             )

# import pandas as pd
# import plotly.express as px

# perf = pd.DataFrame(performance)
# perf
# perf["ntrain_classes"] = perf["matrioska_idx"] + 2

# fig = px.scatter(
#     perf,
#     facet_col="clusterer",
#     facet_row="metric_name",
#     x="ntrain_classes",
#     y="score",
#     labels={"matrioska_idx": "Number of classes trained on", "test_acc": "Test accuracy"},
#     height=2000,
# )
# fig

# Classifying K classes in the latent space of N classes or K<N classes

In [None]:
from typing import List
from nn_core.callbacks import NNTemplateCore
from nn_core.model_logging import NNLogger
from nn_core.serialization import NNCheckpointIO
from datasets import Dataset, DatasetDict
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader
from pytorch_lightning import Callback
from la.utils.utils import build_callbacks
from torch.utils.data import DataLoader
from la.data.my_dataset_dict import MyDatasetDict
from tqdm import tqdm

In [None]:
import random

num_task_classes = 10
num_tasks = 2

In [None]:
import dataclasses
from typing import Union


@dataclasses.dataclass
class Task:
    class_idxs: list
    classes: list
    global_to_local: Dict
    id: str
    dataset: Union[DatasetDict, MyDatasetDict] = None
    embedded_dataset: MyDatasetDict = None
    model: pl.LightningModule = None

    def metadata(self):
        return {
            "id": self.id,
            "class_idxs": self.class_idxs,
            "classes": self.classes,
            "global_to_local": self.global_to_local,
            "model": self.model,
        }

## Task creation or loading

In [None]:
create_new_tasks = True

SUBTASK_DATASET_PATH = PROJECT_ROOT / "matrioska_learning" / "subtasks"
SUBTASK_DATASET_PATH.mkdir(exist_ok=True)

In [None]:
if not create_new_tasks:

    tasks = []
    for dataset_path in SUBTASK_DATASET_PATH.glob("*"):
        dataset = MyDatasetDict.load_from_disk(dataset_path)
        task = Task(
            id=dataset_path.name,
            class_idxs=dataset["metadata"]["class_idxs"],
            classes=dataset["metadata"]["classes"],
            global_to_local=dataset["metadata"]["global_to_local"],
            dataset=dataset,
        )

        tasks.append(task)

In [None]:
if create_new_tasks:
    tasks = []

    all_classes_task = Task(
        class_idxs=class_idxs,
        classes=class_names,
        global_to_local={i: i for i in range(len(class_names))},
        id="all_classes",
        dataset=MyDatasetDict(train=dataset["train"], test=dataset["test"]),
        embedded_dataset=MyDatasetDict(train=DatasetDict(), test=DatasetDict()),
    )

    tasks.append(all_classes_task)

In [None]:
from hydra.utils import instantiate

transform_func = instantiate(cfg.nn.model.transform_func)

if create_new_tasks:
    for i in range(num_tasks):
        task_class_indices = sorted(random.sample(class_idxs, k=num_task_classes))

        global_to_local = {global_idx: local_idx for local_idx, global_idx in enumerate(task_class_indices)}

        task_classes = [class_names[i] for i in task_class_indices]
        task_str = "_".join([str(i) for i in task_class_indices])

        task_dataset = dataset.filter(lambda row: row["y"] in task_class_indices)
        task_dataset = task_dataset.map(lambda row: {"y": global_to_local[row["y"]]})

        task_dataset = task_dataset.map(
            desc=f"Preprocessing samples",
            function=lambda x: {"x": transform_func(x["x"])},
        )
        task_dataset.set_format(type="torch", columns=["x", "y", "id"])

        embeds = MyDatasetDict(train=DatasetDict(), test=DatasetDict())

        task = Task(
            class_idxs=task_class_indices,
            classes=task_classes,
            global_to_local=global_to_local,
            id=task_str,
            embedded_dataset=embeds,
            dataset=task_dataset,
        )

        tasks.append(task)

tasks

In [None]:
from la.utils.utils import get_checkpoint_callback

SUBTASK_DATASET_PATH = PROJECT_ROOT / "matrioska_learning" / "subtasks"


if create_new_tasks:
    for task in tasks:

        print(f"Training model for task {task.id}")

        num_classes = len(task.classes)
        model: pl.LightningModule = hydra.utils.instantiate(
            cfg.nn.model,
            _recursive_=False,
            num_classes=num_classes,
            model=cfg.nn.model.model,
            input_dim=img_size,
        )

        train_loader = DataLoader(
            task.dataset["train"],
            batch_size=100,
            pin_memory=False,
            shuffle=True,
            num_workers=8,
        )

        val_loader = DataLoader(
            task.dataset["test"],
            batch_size=100,
            pin_memory=False,
            shuffle=False,
            num_workers=8,
        )

        template_core: NNTemplateCore = NNTemplateCore(
            restore_cfg=cfg.train.get("restore", None),
        )
        callbacks: List[Callback] = build_callbacks(cfg.train.callbacks, template_core)

        storage_dir: str = cfg.core.storage_dir
        logger: NNLogger = NNLogger(logging_cfg=cfg.train.logging, cfg=cfg, resume_id=template_core.resume_id)

        # Use this in case we need to restore models, search for it in the wandb UI
        logger.experiment.config["task_classes"] = task.id

        cfg.train.trainer.max_epochs = 1
        trainer = pl.Trainer(
            default_root_dir=storage_dir,
            plugins=[NNCheckpointIO(jailing_dir=logger.run_dir)],
            logger=logger,
            callbacks=callbacks,
            **cfg.train.trainer,
        )
        trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

        best_model_path = get_checkpoint_callback(callbacks).best_model_path

        task.model = {
            "path": best_model_path,
            "class": str(model.__class__.__module__ + "." + model.__class__.__qualname__),
        }

        logger.experiment.finish()

        task_dataset = MyDatasetDict(task.dataset)
        task_dataset["metadata"] = task.metadata()

        Path(f"{SUBTASK_DATASET_PATH}/{task.id}").mkdir(exist_ok=True, parents=True)
        task_dataset.save_to_disk(f"{SUBTASK_DATASET_PATH}/{task.id}")

        task_dataset = MyDatasetDict.load_from_disk(f"{SUBTASK_DATASET_PATH}/{task.id}")
        task.dataset = task_dataset

In [None]:
from pydoc import locate
from nn_core.serialization import load_model


SUBTASK_EMBEDDING_PATH = PROJECT_ROOT / "matrioska_learning" / "subtasks_embeddings"


def embed_and_save_samples(task, batch_size=100) -> Dict:
    modes = ["train", "test"]

    model_path = task.model["path"]
    model_class = locate(task.model["class"])

    model = load_model(model_class, checkpoint_path=Path(model_path + ".zip"))
    model.eval().cuda()

    for mode in modes:
        mode_embeddings = []
        mode_ids = []
        mode_labels = []

        mode_loader = DataLoader(
            dataset[mode],
            batch_size=batch_size,
            pin_memory=True,
            shuffle=False,
            num_workers=4,
        )

        for batch in tqdm(mode_loader, desc=f"Embedding {mode} samples for task {task.id}"):
            x = batch["x"].to("cuda")
            embeds = model(x)["embeds"].detach()

            mode_embeddings.extend(embeds)
            mode_ids.extend(batch["id"])
            mode_labels.extend(batch["y"])

        task.embedded_dataset[mode] = Dataset.from_dict(
            {
                "embeds": mode_embeddings,
                "id": mode_ids,
                "y": mode_labels,
            }
        )

    model.cpu()
    task.embedded_dataset.metadata = task.metadata()

    (SUBTASK_EMBEDDING_PATH / task.id).mkdir(exist_ok=True, parents=True)
    task.embedded_dataset.save_to_disk(SUBTASK_EMBEDDING_PATH / task.id)


for task in tasks:
    embed_and_save_samples(task)

In [None]:
for task in tasks:
    print(task.model)

## Latent space analysis via classifier

In [None]:
from typing import List
from nn_core.callbacks import NNTemplateCore
from nn_core.model_logging import NNLogger
from nn_core.serialization import NNCheckpointIO
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader
from pytorch_lightning import Callback
from la.pl_modules.classifier import Classifier

from la.utils.utils import build_callbacks

performance = []


for task in tasks:
    embeds_dataset = task.embedded_dataset
    embeds_dataset.set_format(type="torch", columns=["embeds", "y"])

    embeds_dataset = embeds_dataset.filter(lambda row: row["y"] in task.class_idxs)

    eval_train_loader = DataLoader(
        embeds_dataset["train"],
        batch_size=64,
        pin_memory=True,
        shuffle=True,
        num_workers=8,
    )

    eval_test_loader = DataLoader(
        embeds_dataset["test"],
        batch_size=64,
        pin_memory=True,
        shuffle=False,
        num_workers=0,
    )

    model = Classifier(
        input_dim=embeds_dataset["train"]["embeds"].size(1),
        num_classes=len(task.classes),
        lr=1e-4,
        deep=True,
        x_feature="embeds",
        y_feature="y",
    )

    callbacks: List[Callback] = build_callbacks(cfg.train.callbacks)

    storage_dir: str = cfg.core.storage_dir

    trainer = pl.Trainer(
        default_root_dir=storage_dir,
        logger=None,
        fast_dev_run=False,
        gpus=1,
        precision=32,
        max_epochs=250,
        accumulate_grad_batches=1,
        num_sanity_val_steps=2,
        gradient_clip_val=10.0,
        val_check_interval=5.0,
    )
    trainer.fit(model, train_dataloaders=eval_train_loader, val_dataloaders=eval_test_loader)

    classifier_model = trainer.model.eval().cpu().requires_grad_(False)
    run_results = trainer.test(model=classifier_model, dataloaders=eval_test_loader)[0]

    performance.extend(
        (
            Result(
                matrioska_idx=task.id,
                num_train_classes=len(task.classes),
                metric_name="test_accuracy",
                score=run_results["accuracy"],
            ),
            Result(
                matrioska_idx=task.id,
                num_train_classes=len(task.classes),
                metric_name="test_f1",
                score=run_results["f1"],
            ),
            Result(
                matrioska_idx=task.id,
                num_train_classes=len(task.classes),
                metric_name="test_loss",
                score=run_results["test_loss"],
            ),
        )
    )