# Plotting results for paper

You need to run all inference notebooks first.

In [None]:
import sys

sys.path.append("/vol/biomedic3/mb121/causal-contrastive")

import matplotlib
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from evaluation.helper_functions import (
    extract_train_label_prop,
    extract_pretraining_type,
    extract_finetuning_type,
)

color_dict = {
    "ImageNet": sns.color_palette("colorblind", 8)[-1],
    "CF-SimCLR": sns.color_palette("colorblind", 3)[1],
    "SimCLR": sns.color_palette("colorblind", 3)[0],
    "SimCLR+": sns.color_palette("colorblind", 3)[2],
    "CF-DINO": sns.color_palette("colorblind", 3)[1],
    "DINO": sns.color_palette("colorblind", 3)[0],
    "DINO+": sns.color_palette("colorblind", 3)[2],
    "CF-DINO": sns.color_palette("colorblind", 3)[1],
    "CF-DiNO-long": sns.color_palette("colorblind", 3)[1],
    "DiNO-long": sns.color_palette("colorblind", 3)[0],
    "CF-DiNO-short": sns.color_palette("colorblind", 8)[-1],
    "DiNO-short": sns.color_palette("colorblind", 8)[-2],
}

legend_conversion = {
    "ImageNet": "ImageNet",
    "CF-SimCLR": "CF-SimCLR",
    "SimCLR": "SimCLR",
    "SimCLR with CF\nin training set": "SimCLR+",
    "CF-DINO": "CF-DINO",
    "DINO": "DINO",
}

type_error = ("se", 1)

plt_kwargs = {
    "errorbar": type_error,
    "palette": color_dict,
    "linewidth": 3,
}

training_prop = {
    "Selenia Dimensions": 89,
    "Senograph 2000D ADS_17.5": 4.4,
    "Lorad Selenia": 3.5,
    "Clearview CSm": 2.7,
    "Senographe Pristina": 0.2,
}

vindr_training_prop = {
    "(OOD) VinDr\nMammomat Inspiration": 80,
    "(OOD) VinDr\nPlanmed Nuance": 20,
}
order_dict = {
    "ImageNet": 3,
    "CF-SimCLR": 0,
    "SimCLR": 2,
    "SimCLR+": 1,
    "CF-DINO": 0,
    "DINO": 2,
    "DINO+": 1,
}

style_dict = {"Finetuning": [2, 2], "Linear Probing": ""}

# SimCLR

## Mammography results

In [None]:
plt_kwargs = {
    "errorbar": type_error,
    "palette": color_dict,
    "linewidth": 3,
    "dashes": style_dict,
}

sns.set_theme(context="paper", style="whitegrid", font_scale=1.6)
matplotlib.rcParams["font.family"] = "serif"
rotation = 90

f, ax = plt.subplots(2, 4, figsize=(20, 14), facecolor='none')
df = pd.read_csv(f"../outputs/classification_tissueden_results_finetune2.csv")
df["ctrain_label_prop"] = df.run_name.apply(extract_train_label_prop)
df["Pretraining"] = df.run_name.apply(
    lambda x: legend_conversion[extract_pretraining_type(x)]
)
df["Classifier"] = df.run_name.apply(lambda x: extract_finetuning_type(x))

df2 = df.dropna(subset=["ROC", "ctrain_label_prop"])
xticks = [0.01, 0.05, 0.1, 0.25, 1.0]
df2 = df2.loc[df2.ctrain_label_prop.isin(xticks)]
plt.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9, wspace=0.2, hspace=0.58)
ax = ax.ravel()
for i, model in enumerate(
    df2.sort_values(by="N_test", ascending=False)["Model Name"].unique(), 0
):
    df3 = df2.loc[df2["Model Name"] == model]
    df3["o"] = df3.Pretraining.apply(lambda x: order_dict[x]) + df3.Classifier.apply(lambda x: 0 if x == 'Linear Probing' else 10)
    sns.lineplot(
        data=df3.sort_values(by="o"),
        x="ctrain_label_prop",
        y="ROC",
        hue="Pretraining",
        ax=ax[i],
        legend=i == 0,
        style="Classifier",
        **plt_kwargs,
    )
    ax[i].set_title(
        r"$\bf{(ID)}$"
        + f" {model}\n{training_prop[model]}% ID train set, N test = {df3.N_test.unique()[0]}"
    )
    ax[i].set_xlabel("Total number of EMBED labels\n(proportion of training set)")
    n_label_train_total = 223086
    ax[i].set_xscale("log")

    xtickslabels = [f"N={int(n_label_train_total * x)}\n({x * 100}%)" for x in xticks]
    ax[i].set_xticks(xticks)
    ax[i].set_xticklabels(xtickslabels, rotation=rotation)
    if i > 0:
        ax[i].set_ylabel("")


df = pd.read_csv(f"../outputs/classification_tissueden_results_ood.csv")
df["ctrain_label_prop"] = df.run_name.apply(extract_train_label_prop)
df["Pretraining"] = df.run_name.apply(
    lambda x: legend_conversion[extract_pretraining_type(x)]
)
df["Classifier"] = df.run_name.apply(lambda x: extract_finetuning_type(x))
# df = df.loc[df.Classifier == evaluation_type]
df2 = df.dropna(subset=["ROC", "ctrain_label_prop"])
xticks = [0.05, 0.1, 0.25, 1.0]
df2 = df2.loc[df2.ctrain_label_prop.isin(xticks)]
for i, model in enumerate(
    df2.sort_values(by="N_test", ascending=False)["Model Name"].unique(), 5
):
    df3 = df2.loc[df2["Model Name"] == model]
    sns.lineplot(
        data=df3,
        x="ctrain_label_prop",
        y="ROC",
        hue="Pretraining",
        ax=ax[i],
        legend=False,
        style="Classifier",
        **plt_kwargs,
    )
    ax[i].set_title(
        r"$\bf{(OOD)}$"
        + f" Senographe Essential\n100% OOD train set, N test = {df3.N_test.unique()[0]}"
    )
    ax[i].set_xlabel("Total number of Senograph labels\n(proportion of training set)")
    n_label_train_total = 10927
    ax[i].set_xscale("log")

    xtickslabels = [f"N={int(n_label_train_total * x)}\n({x * 100}%)" for x in xticks]
    ax[i].set_xticks(xticks)
    ax[i].set_xticklabels(xtickslabels, rotation=rotation)
    if i > 0:
        ax[i].set_ylabel("")

df = pd.read_csv(f"../outputs/vindr2.csv")
df["ctrain_label_prop"] = df.run_name.apply(extract_train_label_prop)
df["Pretraining"] = df.run_name.apply(
    lambda x: legend_conversion[extract_pretraining_type(x)]
)
df["Classifier"] = df.run_name.apply(lambda x: extract_finetuning_type(x))
df2 = df.dropna(subset=["ROC", "ctrain_label_prop"])
xticks = [0.05, 0.1, 0.25, 1.0]
df2 = df2.loc[df2.ctrain_label_prop.isin(xticks)]
df2 = df2.loc[df2["Model Name"] != "(OOD) VinDr"]
for i, model in enumerate(
    df2.sort_values(by="N_test", ascending=False)["Model Name"].unique(), 6
):
    df3 = df2.loc[df2["Model Name"] == model]
    df3["o"] = df3.Pretraining.apply(lambda x: order_dict[x]) + df3.Classifier.apply(lambda x: 0 if x == 'Linear Probing' else 10)
    sns.lineplot(
        data=df3.sort_values(by="o"),
        x="ctrain_label_prop",
        y="ROC",
        hue="Pretraining",
        ax=ax[i],
        legend=i == 7,
        style="Classifier",
        **plt_kwargs,
    )
    modelname = model.replace("\n", " - ")
    modelname = modelname.replace("(OOD)", r"$\bf{(OOD)}$")
    ax[i].set_title(
        f"{modelname}\n{vindr_training_prop[model]}% OOD train set, N test = {df3.N_test.unique()[0]}"
    )
    ax[i].set_xlabel("Total number of VinDR labels\n(proportion of training set)")
    n_label_train_total = 11212
    ax[i].set_xscale("log")

    xtickslabels = [f"N={int(n_label_train_total * x)}\n({x * 100}%)" for x in xticks]
    ax[i].set_xticks(xticks)
    ax[i].set_xticklabels(xtickslabels, rotation=rotation)
    ax[i].set_yticks([0.75, 0.80, 0.85, 0.90, 0.95])
    if i > 0:
        ax[i].set_ylabel("")

plt.tight_layout()
handles, labels = ax[-1].get_legend_handles_labels()
print(handles, labels)
labels[-1] += "          "
labels[0] = r"$\bf{" + labels[0] + r"}$"
labels[-3] = r"$\bf{" + labels[-3] + r"}$"
ax[-1].legend(
    handles,
    labels,
    loc="upper left",
    bbox_to_anchor=(-3.70, -0.45),
    ncol=8,
    fontsize=16,
)

plt.savefig(f"figures/simlcr_embed.pdf", bbox_inches="tight", dpi=300)
plt.show()

## Chest X-ray results

In [None]:
plt_kwargs = {
    "errorbar": type_error,
    "palette": color_dict,
    "linewidth": 3,
    "dashes": style_dict,
}

sns.set_theme(context="paper", style="whitegrid", font_scale=1.6)
matplotlib.rcParams["font.family"] = "serif"

f, ax = plt.subplots(1, 4, figsize=(25, 6), facecolor='none')
rotation = 0
xticks = [0.05, 0.1, 0.25, 1.0]

df = pd.read_csv(f"../outputs/classification_padchestfinetunepneumo_results.csv")
df["ctrain_label_prop"] = df.run_name.apply(extract_train_label_prop)
df["Scanner"] = df["Scanner"].astype(str)
df["Pretraining"] = df.run_name.apply(
    lambda x: legend_conversion[extract_pretraining_type(x)]
)
df["Classifier"] = df.run_name.apply(lambda x: extract_finetuning_type(x))
df = df.loc[df.ctrain_label_prop.isin(xticks)]
custom_order = [
    "ImageNet",
    "DINO",
    "CF-DINO",
]
df.sort_values(
    by="Pretraining",
    inplace=True,
    key=lambda x: x.map({k: i for i, k in enumerate(custom_order)}),
)

for i, s in enumerate(["0", "1"]):
    df3 = df.loc[df.Scanner == s]
    df3["o"] = df3.Pretraining.apply(lambda x: order_dict[x]) + df3.Classifier.apply(lambda x: 0 if x == 'Linear Probing' else 10)
    if len(df3) > 0:
        sns.lineplot(
            data=df3.sort_values(by="o"),
            x="ctrain_label_prop",
            y="ROC",
            hue="Pretraining",
            ax=ax[i],
            legend=False,
            style="Classifier",
            **plt_kwargs,
        )
        if s == "0":
            ax[i].set_title(r"$\bf{(ID)}$" + f" PadChest - Scanner: Phillips")
        else:
            ax[i].set_title(r"$\bf{(ID)}$" + f" PadChest - Scanner: Imaging")

        ax[i].set_xlabel(
            "Total number of PadChest labels\n(proportion of training set)"
        )
        if i > 0:
            ax[i].set_ylabel("")

n_label_train_total = 64989
xtickslabels = [f"N={int(n_label_train_total * x)}\n({x * 100}%)" for x in xticks]
ax[0].set_xscale("log")

ax[0].set_xticks(xticks)
ax[0].set_xticklabels(xtickslabels, rotation=rotation)
ax[1].set_xscale("log")
ax[1].set_xticks(xticks)
ax[1].set_xticklabels(xtickslabels, rotation=rotation)

df = pd.read_csv(f"../outputs/classification_rsna_results.csv")
df["ctrain_label_prop"] = df.run_name.apply(extract_train_label_prop)
df["Scanner"] = df["Scanner"].astype(str)
df["Pretraining"] = df.run_name.apply(
    lambda x: legend_conversion[extract_pretraining_type(x)]
)
df["Classifier"] = df.run_name.apply(lambda x: extract_finetuning_type(x))
xticks = [0.1, 0.25, 1.0]
df = df.loc[df.ctrain_label_prop.isin(xticks)]
if len(df) > 0:
    df["o"] = df.Pretraining.apply(lambda x: order_dict[x]) + df.Classifier.apply(lambda x: 0 if x == 'Linear Probing' else 10)
    sns.lineplot(
        data=df.sort_values(by="o"),
        x="ctrain_label_prop",
        y="ROC",
        hue="Pretraining",
        ax=ax[-2],
        legend=False,
        style="Classifier",
        **plt_kwargs,
    )
ax[-2].set_ylim([0.78, 0.87])

ax[-2].set_ylabel("")
ax[-2].set_xlabel("Total number of RNSA labels\n(proportion of training set)")
ax[-2].set_title(r"$\bf{(OOD)}$" + f" RSNA Pneumonia")
n_label_train_total = 8633

xtickslabels = [f"N={int(n_label_train_total * x)}\n({x * 100}%)" for x in xticks]
ax[-2].set_xscale("log")
ax[-2].set_xticks(xticks)
ax[-2].set_xticklabels(xtickslabels, rotation=rotation)
df = pd.read_csv(f"../outputs/classification_chexfinetunepneumo_results.csv")
df["ctrain_label_prop"] = df.run_name.apply(extract_train_label_prop)
df["Scanner"] = df["Scanner"].astype(str)
df["Classifier"] = df.run_name.apply(lambda x: extract_finetuning_type(x))
df["Pretraining"] = df.run_name.apply(
    lambda x: legend_conversion[extract_pretraining_type(x)]
)
df = df.loc[df.ctrain_label_prop.isin(xticks)]
df["o"] = df.Pretraining.apply(lambda x: order_dict[x]) + df.Classifier.apply(lambda x: 0 if x == 'Linear Probing' else 10)
if len(df) > 0:
    sns.lineplot(
        data=df.sort_values(by="o"),
        x="ctrain_label_prop",
        y="ROC",
        hue="Pretraining",
        ax=ax[-1],
        legend=True,
        style="Classifier",
        **plt_kwargs,
    )
ax[-1].set_ylabel("")

handles, labels = ax[-1].get_legend_handles_labels()
labels[-1] += "          "
labels[0] = r"$\bf{" + labels[0] + r"}$"
labels[-3] = r"$\bf{" + labels[-3] + r"}$"
ax[-1].legend(
    handles,
    labels,
    loc="upper left",
    bbox_to_anchor=(-3.60, -0.3),
    ncol=8,
    fontsize=16,
)
ax[-1].set_xlabel("Total number of CheXpert labels\n(proportion of training set)")
ax[-1].set_title(r"$\bf{(OOD)}$" + f" CheXpert")
n_label_train_total = 13811
xtickslabels = [f"N={int(n_label_train_total * x)}\n({x * 100}%)" for x in xticks]
ax[-1].set_xscale("log")
ax[-1].set_xticks(xticks)
ax[-1].set_xticklabels(xtickslabels, rotation=rotation)
plt.savefig(f"figures/simclr_cxr.pdf", bbox_inches="tight", dpi=300)

# DINO

In [None]:
plt_kwargs = {
    "errorbar": type_error,
    "palette": color_dict,
    "linewidth": 3,
    "dashes": style_dict,
}

sns.set_theme(context="paper", style="whitegrid", font_scale=1.6)
matplotlib.rcParams["font.family"] = "serif"

f, ax = plt.subplots(1, 4, figsize=(25, 5), facecolor='none')
rotation = 0
xticks = [0.05, 0.1, 0.25, 1.0]

df = pd.read_csv(f"../outputs/dino_padchest2.csv")
df["ctrain_label_prop"] = df.run_name.apply(extract_train_label_prop)
df["Scanner"] = df["Scanner"].astype(str)
df["Pretraining"] = df.run_name.apply(lambda x: extract_pretraining_type(x))
df["Classifier"] = df.run_name.apply(lambda x: extract_finetuning_type(x))
# df = df.loc[df["Pretraining"].isin(["DINO", "CF-DINO"])]
df = df.loc[df.ctrain_label_prop.isin(xticks)]
custom_order = ["ImageNet", "DINO", "CF-DINO", "DINO+"]
df.sort_values(
    by="Pretraining",
    inplace=True,
    key=lambda x: x.map({k: i for i, k in enumerate(custom_order)}),
)

for i, s in enumerate(["0", "1"]):
    df3 = df.loc[df.Scanner == s]
    df3["o"] = df3.Pretraining.apply(lambda x: order_dict[x]) + df3.Classifier.apply(lambda x: 0 if x == 'Linear Probing' else 10)
    if len(df3) > 0:
        sns.lineplot(
            data=df3.sort_values(by="o"),
            x="ctrain_label_prop",
            y="ROC",
            hue="Pretraining",
            ax=ax[i],
            legend=False,
            style="Classifier",
            **plt_kwargs,
        )
        if s == "0":
            ax[i].set_title(r"$\bf{(ID)}$" + f" PadChest - Scanner: Phillips")
        else:
            ax[i].set_title(r"$\bf{(ID)}$" + f" PadChest - Scanner: Imaging")
            # if evaluation_type == "Finetuning":
            #     ax[i].set_yticks([0.68, 0.72, 0.76, 0.80, 0.84])

        ax[i].set_xlabel(
            "Total number of PadChest labels\n(proportion of training set)"
        )
        if i > 0:
            ax[i].set_ylabel("")

n_label_train_total = 64989
xtickslabels = [f"N={int(n_label_train_total * x)}\n({x * 100}%)" for x in xticks]
ax[0].set_xscale("log")

ax[0].set_xticks(xticks)
ax[0].set_xticklabels(xtickslabels, rotation=rotation)
ax[1].set_xscale("log")
ax[1].set_xticks(xticks)
ax[1].set_xticklabels(xtickslabels, rotation=rotation)

df = pd.read_csv(f"../outputs/dino_rsna2.csv")
df["ctrain_label_prop"] = df.run_name.apply(extract_train_label_prop)
df["Scanner"] = df["Scanner"].astype(str)
df["Pretraining"] = df.run_name.apply(lambda x: extract_pretraining_type(x))
df["Classifier"] = df.run_name.apply(lambda x: extract_finetuning_type(x))
xticks = [0.1, 0.25, 1.0]
df = df.loc[df.ctrain_label_prop.isin(xticks)]
if len(df) > 0:
    df["o"] = df.Pretraining.apply(lambda x: order_dict[x]) + df.Classifier.apply(lambda x: 0 if x == 'Linear Probing' else 10)
    sns.lineplot(
        data=df.sort_values(by="o"),
        x="ctrain_label_prop",
        y="ROC",
        hue="Pretraining",
        ax=ax[-2],
        legend=False,
        style="Classifier",
        **plt_kwargs,
    )

ax[-2].set_ylabel("")
ax[-2].set_xlabel("Total number of RNSA labels\n(proportion of training set)")
ax[-2].set_title(r"$\bf{(OOD)}$" + f" RSNA Pneumonia")
n_label_train_total = 8633

xtickslabels = [f"N={int(n_label_train_total * x)}\n({x * 100}%)" for x in xticks]
ax[-2].set_xscale("log")
ax[-2].set_xticks(xticks)
ax[-2].set_xticklabels(xtickslabels, rotation=rotation)
df = pd.read_csv(f"../outputs/dino_chexpert2.csv")
df["ctrain_label_prop"] = df.run_name.apply(extract_train_label_prop)
df["Scanner"] = df["Scanner"].astype(str)
df["Classifier"] = df.run_name.apply(lambda x: extract_finetuning_type(x))
df["Pretraining"] = df.run_name.apply(lambda x: extract_pretraining_type(x))
df = df.loc[df.ctrain_label_prop.isin(xticks)]
df["o"] = df.Pretraining.apply(lambda x: order_dict[x]) + df.Classifier.apply(lambda x: 0 if x == 'Linear Probing' else 10)
if len(df) > 0:
    sns.lineplot(
        data=df.sort_values(by="o"),
        x="ctrain_label_prop",
        y="ROC",
        hue="Pretraining",
        ax=ax[-1],
        legend=True,
        style="Classifier",
        **plt_kwargs,
    )
ax[-1].set_ylabel("")

handles, labels = ax[-1].get_legend_handles_labels()
labels[-1] += "          "
labels[0] = r"$\bf{" + labels[0] + r"}$"
labels[-3] = r"$\bf{" + labels[-3] + r"}$"
ax[-1].legend(
    handles,
    labels,
    loc="upper left",
    bbox_to_anchor=(-3.28, -0.3),
    ncol=8,
    fontsize=16,
)
ax[-1].set_xlabel("Total number of CheXpert labels\n(proportion of training set)")
ax[-1].set_title(r"$\bf{(OOD)}$" + f" CheXpert")
n_label_train_total = 13811
xtickslabels = [f"N={int(n_label_train_total * x)}\n({x * 100}%)" for x in xticks]
ax[-1].set_xscale("log")
ax[-1].set_xticks(xticks)
ax[-1].set_xticklabels(xtickslabels, rotation=rotation)
plt.savefig(f"figures/dino_cxr_full.pdf", bbox_inches="tight", dpi=300)

In [None]:
sns.set_theme(context="paper", style="whitegrid", font_scale=1.6)
matplotlib.rcParams["font.family"] = "serif"
rotation = 90

f, ax = plt.subplots(2, 4, figsize=(20, 12), facecolor='none')
df = pd.read_csv(f"../outputs/dino_embed2.csv")
df["ctrain_label_prop"] = df.run_name.apply(extract_train_label_prop)
df["Pretraining"] = df.run_name.apply(lambda x: extract_pretraining_type(x))
df["Classifier"] = df.run_name.apply(lambda x: extract_finetuning_type(x))

df2 = df.dropna(subset=["ROC", "ctrain_label_prop"])
xticks = [0.01, 0.05, 0.1, 0.25, 1.0]
df2 = df2.loc[df2.ctrain_label_prop.isin(xticks)]
plt.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9, wspace=0.2, hspace=0.58)
ax = ax.ravel()
for i, model in enumerate(
    df2.sort_values(by="N_test", ascending=False)["Model Name"].unique(), 0
):
    df3 = df2.loc[df2["Model Name"] == model]
    df3["o"] = df3.Pretraining.apply(lambda x: order_dict[x]) + df3.Classifier.apply(lambda x: 0 if x == 'Linear Probing' else 10)
    sns.lineplot(
        data=df3.sort_values(by="o"),
        x="ctrain_label_prop",
        y="ROC",
        hue="Pretraining",
        ax=ax[i],
        legend=i == 0,
        style="Classifier",
        **plt_kwargs,
    )
    ax[i].set_title(
        r"$\bf{(ID)}$"
        + f" {model}\n{training_prop[model]}% ID train set, N test = {df3.N_test.unique()[0]}"
    )
    ax[i].set_xlabel("Total number of EMBED labels\n(proportion of training set)")
    n_label_train_total = 223086
    ax[i].set_xscale("log")

    xtickslabels = [f"N={int(n_label_train_total * x)}\n({x * 100}%)" for x in xticks]
    ax[i].set_xticks(xticks)
    ax[i].set_xticklabels(xtickslabels, rotation=rotation)
    if i > 0:
        ax[i].set_ylabel("")


df = pd.read_csv(f"../outputs/dino_embed_ood.csv")
df["ctrain_label_prop"] = df.run_name.apply(extract_train_label_prop)
df["Pretraining"] = df.run_name.apply(lambda x: extract_pretraining_type(x))
df["Classifier"] = df.run_name.apply(lambda x: extract_finetuning_type(x))
# df = df.loc[df.Classifier == evaluation_type]
df2 = df.dropna(subset=["ROC", "ctrain_label_prop"])
xticks = [0.05, 0.1, 0.25, 1.0]
df2 = df2.loc[df2.ctrain_label_prop.isin(xticks)]
for i, model in enumerate(
    df2.sort_values(by="N_test", ascending=False)["Model Name"].unique(), 5
):
    df3 = df2.loc[df2["Model Name"] == model]
    df3["o"] = df3.Pretraining.apply(lambda x: order_dict[x]) + df3.Classifier.apply(lambda x: 0 if x == 'Linear Probing' else 10)
    sns.lineplot(
        data=df3.sort_values(by="o"),
        x="ctrain_label_prop",
        y="ROC",
        hue="Pretraining",
        ax=ax[i],
        legend=False,
        style="Classifier",
        **plt_kwargs,
    )
    ax[i].set_title(
        r"$\bf{(OOD)}$"
        + f" Senographe Essential\n100% OOD train set, N test = {df3.N_test.unique()[0]}"
    )
    ax[i].set_xlabel("Total number of Senograph labels\n(proportion of training set)")
    n_label_train_total = 10927
    ax[i].set_xscale("log")

    xtickslabels = [f"N={int(n_label_train_total * x)}\n({x * 100}%)" for x in xticks]
    ax[i].set_xticks(xticks)
    ax[i].set_xticklabels(xtickslabels, rotation=rotation)
    if i > 0:
        ax[i].set_ylabel("")

df = pd.read_csv(f"../outputs/dino_vindr2.csv")
df["ctrain_label_prop"] = df.run_name.apply(extract_train_label_prop)
df["Pretraining"] = df.run_name.apply(lambda x: extract_pretraining_type(x))
df["Classifier"] = df.run_name.apply(lambda x: extract_finetuning_type(x))
df2 = df.dropna(subset=["ROC", "ctrain_label_prop"])
xticks = [0.05, 0.1, 0.25, 1.0]
df2 = df2.loc[df2.ctrain_label_prop.isin(xticks)]
df2 = df2.loc[df2["Model Name"] != "(OOD) VinDr"]
for i, model in enumerate(
    df2.sort_values(by="N_test", ascending=False)["Model Name"].unique(), 6
):
    df3 = df2.loc[df2["Model Name"] == model]
    df3["o"] = df3.Pretraining.apply(lambda x: order_dict[x]) + df3.Classifier.apply(lambda x: 0 if x == 'Linear Probing' else 10)

    sns.lineplot(
        data=df3.sort_values(by="o"),
        x="ctrain_label_prop",
        y="ROC",
        hue="Pretraining",
        ax=ax[i],
        legend=i == 7,
        style="Classifier",
        **plt_kwargs,
    )
    modelname = model.replace("\n", " - ")
    modelname = modelname.replace("(OOD)", r"$\bf{(OOD)}$")
    ax[i].set_title(
        f"{modelname}\n{vindr_training_prop[model]}% OOD train set, N test = {df3.N_test.unique()[0]}"
    )
    ax[i].set_xlabel("Total number of VinDR labels\n(proportion of training set)")
    n_label_train_total = 11212
    ax[i].set_xscale("log")

    xtickslabels = [f"N={int(n_label_train_total * x)}\n({x * 100}%)" for x in xticks]
    ax[i].set_xticks(xticks)
    ax[i].set_xticklabels(xtickslabels, rotation=rotation)
    ax[i].set_yticks([0.75, 0.80, 0.85, 0.90, 0.95])
    if i > 0:
        ax[i].set_ylabel("")

plt.tight_layout()
handles, labels = ax[-1].get_legend_handles_labels()
print(handles, labels)
labels[-1] += "          "
labels[0] = r"$\bf{" + labels[0] + r"}$"
labels[-3] = r"$\bf{" + labels[-3] + r"}$"
ax[-1].legend(
    handles,
    labels,
    loc="upper left",
    bbox_to_anchor=(-3.5, -0.5),
    ncol=8,
    fontsize=16,
)

plt.savefig(f"figures/dino_embed_full.pdf", bbox_inches="tight", dpi=300)
plt.show()