In [None]:
from rae.ui.evaluation import parse_checkpoints_tree
import logging
from collections import defaultdict
from enum import auto
from pathlib import Path
from typing import Callable, Dict, Optional, Tuple, Type, Union

import numpy as np
import pandas as pd
import rich
import torch
import typer
from torchmetrics import (
    ErrorRelativeGlobalDimensionlessSynthesis,
    MeanSquaredError,
    MetricCollection,
    MultiScaleStructuralSimilarityIndexMeasure,
    PeakSignalNoiseRatio,
    StructuralSimilarityIndexMeasure,
)

from rae.modules.enumerations import Output
from rae.pl_modules.pl_gautoencoder import LightningAutoencoder

try:
    # be ready for 3.10 when it drops
    from enum import StrEnum
except ImportError:
    from backports.strenum import StrEnum

logging.getLogger().setLevel(logging.ERROR)


BATCH_SIZE = 32


EXPERIMENT_ROOT = Path(".").parent
EXPERIMENT_CHECKPOINTS = EXPERIMENT_ROOT / "checkpoints"
PREDICTIONS_TSV = EXPERIMENT_ROOT / "predictions.tsv"
PERFORMANCE_TSV = EXPERIMENT_ROOT / "performance.tsv"

DATASET_SANITY = {
    "mnist": ("rae.data.vision.fmnist.FashionMNISTDataset", "test"),
    "fmnist": ("rae.data.vision.fmnist.FashionMNISTDataset", "test"),
    "cifar10": ("rae.data.vision.fmnist.FashionMNISTDataset", "test"),
    "cifar100": ("rae.data.vision.fmnist.FashionMNISTDataset", "test"),
}
MODEL_SANITY = {
    "vae": "rae.modules.ae.VanillaAE",
    "ae": "rae.modules.ae.VanillaAE",
    "relvae": "rae.modules.ae.VanillaAE",
    "relae": "rae.modules.ae.VanillaAE",
}


checkpoints, RUNS = parse_checkpoints_tree(EXPERIMENT_CHECKPOINTS)


DATASETS = sorted(checkpoints.keys())
MODELS = sorted(checkpoints[DATASETS[0]].keys())

In [None]:
preds = pd.read_csv(PREDICTIONS_TSV, sep="\t", index_col=0)
preds

In [None]:
perf = pd.read_csv(PERFORMANCE_TSV, sep="\t", index_col=0)
perf

In [None]:
aggregated_performnace = perf.drop(columns=["run_id"])
aggregated_performnace

In [None]:
aggregated_perfomance = aggregated_performnace.groupby(
    [
        "dataset_name",
        "model_type",
    ]
).agg([np.mean, np.std])

In [None]:
aggregated_perfomance = aggregated_perfomance.round(4)
aggregated_perfomance

In [None]:
aggregated_perfomance = (
    aggregated_perfomance[["mse", "ergas", "psnr", "ssim"]]
    .reindex(["ae", "vae", "rel_ae", "rel_vae"], level="model_type")
    .reindex(["mnist", "fmnist", "cifar10", "cifar100"], level="dataset_name")
)