# Plot SexCF results

In [None]:
import sys
import os
import pandas as pd

sys.path.append("/vol/biomedic3/mb121/causal-contrastive")

os.chdir("/vol/biomedic3/mb121/causal-contrastive/evaluation")

In [None]:
import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt

from evaluation.helper_functions import (
    extract_train_label_prop,
    extract_pretraining_type,
    extract_finetuning_type,
)


sns.set_theme(context="paper", style="whitegrid", font_scale=1.5)
matplotlib.rcParams["font.family"] = "serif"

color_dict = {
    "SimCLR": sns.color_palette("colorblind", 3)[0],
    "SexCF-SimCLR": sns.color_palette("colorblind", 8)[-1],
}

type_error = ("se", 1)

plt_kwargs = {
    "errorbar": type_error,
    "palette": color_dict,
    "linewidth": 3,
}

order_dict = {
    "SimCLR": 0,
    "SexCF-SimCLR": 1,
}

style_dict = {"Female": [1, 2, 5, 2], "Male": ""}

plt_kwargs = {
    "errorbar": type_error,
    "palette": color_dict,
    "linewidth": 3,
    "dashes": style_dict,
}

In [None]:
sns.set_theme(context="paper", style="whitegrid", font_scale=1.6)
matplotlib.rcParams["font.family"] = "serif"

for evaluation_type in ["Linear Probing"]:
    rotation = 0
    xticks = [0.05, 0.1, 0.25, 1.0]

    df = pd.read_csv(f"../outputs/padchest_fair.csv")
    df["Pretraining"] = df.run_name.apply(extract_pretraining_type)
    df["ctrain_label_prop"] = df.run_name.apply(extract_train_label_prop)
    df["Gender"] = df["Sex"].apply(lambda x: "Male" if x == 0 else "Female").astype(str)
    df["Classifier"] = df.run_name.apply(lambda x: extract_finetuning_type(x))
    df = df.loc[df.ctrain_label_prop.isin(xticks)]
    df = df.loc[df.Classifier == evaluation_type]

    custom_order = [
        "ImageNet",
        "SimCLR",
        "SimCLR with CF\nin training set",
        "Counterfactual SimCLR",
    ]
    df.sort_values(
        by="Pretraining",
        inplace=True,
        key=lambda x: x.map({k: i for i, k in enumerate(custom_order)}),
    )
    f, ax = plt.subplots(1, 3, figsize=(20, 5), facecolor='none')

    df3 = df
    i = 0
    df3["o"] = df3.Pretraining.apply(lambda x: order_dict[x])
    sns.lineplot(
        data=df3.sort_values(by="o"),
        x="ctrain_label_prop",
        y="ROC",
        hue="Pretraining",
        ax=ax[i],
        legend=False,
        style="Gender",
        **plt_kwargs,
    )
    ax[i].set_title(r"$\bf{(ID)}$" + f" PadChest")
    ax[i].set_xlabel("Total number of PadChest labels\n(proportion of training set)")

    df = pd.read_csv(f"../outputs/rsna_fair.csv")
    df["ctrain_label_prop"] = df.run_name.apply(extract_train_label_prop)
    df["Pretraining"] = df.run_name.apply(extract_pretraining_type)

    n_label_train_total = 64989
    xtickslabels = [f"N={int(n_label_train_total * x)}\n({x * 100}%)" for x in xticks]
    ax[0].set_xscale("log")

    ax[0].set_xticks(xticks)
    ax[0].set_xticklabels(xtickslabels, rotation=rotation)
    ax[1].set_xscale("log")
    ax[1].set_xticks(xticks)
    ax[1].set_xticklabels(xtickslabels, rotation=rotation)

    xticks = [0.1, 0.25, 1.0]

    df["Classifier"] = df.run_name.apply(lambda x: extract_finetuning_type(x))
    df = df.loc[df.Classifier == evaluation_type]
    df = df.loc[df.ctrain_label_prop.isin(xticks)]
    df["o"] = df.Pretraining.apply(lambda x: order_dict[x])
    df["Gender"] = df["Sex"].apply(lambda x: "Male" if x == 0 else "Female").astype(str)
    sns.lineplot(
        data=df.sort_values(by="o"),
        x="ctrain_label_prop",
        y="ROC",
        hue="Pretraining",
        ax=ax[-2],
        legend=False,
        style="Gender",
        **plt_kwargs,
    )

    ax[-2].set_ylabel("")
    ax[-2].set_xlabel("Total number of RNSA labels\n(proportion of training set)")
    ax[-2].set_title(r"$\bf{(OOD)}$" + f" RSNA Pneumonia")
    n_label_train_total = 8633

    xtickslabels = [f"N={int(n_label_train_total * x)}\n({x * 100}%)" for x in xticks]
    ax[-2].set_xscale("log")
    ax[-2].set_xticks(xticks)
    ax[-2].set_xticklabels(xtickslabels, rotation=rotation)
    df = pd.read_csv(f"../outputs/chexpert_fair.csv")
    df["Pretraining"] = df.run_name.apply(extract_pretraining_type)
    df["ctrain_label_prop"] = df.run_name.apply(extract_train_label_prop)
    df["Gender"] = df["Sex"].apply(lambda x: "Male" if x == 0 else "Female").astype(str)
    df["Classifier"] = df.run_name.apply(lambda x: extract_finetuning_type(x))
    df = df.loc[df.ctrain_label_prop.isin(xticks)]
    df = df.loc[df.Classifier == evaluation_type]
    df["o"] = df.Pretraining.apply(lambda x: order_dict[x])
    sns.lineplot(
        data=df.sort_values(by="o"),
        x="ctrain_label_prop",
        y="ROC",
        hue="Pretraining",
        ax=ax[-1],
        legend=True,
        style="Gender",
        **plt_kwargs,
    )
    ax[-1].set_ylabel("")

    handles, labels = ax[-1].get_legend_handles_labels()
    labels[-1] += "          "
    labels[0] = r"$\bf{" + labels[0] + r"}$"
    labels[-3] = r"$\bf{" + labels[-3] + r"}$"
    plt.tight_layout()
    ax[-1].legend(
        handles,
        labels,
        loc="upper left",
        bbox_to_anchor=(-2.0, -0.35),
        ncol=8,
        fontsize=16,
    )
    ax[-1].set_xlabel("Total number of CheXpert labels\n(proportion of training set)")
    ax[-1].set_title(r"$\bf{(OOD)}$" + f" CheXpert")
    n_label_train_total = 13811
    xtickslabels = [f"N={int(n_label_train_total * x)}\n({x * 100}%)" for x in xticks]
    ax[-1].set_xscale("log")
    ax[-1].set_xticks(xticks)
    ax[-1].set_xticklabels(xtickslabels, rotation=rotation)
    plt.savefig(
        f"figures/xrayfair_{evaluation_type.lower()}.pdf", bbox_inches="tight", dpi=300
    )