# Ensemble plots

In [None]:
import sys

sys.path.append("/vol/biomedic3/mb121/calibration_exploration/")

from collections import defaultdict
from pathlib import Path
import numpy as np


import pandas as pd
from default_paths import ROOT

import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import ticker

sns.set_style("whitegrid")

In [None]:
evaluate_foundation_models = True

In [None]:
all_experiments = [
    "base_density",
    "base_camelyon",
    "base_retina",
    "base_living17",
    "base_entity30",
    "base_domainnet",
    "base_icam",
    "base_chexpert",
]


cm = sns.color_palette("Paired")
palette = {
    "PROBAS": "black",
    "TS": cm[0],
    "IRM": cm[2],
    "ETS": cm[4],
    "IROVATS": cm[6],
    "EBS_MINUS": cm[8],
    "EBS": cm[9],
    "TS_WITH_OOD": cm[1],
    "IRM_WITH_OOD": cm[3],
    "IROVATS_WITH_OOD": cm[7],
    "PSEUDO_LOGITS_PROBAS_ENS_TS": cm[0],
    "PSEUDO_LOGITS_PROBAS_ENS_EBS": cm[9],
    "PSEUDO_LOGITS_PROBAS_ENS_TS_OOD": cm[1],
}


def fn(x, pos):
    return f"{x:.3f}".lstrip("0")

## Main paper figure

In [None]:
metrics_list = ["ECE", "Brier"]
c2 = [
    "probas",
    "calib_ebs",
    "calib_ts",
    "calib_irm",
    "calib_irovats",
    "calib_irm_with_ood",
    "calib_ts_with_ood",
    "calib_irovats_with_ood",
    "pseudo_logits_probas_ens",
    "pseudo_logits_probas_ens_ts",
    "pseudo_logits_probas_ens_ebs",
    "pseudo_logits_probas_ens_ts_ood",
]
for experiment in all_experiments:
    all_dfs = defaultdict(list)
    all_baseline_dfs = defaultdict(list)
    for k in metrics_list:
        for ls, er in [(0, 0), (0.05, 0.1)]:
            ensemble_output_dir = ROOT / Path(
                f"outputs/ensembling_results/{experiment}/{float(ls):.2f}_{float(er):.2f}_{evaluate_foundation_models}"
            )
            df = pd.read_csv(
                ensemble_output_dir / f"all_ensemble_metrics_{k}.csv", index_col=0
            )
            clean_name = {c: c.replace("calib_", "").upper() for c in c2}
            df.rename(columns=clean_name, inplace=True)
            df["domain"] = df.index.values
            df["domain"] = df["domain"].map(lambda x: "ID" if "id" == x else "OOD")
            df["experiment"] = experiment
            df["foundation"] = evaluate_foundation_models
            df["metric"] = k
            df["loss"] = "CE" if (ls == er == 0) else "ER+LS"
            all_dfs[k].append(df)

            df = pd.read_csv(
                ensemble_output_dir / f"individual_run_metrics_{k}.csv", index_col=0
            )
            clean_name = {c: c.replace("calib_", "").upper() for c in c2}
            df.rename(columns=clean_name, inplace=True)
            df["domain"] = df.index.values
            df["domain"] = df["domain"].map(lambda x: "ID" if "id" == x else "OOD")
            df["experiment"] = experiment
            df["metric"] = k
            df["loss"] = "CE" if (ls == er == 0) else "ER+LS"
            df["foundation"] = evaluate_foundation_models
            all_baseline_dfs[k].append(df)

    for j, m in enumerate(["ECE", "Brier"]):
        f, ax = plt.subplots(1, 1, figsize=(2.3, 2.9))
        df1 = pd.concat(all_dfs[m])
        df2 = df1.drop(columns=["metric", "foundation"])
        df3 = df2.melt(
            value_vars=[
                "EBS",
                "TS",
                "PROBAS",
                "IRM",
                "TS_WITH_OOD",
                "IRM_WITH_OOD",
                "IROVATS",
                "IROVATS_WITH_OOD",
                "pseudo_logits_probas_ens_ts".upper(),
                "pseudo_logits_probas_ens_ebs".upper(),
                "pseudo_logits_probas_ens_ts_ood".upper(),
            ],
            id_vars=["experiment", "loss", "domain"],
        )
        df3["treatment"] = df3["loss"] + " + " + df3["variable"]
        df_calib_after = df3.loc[
            df3["treatment"].isin(
                [
                    "CE + " + "pseudo_logits_probas_ens_ebs".upper(),
                    "CE + " + "pseudo_logits_probas_ens_ts".upper(),
                    "CE + " + "pseudo_logits_probas_ens_ts_ood".upper(),
                ]
            )
        ]

        c3 = [
            "CE + EBS",
            "CE + TS",
            "CE + PROBAS",
            "CE + TS_WITH_OOD",
            "ER+LS + PROBAS",
        ]
        df3 = df3.loc[df3["treatment"].isin(c3)]

        df_calib_after = (
            pd.merge(
                df_calib_after.loc[df_calib_after.domain == "ID"].drop(
                    columns="domain"
                ),
                df_calib_after.loc[df_calib_after.domain == "OOD"].drop(
                    columns="domain"
                ),
                on=["loss", "variable", "treatment", "experiment"],
                suffixes=["_id", "_ood"],
            )
            .groupby(["loss", "variable", "treatment", "experiment"])
            .mean()
        )
        df5 = (
            pd.merge(
                df3.loc[df3.domain == "ID"].drop(columns="domain"),
                df3.loc[df3.domain == "OOD"].drop(columns="domain"),
                on=["loss", "variable", "treatment", "experiment"],
                suffixes=["_id", "_ood"],
            )
            .groupby(["loss", "variable", "treatment", "experiment"])
            .mean()
        )

        df1_base = pd.concat(all_baseline_dfs[m])
        df1_base = df1_base.drop(columns=["metric", "foundation"])
        df1_base = df1_base.melt(
            value_vars=[
                "EBS",
                "TS",
                "PROBAS",
                "IRM",
                "TS_WITH_OOD",
                "IRM_WITH_OOD",
                "IROVATS",
                "IROVATS_WITH_OOD",
            ],
            id_vars=["experiment", "loss", "domain"],
        )
        df1_base["treatment"] = df1_base["loss"] + " + " + df1_base["variable"]
        df1_base = df1_base.loc[df1_base["treatment"].isin(c3)]
        df1_base = (
            pd.merge(
                df1_base.loc[df1_base.domain == "ID"].drop(columns="domain"),
                df1_base.loc[df1_base.domain == "OOD"].drop(columns="domain"),
                on=["loss", "variable", "treatment", "experiment"],
                suffixes=["_id", "_ood"],
            )
            .groupby(["loss", "variable", "treatment", "experiment"])
            .mean()
        )
        sns.scatterplot(
            data=df1_base,
            style="loss",
            hue="variable",
            y="value_ood",
            x="value_id",
            s=30,
            ax=ax,
            markers={"CE": "o", "ER": "P", "LS": "X", "Focal": "v", "ER+LS": "D"},
            palette=palette,
            legend=False,
        )

        df5 = df5.reset_index()
        for xp, yp, var in zip(
            np.stack([df5.value_id.values, df1_base.value_id.values], 1),
            np.stack([df5.value_ood.values, df1_base.value_ood.values], 1),
            df5.variable.values,
        ):
            ax.plot(xp, yp, c=palette[var], ls=":", linewidth=1)

        df1_base = df1_base.reset_index()
        df1_base = df1_base.loc[
            df1_base["treatment"].isin(["CE + EBS", "CE + TS", "CE + TS_WITH_OOD"])
        ]
        df_calib_after = df_calib_after.reset_index()

        for xp, yp, var in zip(
            np.stack([df_calib_after.value_id.values, df1_base.value_id.values], 1),
            np.stack([df_calib_after.value_ood.values, df1_base.value_ood.values], 1),
            df1_base["variable"].values,
        ):
            ax.plot(xp, yp, c=palette[var], ls=":", linewidth=1)

        sns.scatterplot(
            data=df5,
            style="loss",
            hue="variable",
            y="value_ood",
            x="value_id",
            s=100,
            legend=True,
            ax=ax,
            markers={"CE": "o", "ER": "P", "LS": "X", "Focal": "v", "ER+LS": "D"},
            palette=palette,
            edgecolor="dimgrey",
            linewidth=0.5,
        )

        sns.scatterplot(
            data=df_calib_after,
            hue="variable",
            y="value_ood",
            x="value_id",
            s=100,
            legend=True,
            ax=ax,  # style='loss',
            markers={"CE": "o", "ER+LS": "P"},
            palette=palette,
            marker="$\circ$",
            ec="face",
            linewidth=1.5,  # "$\diamond$"
        )

        # Get handles and labels directly from the axes
        handles, labels = ax.get_legend_handles_labels()
        # Remove the legend from its current position
        if m == "ECE" and experiment == "base_icam":
            ax.set_xlim((0, 0.25))
            ax.set_ylim((0, 0.30))
        ax.get_legend().remove()
        ax.set_xlabel("")

        (
            ax.set_ylabel("$\mathbf{" + m.upper() + "}$" + " - SHIFTED")
            if experiment == "base_living17"
            else ax.set_ylabel("")
        )
        ax.set_xlabel("ID")
        ax.xaxis.set_major_formatter(ticker.FuncFormatter(fn))
        ax.yaxis.set_major_formatter(ticker.FuncFormatter(fn))
        if m == "ECE":
            match experiment:
                case "base_chexpert":
                    f.suptitle("$\mathbf{CXR}$")
                case "base_density":
                    f.suptitle("$\mathbf{EMBED}$")
                case _:
                    f.suptitle(
                        "$\mathbf{" + experiment.replace("base_", "").upper() + "}$"
                    )
        f.savefig(
            f"/vol/biomedic3/mb121/calibration_exploration/outputs/figures/ensemble/main/{m}/{experiment}_{evaluate_foundation_models}.pdf",
            bbox_inches="tight",
        )

## Appendix: pre/post calibration with ERLS

In [None]:
for experiment in all_experiments:
    all_dfs = defaultdict(list)
    all_baseline_dfs = defaultdict(list)
    for k in metrics_list:
        for ls, er in [(0, 0), (0.05, 0.1)]:
            for evaluate_foundation_models in [False]:
                ensemble_output_dir = ROOT / Path(
                    f"outputs/ensembling_results/{experiment}/{float(ls):.2f}_{float(er):.2f}_{evaluate_foundation_models}"
                )
                df = pd.read_csv(
                    ensemble_output_dir / f"all_ensemble_metrics_{k}.csv", index_col=0
                )
                clean_name = {c: c.replace("calib_", "").upper() for c in c2}
                df.rename(columns=clean_name, inplace=True)
                df["domain"] = df.index.values
                df["domain"] = df["domain"].map(lambda x: "ID" if "id" == x else "OOD")
                df["experiment"] = experiment
                df["foundation"] = evaluate_foundation_models
                df["metric"] = k
                df["loss"] = "CE" if (ls == er == 0) else "ER+LS"
                all_dfs[k].append(df)

                df = pd.read_csv(
                    ensemble_output_dir / f"individual_run_metrics_{k}.csv", index_col=0
                )
                clean_name = {c: c.replace("calib_", "").upper() for c in c2}
                df.rename(columns=clean_name, inplace=True)
                df["domain"] = df.index.values
                df["domain"] = df["domain"].map(lambda x: "ID" if "id" == x else "OOD")
                df["experiment"] = experiment
                df["metric"] = k
                df["loss"] = "CE" if (ls == er == 0) else "ER+LS"
                df["foundation"] = evaluate_foundation_models
                all_baseline_dfs[k].append(df)

    sns.set_style("whitegrid")
    f, ax = plt.subplots(1, 2, figsize=(8, 4))

    for j, m in enumerate(["ECE", "Brier"]):
        df1 = pd.concat(all_dfs[m])
        df2 = df1.drop(
            columns=["metric", "foundation"]
        )  # .groupby(['experiment', 'loss', 'domain'], as_index=False).mean()
        df3 = df2.melt(
            value_vars=[
                "EBS",
                "TS",
                "PROBAS",
                "IRM",
                "TS_WITH_OOD",
                "IRM_WITH_OOD",
                "IROVATS",
                "IROVATS_WITH_OOD",
                "pseudo_logits_probas_ens_ts".upper(),
                "pseudo_logits_probas_ens_ebs".upper(),
                "pseudo_logits_probas_ens_ts_ood".upper(),
            ],
            id_vars=["experiment", "loss", "domain"],
        )
        df3["treatment"] = df3["loss"] + " + " + df3["variable"]
        df_calib_after = df3.loc[
            df3["treatment"].isin(
                [
                    "ER+LS + " + "pseudo_logits_probas_ens_ebs".upper(),
                    "ER+LS + " + "pseudo_logits_probas_ens_ts".upper(),
                    "ER+LS + " + "pseudo_logits_probas_ens_ts_ood".upper(),
                ]
            )
        ]

        c3 = [
            "ER+LS + EBS",
            "ER+LS + TS",
            "ER+LS + TS_WITH_OOD",
            "ER+LS + PROBAS",
        ]  # 'CE + IROVATS', 'CE + IROVATS_WITH_OOD', 'CE + IRM_WITH_OOD', 'CE + IRM',
        df_base_ce = df3.loc[df3["treatment"] == "CE + TS"]
        df3 = df3.loc[df3["treatment"].isin(c3)]

        df_calib_after = (
            pd.merge(
                df_calib_after.loc[df_calib_after.domain == "ID"].drop(
                    columns="domain"
                ),
                df_calib_after.loc[df_calib_after.domain == "OOD"].drop(
                    columns="domain"
                ),
                on=["loss", "variable", "treatment", "experiment"],
                suffixes=["_id", "_ood"],
            )
            .groupby(["loss", "variable", "treatment", "experiment"])
            .mean()
        )
        df5 = (
            pd.merge(
                df3.loc[df3.domain == "ID"].drop(columns="domain"),
                df3.loc[df3.domain == "OOD"].drop(columns="domain"),
                on=["loss", "variable", "treatment", "experiment"],
                suffixes=["_id", "_ood"],
            )
            .groupby(["loss", "variable", "treatment", "experiment"])
            .mean()
        )
        sns.scatterplot(
            data=df5,
            style="loss",
            hue="variable",
            y="value_ood",
            x="value_id",
            s=140,
            legend=j == 0,
            ax=ax[j],
            markers={"CE": "o", "ER": "P", "LS": "X", "Focal": "v", "ER+LS": "D"},
            palette=palette,
            edgecolor="dimgrey",
            linewidth=0.5,
        )

        sns.scatterplot(
            data=df_calib_after,
            hue="variable",
            y="value_ood",
            x="value_id",
            s=275,
            legend=j == 0,
            ax=ax[j],  # style='loss',
            marker="$\diamond$",
            ec="face",
            palette=palette,  # edgecolor='dimgrey', linewidth=0.5
        )

        sns.scatterplot(
            data=pd.merge(
                df_base_ce.loc[df_base_ce.domain == "ID"].drop(columns="domain"),
                df_base_ce.loc[df_base_ce.domain == "OOD"].drop(columns="domain"),
                on=["loss", "variable", "treatment", "experiment"],
                suffixes=["_id", "_ood"],
            )
            .groupby(["loss", "variable", "treatment", "experiment"])
            .mean(),
            hue="variable",
            y="value_ood",
            x="value_id",
            s=140,
            legend=j == 0,
            ax=ax[j],
            style="loss",
            markers={"CE": "o"},
            palette=palette,
            edgecolor="red",
            linewidth=0.5,
        )

        df1_base = pd.concat(all_baseline_dfs[m])
        df1_base = df1_base.drop(
            columns=["metric", "foundation"]
        )  # .groupby(['experiment', 'loss', 'domain'], as_index=False).mean()
        df1_base = df1_base.melt(
            value_vars=[
                "EBS",
                "TS",
                "PROBAS",
                "IRM",
                "TS_WITH_OOD",
                "IRM_WITH_OOD",
                "IROVATS",
                "IROVATS_WITH_OOD",
            ],
            id_vars=["experiment", "loss", "domain"],
        )
        df1_base["treatment"] = df1_base["loss"] + " + " + df1_base["variable"]
        df1_base = df1_base.loc[df1_base["treatment"].isin(c3)]
        df1_base = (
            pd.merge(
                df1_base.loc[df1_base.domain == "ID"].drop(columns="domain"),
                df1_base.loc[df1_base.domain == "OOD"].drop(columns="domain"),
                on=["loss", "variable", "treatment", "experiment"],
                suffixes=["_id", "_ood"],
            )
            .groupby(["loss", "variable", "treatment", "experiment"])
            .mean()
        )
        sns.scatterplot(
            data=df1_base,
            style="loss",
            hue="variable",
            y="value_ood",
            x="value_id",
            s=40,
            ax=ax[j],
            markers={"CE": "o", "ER": "P", "LS": "X", "Focal": "v", "ER+LS": "D"},
            palette=palette,
            legend=False,
        )

        df5 = df5.reset_index()
        df5 = df5.loc[df5["treatment"].isin(c3)]
        for xp, yp, var in zip(
            np.stack([df5.value_id.values, df1_base.value_id.values], 1),
            np.stack([df5.value_ood.values, df1_base.value_ood.values], 1),
            df5.variable.values,
        ):
            ax[j].plot(xp, yp, c=palette[var], ls=":")
        df1_base = df1_base.reset_index()
        df1_base = df1_base.loc[
            df1_base["treatment"].isin(
                ["ER+LS + EBS", "ER+LS + TS", "ER+LS + TS_WITH_OOD"]
            )
        ]
        for xp, yp, var in zip(
            np.stack([df_calib_after.value_id.values, df1_base.value_id.values], 1),
            np.stack([df_calib_after.value_ood.values, df1_base.value_ood.values], 1),
            df1_base["variable"].values,
        ):
            ax[j].plot(xp, yp, c=palette[var], ls=":")

        if 0 == j:
            # Get handles and labels directly from the axes
            handles, labels = ax[0].get_legend_handles_labels()
            # Remove the legend from its current position
            ax[0].get_legend().remove()
        ax[j].set_xlabel("")

    ax[0].set_ylabel("SHIFTED")
    ax[0].set_xlabel("ID")
    ax[1].set_ylabel("SHIFTED")
    ax[1].set_xlabel("ID")
    ax[0].set_title("ECE")
    ax[1].set_title("Brier")
    ax[1].set_ylabel("")
    f2, ax2 = plt.subplots(1, 1, figsize=(10, 1))
    ax2.legend(
        handles,
        labels,
        loc="center",  # Center the legend
        ncol=12,  # Display legend items in 3 columns
    )
    ax2.axis("off")
    f2.tight_layout()
    f2.savefig(
        f"/vol/biomedic3/mb121/calibration_exploration/outputs/figures/ensemble/erls/legend.pdf",
        bbox_inches="tight",
    )
    match experiment:
        case "base_chexpert":
            f.suptitle("$\mathbf{CXR}$")
        case "base_density":
            f.suptitle("$\mathbf{EMBED}$")
        case _:
            f.suptitle("$\mathbf{" + experiment.replace("base_", "").upper() + "}$")
    f.savefig(
        f"/vol/biomedic3/mb121/calibration_exploration/outputs/figures/ensemble/erls/{experiment}.pdf",
        bbox_inches="tight",
    )

## Appendix: full comparison of post-hoc calibrator (pre + CE)

In [None]:
all_dfs = defaultdict(list)
all_baseline_dfs = defaultdict(list)
for k in metrics_list:
    for ls, er in [(0, 0), (0.05, 0.1)]:
        for evaluate_foundation_models in [False]:
            ensemble_output_dir = ROOT / Path(
                f"outputs/ensembling_results/{experiment}/{float(ls):.2f}_{float(er):.2f}_{evaluate_foundation_models}"
            )
            df = pd.read_csv(
                ensemble_output_dir / f"all_ensemble_metrics_{k}.csv", index_col=0
            )
            clean_name = {c: c.replace("calib_", "").upper() for c in c2}
            df.rename(columns=clean_name, inplace=True)
            df["domain"] = df.index.values
            df["domain"] = df["domain"].map(lambda x: "ID" if "id" == x else "OOD")
            df["experiment"] = experiment
            df["foundation"] = evaluate_foundation_models
            df["metric"] = k
            df["loss"] = "CE" if (ls == er == 0) else "ER+LS"
            all_dfs[k].append(df)

            df = pd.read_csv(
                ensemble_output_dir / f"individual_run_metrics_{k}.csv", index_col=0
            )
            clean_name = {c: c.replace("calib_", "").upper() for c in c2}
            df.rename(columns=clean_name, inplace=True)
            df["domain"] = df.index.values
            df["domain"] = df["domain"].map(lambda x: "ID" if "id" == x else "OOD")
            df["experiment"] = experiment
            df["metric"] = k
            df["loss"] = "CE" if (ls == er == 0) else "ER+LS"
            df["foundation"] = evaluate_foundation_models
            all_baseline_dfs[k].append(df)

sns.set_style("whitegrid")
f, ax = plt.subplots(1, 2, figsize=(8, 4))
colors_blue = sns.color_palette("Blues_r", n_colors=5)
my_palette = {
    "CE + PROBAS": colors_blue[0],
    "CE + TS": colors_blue[1],
    "CE + IRM": colors_blue[2],
    "CE + EBS": colors_blue[3],
    "CE + EBS_DAC": colors_blue[4],
    "ER+LS + PROBAS": "darkred",
    "ER+LS + EBS": "indianred",
}


for j, m in enumerate(["ECE", "Brier"]):
    df1 = pd.concat(all_dfs[m])
    df2 = df1.drop(columns=["metric", "foundation"])
    df3 = df2.melt(
        value_vars=[
            "EBS",
            "TS",
            "PROBAS",
            "IRM",
            "TS_WITH_OOD",
            "IRM_WITH_OOD",
            "IROVATS",
            "IROVATS_WITH_OOD",
        ],
        id_vars=["experiment", "loss", "domain"],
    )
    df3["treatment"] = df3["loss"] + " + " + df3["variable"]

    c3 = [
        "CE + EBS",
        "CE + TS",
        "CE + PROBAS",
        "CE + TS_WITH_OOD",
        "CE + IROVATS",
        "CE + IROVATS_WITH_OOD",
        "CE + IRM_WITH_OOD",
        "CE + IRM",
    ]  #
    df3 = df3.loc[df3["treatment"].isin(c3)]

    df5 = (
        pd.merge(
            df3.loc[df3.domain == "ID"].drop(columns="domain"),
            df3.loc[df3.domain == "OOD"].drop(columns="domain"),
            on=["loss", "variable", "treatment", "experiment"],
            suffixes=["_id", "_ood"],
        )
        .groupby(["loss", "variable", "treatment", "experiment"])
        .mean()
    )
    sns.scatterplot(
        data=df5,
        style="loss",
        hue="variable",
        y="value_ood",
        x="value_id",
        s=140,
        legend=j == 0,
        ax=ax[j],
        markers={"CE": "o", "ER": "P", "LS": "X", "Focal": "v", "ER+LS": "D"},
        palette=palette,
        edgecolor="dimgrey",
        linewidth=0.5,
    )

    df1_base = pd.concat(all_baseline_dfs[m])
    df1_base = df1_base.drop(columns=["metric", "foundation"])
    df1_base = df1_base.melt(
        value_vars=[
            "EBS",
            "TS",
            "PROBAS",
            "IRM",
            "TS_WITH_OOD",
            "IRM_WITH_OOD",
            "IROVATS",
            "IROVATS_WITH_OOD",
        ],
        id_vars=["experiment", "loss", "domain"],
    )
    df1_base["treatment"] = df1_base["loss"] + " + " + df1_base["variable"]
    df1_base = df1_base.loc[df1_base["treatment"].isin(c3)]
    df1_base = (
        pd.merge(
            df1_base.loc[df1_base.domain == "ID"].drop(columns="domain"),
            df1_base.loc[df1_base.domain == "OOD"].drop(columns="domain"),
            on=["loss", "variable", "treatment", "experiment"],
            suffixes=["_id", "_ood"],
        )
        .groupby(["loss", "variable", "treatment", "experiment"])
        .mean()
    )
    sns.scatterplot(
        data=df1_base,
        style="loss",
        hue="variable",
        y="value_ood",
        x="value_id",
        s=40,
        ax=ax[j],
        markers={"CE": "o", "ER": "P", "LS": "X", "Focal": "v", "ER+LS": "D"},
        palette=palette,
        legend=False,
    )

    df5 = df5.reset_index()
    df5 = df5.loc[df5["treatment"].isin(c3)]
    for xp, yp, var in zip(
        np.stack([df5.value_id.values, df1_base.value_id.values], 1),
        np.stack([df5.value_ood.values, df1_base.value_ood.values], 1),
        df5.variable.values,
    ):
        ax[j].plot(xp, yp, c=palette[var], ls=":")
    if 0 == j:
        # Get handles and labels directly from the axes
        handles, labels = ax[0].get_legend_handles_labels()
        # Remove the legend from its current position
        ax[0].get_legend().remove()
    ax[j].set_xlabel("")

ax[0].set_ylabel("SHIFTED")
ax[0].set_xlabel("ID")
ax[1].set_ylabel("SHIFTED")
ax[1].set_xlabel("ID")
ax[0].set_title("ECE")
ax[1].set_title("Brier")
ax[1].set_ylabel("")
f2, ax2 = plt.subplots(1, 1, figsize=(10, 1))
ax2.legend(
    handles,
    labels,
    loc="center",  # Center the legend
    ncol=12,  # Display legend items in 3 columns
)
ax2.axis("off")
f2.tight_layout()
f2.savefig(
    f"/vol/biomedic3/mb121/calibration_exploration/outputs/figures/ensemble/posthoc/legend.pdf",
    bbox_inches="tight",
)
match experiment:
    case "base_chexpert":
        f.suptitle("$\mathbf{CXR}$")
    case "base_density":
        f.suptitle("$\mathbf{EMBED}$")
    case _:
        f.suptitle("$\mathbf{" + experiment.replace("base_", "").upper() + "}$")
f.savefig(
    f"/vol/biomedic3/mb121/calibration_exploration/outputs/figures/ensemble/posthoc/{experiment}.pdf",
    bbox_inches="tight",
)