# Plotting ablation study on OOD size

In [None]:
import sys

sys.path.append("/vol/biomedic3/mb121/calibration_exploration/")

from classification.load_model_and_config import (
    get_run_id_from_config,
    _clean_config_for_backward_compatibility,
)
from hydra import initialize, compose
from pathlib import Path
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

sns.set_style("whitegrid")

experiment = "base_chexpert"

from matplotlib import ticker

cm = sns.color_palette("Paired")
palette = {
    "TS": cm[1],
    "EBS": cm[9],
    "IRM": cm[3],
    "IROVATS": cm[7],
}

In [None]:
pretrained = False
model_names = [
    "resnet18",
    "resnet50",
    "mobilenetv2_100",
    "convnext_tiny",
    "vit_base_patch16_224",
    "efficientnet_b0",
]  #
configs_to_evaluate = [
    [
        f"experiment={experiment}",
        f"model.encoder_name={model}",
        f"model.pretrained={pretrained}",
    ]
    for model in model_names
]

with initialize(version_base=None, config_path="../configs"):
    run_ids = []

    for config_str in configs_to_evaluate:
        config = compose(
            config_name="config.yaml",
            overrides=config_str + ["trainer.label_smoothing=0.00"],
        )
        delattr(config.trainer, "lr")
        _clean_config_for_backward_compatibility(config)
        run_id = get_run_id_from_config(
            config, allow_multiple_runs=False, allow_return_none_if_no_runs=True
        )
        if run_id is not None:
            run_ids.append(run_id)

print(experiment, len(run_ids))

In [None]:
def retrieve_metrics_df(list_run_ids, metric):
    all_df = []
    for run_id in list_run_ids:
        output_dir = Path(
            f"/vol/biomedic3/mb121/calibration_exploration/outputs/run_{run_id}"
        )
        try:
            df = pd.read_csv(output_dir / f"ebs_ablation_metrics_{metric}.csv")
        except FileNotFoundError:
            print(str(output_dir / f"ebs_ablation_metrics_{metric}.csv") + " Not found")
            continue
        df.rename(columns={"Unnamed: 0": "domain"}, inplace=True)
        if "brightness_s0" in df.domain.values:
            df["domain"] = df["domain"].apply(
                lambda x: (int(x[-1]) + 1) if x != "id" else "id"
            )
        all_df.append(df)

    # print(len(all_df), len(list_run_ids))
    if len(all_df) == 0:
        return pd.DataFrame()
    return pd.concat(all_df)


all_dfs = {}
run_lists = [run_ids]
metrics = ["ECE", "Brier"]
for run_list in run_lists:
    for m in metrics:
        all_dfs[m] = retrieve_metrics_df(run_list, m)

In [None]:
df = all_dfs["ECE"]
df["domain"] = df["domain"].map(lambda x: "ID" if "ID" == x.upper() else "OOD")
df = df.melt(id_vars="domain")
df["OOD_rel_size"] = (
    df["variable"]
    .apply(lambda x: x[-4:].replace("_", "") if x != "probas" else 0)
    .astype(float)
)
# df['Calibrator with OOD'] = df['variable'].apply(lambda x: x.replace('calib_', '').upper() if x != 'probas' else x.upper())
df["Calibrator with OOD"] = df["variable"].apply(
    lambda x: (
        x[:-4].replace("0", "").replace("calib_", "").replace("_", "").upper()
        if x != "probas"
        else x.upper()
    )
)
df = pd.merge(
    df.loc[df.domain == "OOD"].drop(columns="domain"),
    df.loc[df.domain == "ID"].drop(columns="domain"),
    on=["variable", "Calibrator with OOD", "OOD_rel_size"],
    suffixes=["_OOD", "_ID"],
)
df = df.loc[df["OOD_rel_size"] > 0.005]
df = df.loc[df["Calibrator with OOD"] != "irm"]

In [None]:
f, ax = plt.subplots(1, 1, figsize=(3, 3))
sns.scatterplot(
    data=df.groupby(["variable", "Calibrator with OOD", "OOD_rel_size"]).mean(),
    x="value_ID",
    y="value_OOD",
    size="OOD_rel_size",
    size_norm=(0, 1),
    sizes=(20, 300),
    hue="Calibrator with OOD",
    edgecolor="dimgrey",
    linewidth=0.5,
    ax=ax,
    legend=False,
    palette=palette,
)
match experiment:
    case "base_chexpert":
        f.suptitle("$\mathbf{CXR}$")
    case "base_density":
        f.suptitle("$\mathbf{EMBED}$")
    case _:
        f.suptitle("$\mathbf{" + experiment.replace("base_", "").upper() + "}$")
ax.set_xlabel("ID")
ax.set_ylabel("OOD")

handles, labels = ax.get_legend_handles_labels()
ax.xaxis.set_major_formatter(ticker.StrMethodFormatter("{x:.2f}"))
ax.yaxis.set_major_formatter(ticker.StrMethodFormatter("{x:.2f}"))
f.savefig(
    f"/vol/biomedic3/mb121/calibration_exploration/outputs/figures/ablation/{experiment}.pdf",
    bbox_inches="tight",
)
# f2, ax2 = plt.subplots(1,1,figsize=(10,1))
# ax2.legend(
#     handles,
#     labels,
#     loc='center',  # Center the legend
#     #bbox_to_anchor=(0.5, -0.1),  # Position at bottom
#     ncol=10  # Display legend items in 3 columns
# )
# ax2.axis('off')
# f2.tight_layout()
# f2.savefig(f'/vol/biomedic3/mb121/calibration_exploration/outputs/figures/ablation/legend.pdf',bbox_inches='tight')