In [None]:
import sys

sys.path.append("/vol/biomedic3/mb121/calibration_exploration/")

from plotting_notebooks.plotting_utils import (
    get_all_runs,
    retrieve_metrics_df,
    shiftedColorMap,
    my_pretty_plot,
)
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from collections import defaultdict
from matplotlib import ticker

sns.set_style("whitegrid")

all_experiments = [
    "base_density",
    "base_camelyon",
    "base_retina",
    "base_living17",
    "base_entity30",
    "base_icam",
    "base_domainnet",
    "base_chexpert",
]

pd.options.mode.chained_assignment = None

cm = sns.color_palette("Paired")

palette = {
    "PROBAS": "black",
    "TS": cm[0],
    "IRM": cm[2],
    "ETS": cm[4],
    "IROVATS": cm[6],
    "EBS_MINUS": cm[8],
    "EBS": cm[9],
    "TS_WITH_OOD": cm[1],
    "IRM_WITH_OOD": cm[3],
    "IROVATS_WITH_OOD": cm[7],
}


def _prepare_df_for_plot(list_dfs):
    df = pd.concat(list_dfs)
    df = pd.pivot_table(
        data=df,
        index=["loss", "domain"],
        values=list(clean_name.values()),
        aggfunc="mean",
    )[clean_name.values()].reset_index()
    df2 = pd.melt(df, id_vars=["loss", "domain"])
    df3 = pd.merge(
        df2.loc[df2.domain == "ID"].drop(columns="domain"),
        df2.loc[df2.domain == "OOD"].drop(columns="domain"),
        on=["loss", "variable"],
        suffixes=["_id", "_ood"],
    )
    return df3


metrics = ["ECE", "Brier"]

# Main results figure

In [None]:
evaluate_foundation_model = True

if evaluate_foundation_model:
    pretty_axes = {
        "ECE": {
            "base_density": [[(0.0, 0.05), (0.10, 0.20)], [(0.02, 0.08), (0.10, 0.15)]],
            "base_camelyon": [None, None],
            "base_retina": [None, None],
            "base_living17": [
                [(0.005, 0.05), (0.07, 0.10)],
                [(0.03, 0.15), (0.18, 0.20)],
            ],
            "base_entity30": [
                [(0.00, 0.054), (0.062, 0.10)],
                [(0.045, 0.19), (0.20, 0.25)],
            ],
            "base_chexpert": [
                [(0.0, 0.09), (0.16, 0.20)],
                [(0.02, 0.12), (0.16, 0.20)],
            ],
            "base_icam": [[(0.02, 0.10), (0.11, 0.18)], [(0.04, 0.160), (0.20, 0.24)]],
            "base_domainnet": [
                [(0.02, 0.11), (0.12, 0.14)],
                [(0.09, 0.44), (0.50, 0.52)],
            ],
        },
        "Brier": {
            "base_density": [
                [(0.30, 0.34), (0.354, 0.42)],
                [(0.35, 0.44), (0.46, 0.48)],
            ],
            "base_camelyon": [None, None],
            "base_retina": [
                [(0.29, 0.37), (0.378, 0.46)],
                [(0.44, 0.54), (0.62, 0.66)],
            ],
            "base_living17": [
                [(0.05, 0.065), (0.30, 0.35)],
                [(0.28, 0.38), (0.75, 0.85)],
            ],
            "base_entity30": [
                [(0.06, 0.09), (0.36, 0.45)],
                [(0.47, 0.59), (0.75, 0.85)],
            ],
            "base_chexpert": [
                [(0.125, 0.16), (0.20, 0.24)],
                [(0.316, 0.368), (0.38, 0.44)],
            ],
            "base_icam": [[(0.33, 0.40), (0.50, 0.60)], [(0.44, 0.58), (0.73, 0.77)]],
            "base_domainnet": [
                [(0.22, 0.28), (0.50, 0.60)],
                [(0.75, 1.08), (1.10, 1.12)],
            ],
        },
    }
else:
    pretty_axes = {
        "ECE": {
            "base_density": [[(0.0, 0.12), (0.16, 0.22)], [(0.04, 0.09), (0.11, 0.16)]],
            "base_camelyon": [None, None],
           "base_retina": [
                [(0.00, 0.11), (0.145, 0.18)],
                [(0.07, 0.19), (0.21, 0.26)],
            ]
            "base_living17": [
                [(0.02, 0.078), (0.14, 0.18)],
                [(0.12, 0.34), (0.45, 0.50)],
            ],
            "base_entity30": [
                [(0.01, 0.10), (0.16, 0.18)],
                [(0.05, 0.30), (0.38, 0.44)],
            ],
            "base_chexpert": [
                [(0.00, 0.08), (0.20, 0.25)],
                [(0.05, 0.14), (0.16, 0.22)],
            ],
            "base_icam": [[(0.02, 0.18), (0.23, 0.30)], [(0.05, 0.27), (0.30, 0.44)]],
            "base_domainnet": [
                [(0.02, 0.23), (0.25, 0.30)],
                [(0.06, 0.47), (0.55, 0.65)],
            ],
        },
        "Brier": {
            "base_density": [
                [(0.32, 0.36), (0.39, 0.42)],
                [(0.39, 0.45), (0.455, 0.46)],
            ]
            "base_camelyon": [None, None],
            "base_retina": [
                [(0.35, 0.50), (0.52, 0.56)],
                [(0.58, 0.67), (0.68, 0.72)],
            ]
            "base_living17": [
                [(0.32, 0.358), (0.38, 0.41)],
                [(0.76, 0.93), (1.02, 1.1)],
            ],
            "base_entity30": [
                [(0.415, 0.45), (0.46, 0.52)],
                [(0.75, 0.86), (0.95, 1.05)],
            ],
            "base_chexpert": [
                [(0.125, 0.16), (0.24, 0.28)],
                [(0.345, 0.44), (0.46, 0.48)],
            ],
            "base_icam": [[(0.55, 0.61), (0.63, 0.70)], [(0.72, 0.84), (0.87, 0.98)]],
            "base_domainnet": [
                [(0.48, 0.615), (0.622, 0.67)],
                [(0.9, 1.255), (1.28, 1.43)],
            ],
        },
    }


def fn(x, pos):
    return f"{x:.3f}".lstrip("0")

In [None]:
all_experiments = [
    "base_density",
    "base_camelyon",
     "base_retina",
    "base_living17",
    "base_entity30",
    "base_icam",
    "base_domainnet",
    "base_chexpert",
]

c2 = [
    "probas",
    "calib_ts",
    "calib_irm",
    "calib_irovats",
    "calib_ebs_minus",
    "calib_ebs",
    "calib_ts_with_ood",
    "calib_irm_with_ood",
    "calib_irovats_with_ood",
]

for experiment in all_experiments:
    run_ids_dict = get_all_runs(
        experiment, evaluate="foundation" if evaluate_foundation_model else "scratch"
    )
    all_dfs = defaultdict(dict)

    for k, run_list in run_ids_dict.items():
        for m in metrics:
            all_dfs[k][m] = retrieve_metrics_df(run_list, m)
            all_dfs[k][m]["experiment"] = experiment
    keys = run_ids_dict.keys()

    if evaluate_foundation_model:
        all_dfs_base = defaultdict(dict)
        run_ids_dict_base = get_all_runs(experiment, evaluate="scratch")
        for k, run_list in run_ids_dict_base.items():
            for m in metrics:
                all_dfs_base[k][m] = retrieve_metrics_df(run_list, m)
                all_dfs_base[k][m]["experiment"] = experiment
    print(experiment)

    for i, m in enumerate(["ECE", "Brier"]):
        dfs_to_plot = []
        for k in keys:
            df1 = all_dfs[k][m][["domain"] + c2]
            clean_name = {c: c.replace("calib_", "").upper() for c in c2}
            df1.rename(columns=clean_name, inplace=True)
            df1["domain"] = df1["domain"].map(lambda x: "ID" if "id" == x else "OOD")
            df1["loss"] = k
            dfs_to_plot.append(df1)
        df3 = _prepare_df_for_plot(dfs_to_plot)
        if evaluate_foundation_model:
            baseline = all_dfs_base["CE"][m][["domain", "calib_ebs"]]
            baseline["loss"] = "CE"
            baseline["domain"] = baseline["domain"].map(
                lambda x: "ID" if "id" == x else "OOD"
            )
            baseline_df = (
                pd.merge(
                    baseline.loc[baseline.domain == "ID"].drop(columns="domain"),
                    baseline.loc[baseline.domain == "OOD"].drop(columns="domain"),
                    on=["loss"],
                    suffixes=["_id", "_ood"],
                )
                .groupby("loss")
                .mean()
            )

        if pretty_axes[m][experiment][0] is not None:
            f, ax = plt.subplots(
                ncols=2,
                nrows=2,
                width_ratios=(7, 1),
                height_ratios=(1, 8),
                figsize=(2.3, 2.9),
            )
            plt.subplots_adjust(hspace=0.14, wspace=0.14)

            xs_lims, ys_lims = pretty_axes[m][experiment]

            ax[0, 0].set_yticks(ys_lims[1])
            ax[1, 1].set_xticks(xs_lims[1])

            sns.scatterplot(
                data=df3,
                style="loss",
                hue="variable",
                y="value_ood",
                x="value_id",
                s=80,
                legend=True,
                ax=ax[0, 0],
                markers={"CE": "o", "ER": "P", "LS": "X", "Focal": "v", "ER+LS": "D"},
                palette=palette,
                edgecolor="dimgrey",
                linewidth=0.5,
            )

            for axi in [ax[0, 1], ax[1, 1], ax[1, 0]]:
                sns.scatterplot(
                    data=df3,
                    style="loss",
                    hue="variable",
                    y="value_ood",
                    x="value_id",
                    s=80,
                    legend=False,
                    ax=axi,
                    markers={
                        "CE": "o",
                        "ER": "P",
                        "LS": "X",
                        "Focal": "v",
                        "ER+LS": "D",
                    },
                    palette=palette,
                    edgecolor="dimgrey",
                    linewidth=0.5,
                )

            if evaluate_foundation_model:
                for axi in ax.ravel():
                    sns.scatterplot(
                        data=baseline_df,
                        style="loss",
                        y="calib_ebs_ood",
                        x="calib_ebs_id",
                        s=300,
                        hue="loss",
                        legend=False,
                        ax=axi,
                        markers={"CE": "*"},
                        palette={"CE": "red"},
                    )

            ax[1, 1].set_ylabel("")
            ax[0, 1].set_ylabel("")
            ax[0, 0].set_ylabel("")

            ax[1, 1].set_xlabel("")
            ax[0, 1].set_xlabel("")
            ax[0, 0].set_xlabel("")
            ax[1, 0].set_xlabel("ID", x=0.6)

            # ax[0,0].set_title(m)

            if xs_lims is not None:
                ax[0, 0].set_xlim(xs_lims[0])
                ax[0, 0].set_ylim(ys_lims[1])

                ax[0, 1].set_xlim(xs_lims[1])
                ax[0, 1].set_ylim(ys_lims[1])

                ax[1, 0].set_xlim(xs_lims[0])
                ax[1, 0].set_ylim(ys_lims[0])

                ax[1, 1].set_xlim(xs_lims[1])
                ax[1, 1].set_ylim(ys_lims[0])

            (
                ax[1, 0].set_ylabel(
                    "$\mathbf{" + m.upper() + "}$" + " - SHIFTED", y=0.65
                )
                if experiment == "base_living17"
                else ax[1, 0].set_ylabel("")
            )

            for axi in ax.ravel():
                axi.tick_params(axis="x", rotation=45)
                fmt = "{x:.2f}"
                axi.xaxis.set_major_formatter(ticker.FuncFormatter(fn))
                axi.yaxis.set_major_formatter(ticker.FuncFormatter(fn))

            for j in range(2):
                ax[j, 1].set_yticks(ax[j, 0].get_yticks())
                ax[j, 1].set_ylim(ax[j, 0].get_ylim())
                ax[j, 1].set_yticklabels("")
                ax[j, 0].spines["right"].set_visible(False)
                ax[j, 1].spines["left"].set_visible(False)

            for j in range(2):
                ax[0, j].spines["bottom"].set_visible(False)
                ax[0, j].set_xticks(ax[1, j].get_xticks())
                ax[0, j].set_xlim(ax[1, j].get_xlim())
                ax[0, j].set_xticklabels("")
                ax[0, j].set_xlabel("")

            d = 0.03  # how big to make the diagonal lines in axes coordinates
            kwargs = dict(transform=ax[1, 0].transAxes, color="k", clip_on=False)
            ax[1, 0].plot((1, 1), (-d, +d), **kwargs)  # top-left diagonal
            kwargs = dict(transform=ax[1, 1].transAxes, color="k", clip_on=False)
            ax[1, 1].plot((0, 0), (-d, +d), **kwargs)  # top-left diagonal

            kwargs = dict(transform=ax[0, 0].transAxes, color="k", clip_on=False)
            ax[0, 0].plot((1, 1), (1 - d, 1 + d), **kwargs)  # top-left diagonal
            kwargs = dict(transform=ax[0, 1].transAxes, color="k", clip_on=False)
            ax[0, 1].plot((0, 0), (1 - d, 1 + d), **kwargs)  # top-left diagonal

            kwargs = dict(transform=ax[1, 0].transAxes, color="k", clip_on=False)
            ax[1, 0].plot((-d, +d), (1, 1), **kwargs)  # top-left diagonal
            kwargs = dict(transform=ax[0, 1].transAxes, color="k", clip_on=False)
            ax[0, 1].plot((1 - d, 1 + d), (0, 0), **kwargs)  # top-left diagonal

            kwargs = dict(transform=ax[0, 0].transAxes, color="k", clip_on=False)
            ax[0, 0].plot((-d, +d), (0, 0), **kwargs)  # top-left diagonal
            kwargs = dict(transform=ax[1, 1].transAxes, color="k", clip_on=False)
            ax[1, 1].plot((1 - d, 1 + d), (1, 1), **kwargs)  # top-left diagonal

            # Get handles and labels directly from the axes
            handles, labels = ax[0, 0].get_legend_handles_labels()
            ax[0, 0].get_legend().remove()
            if m == "ECE":
                my_pretty_plot(experiment, f)

        else:
            f, ax = plt.subplots(ncols=1, nrows=1, figsize=(2.3, 2.9))
            sns.scatterplot(
                data=df3,
                style="loss",
                hue="variable",
                y="value_ood",
                x="value_id",
                s=80,
                legend=False,
                ax=ax,
                markers={"CE": "o", "ER": "P", "LS": "X", "Focal": "v", "ER+LS": "D"},
                palette=palette,
                edgecolor="dimgrey",
                linewidth=0.5,
            )

            if evaluate_foundation_model:
                sns.scatterplot(
                    data=baseline_df,
                    style="loss",
                    y="calib_ebs_ood",
                    x="calib_ebs_id",
                    s=300,
                    hue="loss",
                    legend=False,
                    ax=ax,
                    markers={"CE": "*"},
                    palette={"CE": "red"},
                )

            ax.tick_params(axis="x", rotation=45)
            ax.xaxis.set_major_formatter(ticker.FuncFormatter(fn))
            ax.yaxis.set_major_formatter(ticker.FuncFormatter(fn))
            (
                ax.set_ylabel("$\mathbf{" + m.upper() + "}$" + " - SHIFTED")
                if experiment == "base_living17"
                else ax.set_ylabel("")
            )
            ax.set_xlabel("ID")
            if m == "ECE":
                my_pretty_plot(experiment, f)

        f.savefig(
            f"/vol/biomedic3/mb121/calibration_exploration/outputs/figures/{m}/{experiment}_{evaluate_foundation_model}.pdf",
            bbox_inches="tight",
        )  #

    # f2, ax2 = plt.subplots(1,1,figsize=(10,1))
    # ax2.legend(
    #     handles,
    #     labels,
    #     loc='center',  # Center the legend
    #     ncol=7  # Display legend items in 3 columns
    # )
    # ax2.axis('off')
    # f2.tight_layout()
    # f2.savefig(f'/vol/biomedic3/mb121/calibration_exploration/outputs/figures/legend.pdf',bbox_inches='tight')

## Ablation - model size

In [None]:

for experiment in all_experiments:
    run_ids_dict_scratch = get_all_runs(
            experiment, evaluate="scratch"
    )
    run_ids_dict_foundation = get_all_runs(
            experiment, evaluate="foundation"
    )

    all_dfs = defaultdict(dict)

    df_scratch = retrieve_metrics_df(run_ids_dict_scratch['CE'], 'ECE')
    df_foundation = retrieve_metrics_df(run_ids_dict_foundation['CE'], 'ECE')

    model_names = [
                "resnet18",
                "resnet50",
                "mobilenetv2_100",
                "convnext_tiny",
                "vit_base",
                "efficientnet_b0",
            ] 

    sizes = {
        "resnet18": 12,
        "resnet50": 26,
        "convnext_tiny": 28.1,
        "vit_base": 86,
        "efficientnet_b0": 5,
        "vit_base_foundation": 86
    }

    n = len(df_scratch.domain.unique())

    # Just take the dino foundation model
    df_foundation = df_foundation[:n]

    full_names = []
    for r in model_names:
        for _ in range(n):
            full_names.append(r)


    df_scratch['name'] = full_names
    df_foundation['name'] = "vit_base_foundation"
    df = pd.concat([df_scratch.loc[df_scratch.name.isin(sizes.keys())], df_foundation])

    c2 = [
        "probas",
        "calib_ts",
        "calib_ts_with_ood",
    ]
    
    df1 = df[["domain", 'name'] + c2]
    clean_name = {c: c.replace("calib_", "").upper() for c in c2}
    df1.rename(columns=clean_name, inplace=True)
    df1["domain"] = df1["domain"].map(lambda x: "ID" if "id" == x else "OOD")
    df2 = pd.pivot_table(
            data=df1,
            index=["name", "domain"],
            values=list(clean_name.values()),
            aggfunc="mean",
        )[clean_name.values()].reset_index()
    df2 = pd.melt(df2, id_vars=["name", "domain"])
    df3 = pd.merge(
            df2.loc[df2.domain == "ID"].drop(columns="domain"),
            df2.loc[df2.domain == "OOD"].drop(columns="domain"),
            on=["name", "variable"],
            suffixes=["_id", "_ood"],
    )
    df3['size'] = df3['name'].apply(lambda x: sizes[x])

    f, ax = plt.subplots(1, 1, figsize=(3,3))
    df3['foundation_model'] = df3['name'].apply(lambda x: 'foundation' in x)
    df3 = df3.loc[df3.name != 'resnet50']
    sns.scatterplot(data=df3.sort_values(by=['size']), x='value_id', y='value_ood',  size='size', ax=ax, palette=palette, hue='variable', style='foundation_model', sizes=(30,300), legend=False, markers={False: 'o', True: "p"})
    plt.suptitle(experiment.upper())
    my_pretty_plot(experiment, f)
    ax.tick_params(axis="x", rotation=45)
    ax.xaxis.set_major_formatter(ticker.FuncFormatter(fn))
    ax.yaxis.set_major_formatter(ticker.FuncFormatter(fn))
    (
        ax.set_ylabel("$\mathbf{ECE}$" + " - SHIFTED")
        if experiment in ["base_living17", 'base_density']
        else ax.set_ylabel("")
    )
    ax.set_xlabel("ID")

    f.savefig(experiment + '_model_size_ablation.pdf', bbox_inches='tight')

## For figure 1

In [None]:
evaluate_foundation_model = False
for experiment in ["base_living17"]:
    run_ids_dict = get_all_runs(
        experiment, evaluate="foundation" if evaluate_foundation_model else "scratch"
    )
    all_dfs = defaultdict(dict)
    metrics = ["ECE", "Brier"]
    for k, run_list in run_ids_dict.items():
        for m in metrics:
            all_dfs[k][m] = retrieve_metrics_df(run_list, m)
            all_dfs[k][m]["experiment"] = experiment
    keys = run_ids_dict.keys()

    keys = ["CE", "ER+LS", "Focal"]
    print(experiment)

    f, ax = plt.subplots(1, 2, figsize=(8, 3.5), facecolor="None")
    for i, m in enumerate(["ECE", "Brier"]):
        dfs_to_plot = []

        c2 = ["probas", "calib_ts", "calib_irm", "calib_ebs"]
        for k in keys:
            df1 = all_dfs[k][m][["domain"] + c2]
            clean_name = {c: c.replace("calib_", "").upper() for c in c2}
            df1.rename(columns=clean_name, inplace=True)
            df1["domain"] = df1["domain"].map(lambda x: "ID" if "id" == x else "OOD")
            df1["loss"] = k
            dfs_to_plot.append(df1)

        df3 = _prepare_df_for_plot(dfs_to_plot)
        sns.scatterplot(
            data=df3,
            style="loss",
            hue="variable",
            y="value_ood",
            x="value_id",
            s=80,
            legend=i == 0,
            ax=ax[i],
            markers={"CE": "o", "ER": "P", "LS": "X", "Focal": "v", "ER+LS": "D"},
            palette=palette,
            edgecolor="dimgrey",
            linewidth=0.5,
        )

        if evaluate_foundation_model:
            baseline = all_dfs_base["CE"][m][["domain", "calib_ebs"]]
            baseline["loss"] = "CE"
            baseline["domain"] = baseline["domain"].map(
                lambda x: "ID" if "id" == x else "OOD"
            )
            baseline_df = (
                pd.merge(
                    baseline.loc[baseline.domain == "ID"].drop(columns="domain"),
                    baseline.loc[baseline.domain == "OOD"].drop(columns="domain"),
                    on=["loss"],
                    suffixes=["_id", "_ood"],
                )
                .groupby("loss")
                .mean()
            )
            sns.scatterplot(
                data=baseline_df,
                style="loss",
                y="calib_ebs_ood",
                x="calib_ebs_id",
                s=300,
                hue="loss",
                legend=False,
                ax=ax[i],
                markers={"CE": "*"},
                palette={"CE": "red"},
            )

        # plt.show()

    # Get handles and labels directly from the axes
    handles, labels = ax[0].get_legend_handles_labels()
    # # Remove the legend from its current position
    ax[0].get_legend().remove()
    ax[0].set_ylabel("SHIFTED")
    ax[0].set_xlabel("ID")
    ax[1].set_ylabel("SHIFTED")
    ax[1].set_xlabel("ID")
    ax[0].set_title("ECE")
    ax[1].set_title("Brier")
    ax[1].set_ylabel("")
    ax[0].xaxis.set_major_formatter(ticker.StrMethodFormatter("{x:.2f}"))
    ax[0].yaxis.set_major_formatter(ticker.StrMethodFormatter("{x:.2f}"))
    ax[1].xaxis.set_major_formatter(ticker.StrMethodFormatter("{x:.2f}"))
    ax[1].yaxis.set_major_formatter(ticker.StrMethodFormatter("{x:.2f}"))
    f.suptitle("$\mathbf{" + experiment.replace("base_", "").upper() + "}$")
    # ax.set_title(experiment.replace('base_','').upper())
    f2, ax2 = plt.subplots(1, 1, figsize=(10, 1))
    ax2.legend(
        handles,
        labels,
        loc="center",  # Center the legend
        # bbox_to_anchor=(0.5, -0.1),  # Position at bottom
        ncol=7,  # Display legend items in 3 columns
    )
    ax2.axis("off")
    f2.tight_layout()
    f.savefig(
        f"/vol/biomedic3/mb121/calibration_exploration/outputs/figures/figure1_{experiment}_{evaluate_foundation_model}.svg",
        bbox_inches="tight",
        dpi=300,
    )
    # f2.savefig(f'/vol/biomedic3/mb121/calibration_exploration/outputs/figures/legend.pdf',bbox_inches='tight')

# For BalAccuracy plot

In [None]:
evaluate_foundation_model = False
metrics_to_plot = ["BalAccuracy"]
for experiment in all_experiments:
    run_ids_dict = get_all_runs(
        experiment, evaluate="foundation" if evaluate_foundation_model else "scratch"
    )
    all_dfs = defaultdict(dict)
    metrics = metrics_to_plot
    for k, run_list in run_ids_dict.items():
        for m in metrics:
            all_dfs[k][m] = retrieve_metrics_df(run_list, m)
            all_dfs[k][m]["experiment"] = experiment
    keys = run_ids_dict.keys()

    print(experiment)

    for i, m in enumerate(metrics_to_plot):
        f, ax = plt.subplots(1, 1, figsize=(3, 3))
        dfs_to_plot = []
        if m != "BalAccuracy":
            c2 = [
                "probas",
                "calib_ts",
                "calib_irm",
                "calib_irovats",
                "calib_ebs_minus",
                "calib_ebs",
                "calib_ts_with_ood",
                "calib_irm_with_ood",
                "calib_irovats_with_ood",
            ]
        else:
            c2 = ["probas", "calib_irovats", "calib_irovats_with_ood"]
        for k in keys:
            df1 = all_dfs[k][m][["domain"] + c2]
            clean_name = {c: c.replace("calib_", "").upper() for c in c2}
            df1.rename(columns=clean_name, inplace=True)
            df1["domain"] = df1["domain"].map(lambda x: "ID" if "id" == x else "OOD")
            df1["loss"] = k
            dfs_to_plot.append(df1)

        df3 = _prepare_df_for_plot(dfs_to_plot)
        if m == "BalAccuracy":
            df3["value_ood"] = df3["value_ood"] * 100
            df3["value_id"] = df3["value_id"] * 100
        sns.scatterplot(
            data=df3,
            style="loss",
            hue="variable",
            y="value_ood",
            x="value_id",
            s=80,
            legend=False,
            ax=ax,
            markers={"CE": "o", "ER": "P", "LS": "X", "Focal": "v", "ER+LS": "D"},
            palette=palette,
            edgecolor="dimgrey",
            linewidth=0.5,
        )

        ax.set_ylabel("SHIFTED")
        ax.set_xlabel("ID")
        my_pretty_plot(experiment, f)
        ax.xaxis.set_major_formatter(ticker.StrMethodFormatter("{x:.1f}"))
        ax.yaxis.set_major_formatter(ticker.StrMethodFormatter("{x:.1f}"))
        f.savefig(
            f"/vol/biomedic3/mb121/calibration_exploration/outputs/figures/{m}/{experiment}_{evaluate_foundation_model}.pdf",
            bbox_inches="tight",
        )
    # f2.savefig(f'/vol/biomedic3/mb121/calibration_exploration/outputs/figures/fig1_legend.pdf',bbox_inches='tight')

# For DAC

In [None]:
sns.set_style("whitegrid")
evaluate_foundation_model = False
for experiment in all_experiments:
    run_ids_dict = get_all_runs(
        experiment, evaluate="foundation" if evaluate_foundation_model else "scratch"
    )
    all_dfs = defaultdict(dict)
    for k, run_list in run_ids_dict.items():
        for m in metrics:
            all_dfs[k][m] = retrieve_metrics_df(run_list, m)
            all_dfs[k][m]["experiment"] = experiment
    keys = run_ids_dict.keys()

    print(experiment)
    pd.options.mode.chained_assignment = None

    cm = sns.color_palette("Paired")
    palette = {
        "PROBAS": "black",
        "TS": cm[0],
        "IRM": cm[2],
        "ETS": cm[4],
        "IROVATS": cm[6],
        "EBS_MINUS": cm[8],
        "EBS": cm[9],
        "TS_WITH_OOD": cm[1],
        "IRM_WITH_OOD": cm[3],
        "IROVATS_WITH_OOD": cm[7],
        "TS_DAC": cm[0],
        "IRM_DAC": cm[2],
        "ETS_DAC": cm[4],
        "IROVATS_DAC": cm[6],
        "EBS_MINUS_DAC": cm[8],
        "EBS_DAC": cm[9],
        "TS_WITH_OOD_DAC": cm[1],
        "IRM_WITH_OOD_DAC": cm[3],
        "IROVATS_WITH_OOD_DAC": cm[7],
    }
    f, ax = plt.subplots(1, 1, figsize=(3, 3))
    m = "ECE"

    dfs_to_plot = []
    dfs_to_plot_dac = []

    c2 = [
        "calib_ts",
        "calib_irovats",
        "calib_ebs",
        "calib_ts_with_ood",
        "calib_irovats_with_ood",
    ]

    c2_dac = [f"{c}_dac" for c in c2]

    for k in ["CE"]:
        df1 = all_dfs[k][m][["domain"] + c2]
        clean_name = {c: c.replace("calib_", "").upper() for c in c2}
        df1.rename(columns=clean_name, inplace=True)
        df1["domain"] = df1["domain"].map(lambda x: "ID" if "id" == x else "OOD")
        df1["loss"] = k
        dfs_to_plot.append(df1)

        df2 = all_dfs[k][m][["domain"] + c2_dac]
        clean_name_dac = {c: c.replace("calib_", "").upper() for c in c2_dac}
        df2.rename(columns=clean_name_dac, inplace=True)
        df2["domain"] = df2["domain"].map(lambda x: "ID" if "id" == x else "OOD")
        df2["loss"] = k
        dfs_to_plot_dac.append(df2)

    df3 = _prepare_df_for_plot(dfs_to_plot)
    sns.scatterplot(
        data=df3,
        hue="variable",
        y="value_ood",
        x="value_id",
        s=40,
        legend=True,
        ax=ax,
        marker="X",
        palette=palette,
        edgecolor="dimgrey",
        linewidth=0.5,
    )

    df_dac = pd.concat(dfs_to_plot_dac)
    df_dac = pd.pivot_table(
        data=df_dac,
        index=["loss", "domain"],
        values=list(clean_name_dac.values()),
        aggfunc="mean",
    )[clean_name_dac.values()].reset_index()
    df_dac = pd.melt(df_dac, id_vars=["loss", "domain"])
    df_dac = pd.merge(
        df_dac.loc[df_dac.domain == "ID"].drop(columns="domain"),
        df_dac.loc[df_dac.domain == "OOD"].drop(columns="domain"),
        on=["loss", "variable"],
        suffixes=["_id", "_ood"],
    )
    sns.scatterplot(
        data=df_dac,
        hue="variable",
        y="value_ood",
        x="value_id",
        s=120,
        legend=True,
        ax=ax,
        marker="o",
        palette=palette,
        edgecolor="dimgrey",
        linewidth=0.5,
    )

    df3 = df3.sort_values(by="variable")
    df_dac = df_dac.sort_values(by="variable")
    for xp, yp, var in zip(
        np.stack([df3.value_id.values, df_dac.value_id.values], 1),
        np.stack([df3.value_ood.values, df_dac.value_ood.values], 1),
        df3.reset_index().variable.values,
    ):
        ax.plot(xp, yp, c=palette[var], ls=":")

        # plt.show()

    # Get handles and labels directly from the axes
    handles, labels = ax.get_legend_handles_labels()
    # # Remove the legend from its current position
    ax.get_legend().remove()
    ax.set_ylabel("SHIFTED")
    ax.set_xlabel("ID")
    ax.xaxis.set_major_formatter(ticker.StrMethodFormatter("{x:.2f}"))
    ax.yaxis.set_major_formatter(ticker.StrMethodFormatter("{x:.2f}"))
    my_pretty_plot(experiment, f)
    f2, ax2 = plt.subplots(1, 1, figsize=(10, 1))
    ax2.legend(
        handles,
        labels,
        loc="center",  # Center the legend
        # bbox_to_anchor=(0.5, -0.1),  # Position at bottom
        ncol=5,  # Display legend items in 3 columns
    )
    ax2.axis("off")
    f2.tight_layout()
    f.savefig(
        f"/vol/biomedic3/mb121/calibration_exploration/outputs/figures/dac_{experiment}_{evaluate_foundation_model}_ECE.pdf",
        bbox_inches="tight",
    )
    f2.savefig(
        f"/vol/biomedic3/mb121/calibration_exploration/outputs/figures/legend_dac.pdf",
        bbox_inches="tight",
    )