In [None]:
%load_ext autoreload
%autoreload 2

import hydra
from hydra import initialize, compose

initialize(version_base=None, config_path=str("../conf"), job_name="matrioska_learning")

In [None]:
from nn_core.common import PROJECT_ROOT

# Instantiate torchvision dataset
cfg = compose(config_name="matrioska_learning", overrides=[])

# Load matrioska embeddings


# Evaluate matrioska models

In [None]:
# Decide which classes to evaluate on -- it may be interesting to change this
EVALUATION_CLASSES = {0, 1}
EVALUATION_CLASSES

In [None]:
import dataclasses


@dataclasses.dataclass
class Result:
    matrioska_idx: int
    num_train_classes: int
    metric_name: str
    score: float

In [None]:
from typing import List
from nn_core.callbacks import NNTemplateCore
from nn_core.model_logging import NNLogger
from nn_core.serialization import NNCheckpointIO
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader
from pytorch_lightning import Callback
from la.pl_modules.classifier import Classifier

from la.utils.utils import build_callbacks

performance = []
for matrioska_idx, embeds in matrioskaidx2embeds.items():
    embeds_dataset = matrioskaidx2embeds[matrioska_idx].filter(
        lambda x: x["y"] in EVALUATION_CLASSES,
    )
    embeds_dataset.set_format(type="torch", columns=["embeds", "y"])

    eval_train_loader = DataLoader(
        embeds_dataset["train"],
        batch_size=64,
        pin_memory=True,
        shuffle=True,
        num_workers=0,
    )

    eval_test_loader = DataLoader(
        embeds_dataset["test"],
        batch_size=64,
        pin_memory=True,
        shuffle=False,
        num_workers=0,
    )

    model = Classifier(
        input_dim=embeds_dataset["train"]["embeds"].size(1),
        num_classes=len(EVALUATION_CLASSES),
        lr=1e-4,
        deep=True,
        x_feature="embeds",
        y_feature="y",
    )

    callbacks: List[Callback] = build_callbacks(cfg.train.callbacks)

    storage_dir: str = cfg.core.storage_dir

    trainer = pl.Trainer(
        default_root_dir=storage_dir,
        logger=None,
        fast_dev_run=False,
        gpus=1,
        precision=32,
        max_epochs=50,
        accumulate_grad_batches=1,
        num_sanity_val_steps=2,
        gradient_clip_val=10.0,
        val_check_interval=1.0,
    )
    trainer.fit(model, train_dataloaders=eval_train_loader, val_dataloaders=eval_test_loader)

    classifier_model = trainer.model.eval().cpu().requires_grad_(False)
    run_results = trainer.test(model=classifier_model, dataloaders=eval_test_loader)[0]

    performance.extend(
        (
            Result(
                matrioska_idx=matrioska_idx,
                num_train_classes=None,
                metric_name="test_accuracy",
                score=run_results["accuracy"],
            ),
            Result(
                matrioska_idx=matrioska_idx,
                num_train_classes=None,
                metric_name="test_f1",
                score=run_results["f1"],
            ),
            Result(
                matrioska_idx=matrioska_idx,
                num_train_classes=None,
                metric_name="test_loss",
                score=run_results["test_loss"],
            ),
        )
    )

In [None]:
import pandas as pd

perf = pd.DataFrame(performance)

In [None]:
import plotly.express as px

px.scatter(perf, x="matrioska_idx", y="score", color="metric_name")