In [None]:
import os
import sys
import torch
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

experiment_dir = "../"
sys.path.append(experiment_dir)
from configs import (
    celeba_im2im_ncsnpp,
    celeba_ncsnpp,
    abdomen_im2im_ncsnpp,
    abdomen_ncsnpp,
)
from configs.utils import get_config
from utils import calibration_dict

sns.set_theme()
sns.set_context("paper", font_scale=2)

fig_dir = os.path.join(experiment_dir, "figures", "risk_control")
os.makedirs(fig_dir, exist_ok=True)

In [None]:
datasets = [("celeba", 0.1), ("abdomen", 0.05)]
for dataset, epsilon in datasets:
    dataset_calibration_dir = os.path.join(experiment_dir, "calibration", dataset)

    val_loss_df = []
    for model_name, calibration in tqdm(calibration_dict.items()):
        config_name = f"{dataset}_{model_name}"
        config_calibration_dir = os.path.join(dataset_calibration_dir, config_name)

        for _calibration in calibration:
            calibration_name = _calibration["name"]
            if calibration_name == "rcps":
                calibration_dir = os.path.join(config_calibration_dir, calibration_name)

                calibration_title = "RCPS"
            if calibration_name == "krcps":
                k_rcps_name = "k_rcps_128_100_8"
                calibration_dir = os.path.join(config_calibration_dir, k_rcps_name)

                calibration_title = r"$K$-RCPS"

            uq_name = _calibration["uq_name"]
            for _uq_name in uq_name:
                if _uq_name == "quantile_regression":
                    uq_title = "QR"
                if _uq_name == "std":
                    uq_title = "MC-Dropout"
                if _uq_name == "conffusion_multiplicative":
                    uq_title = "Conffusion (Multiplicative)"
                if _uq_name == "conffusion_additive":
                    uq_title = "Conffusion (Additive)"
                if _uq_name == "naive_sampling_additive":
                    uq_title = "Naive quantiles"
                if _uq_name == "calibrated_quantile":
                    uq_title = "Calibrated quantiles"

                val_loss = torch.load(
                    os.path.join(calibration_dir, f"{_uq_name}_val_loss.pt")
                )
                val_loss = val_loss[:20]
                assert val_loss.size(0) == 20
                r_val_loss = torch.mean(val_loss, dim=(1, 2))

                title = f"{uq_title} \n {calibration_title}"
                val_loss_df.extend(
                    [{"calibration": title, "risk": l.item()} for l in r_val_loss]
                )

    val_loss_df = pd.DataFrame(val_loss_df)
    _, ax = plt.subplots(figsize=(16, 9))
    ax.axhline(epsilon, color="black", linestyle="--", linewidth=2, label=r"$\epsilon$")
    sns.violinplot(data=val_loss_df, x="calibration", y="risk", ax=ax, cut=2)
    ax.set_ylim(0, 0.105)
    ax.set_yticks([0, 0.05, 0.10])
    ax.legend()
    plt.xticks(rotation=30)
    plt.savefig(os.path.join(fig_dir, f"{dataset}.png"), bbox_inches="tight")
    plt.savefig(os.path.join(fig_dir, f"{dataset}.pdf"), bbox_inches="tight")
    plt.show()
    plt.close()