In [None]:
import os

import matplotlib.pyplot as plt
import numpy as np
import torch
from sklearn.metrics import roc_auc_score

from src.misc_utils import read_pickle_file
from src.train_neonatal import (
    get_full_performance_and_logits,
    init_classifier,
    init_neonatal_dataset,
)

matplotlibrc_path = "../matplotlibrc"
result_path = "TODO"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
single_modalities = ["np", "th", "hr", "pr", "spo2", "pco2"]
folder_modalities = single_modalities

exp_id = "TODO"
multi_run = f"{exp_id}_main_multirun"
num_run = "01"
cfg = read_pickle_file(
    f"{exp_id}_run_{num_run}_config.pkl", os.path.join(result_path, multi_run)
)
models_dict = read_pickle_file(
    f"{exp_id}_run_{num_run}_models.pkl", os.path.join(result_path, multi_run)
)
results_dict = read_pickle_file(
    f"{exp_id}_run_{num_run}_results.pkl", os.path.join(result_path, multi_run)
)

single_modality_configs = {}
single_modality_models = {}
single_modality_results = {}
for folder_mod, file_mod in zip(folder_modalities, single_modalities):
    single_modality_configs[file_mod] = read_pickle_file(
        f"{exp_id}_{file_mod}_{num_run}_config.pkl",
        os.path.join(result_path, f"{exp_id}_{folder_mod}_multirun"),
    )
    single_modality_models[file_mod] = read_pickle_file(
        f"{exp_id}_{file_mod}_{num_run}_models.pkl",
        os.path.join(result_path, f"{exp_id}_{folder_mod}_multirun"),
    )
    single_modality_results[file_mod] = read_pickle_file(
        f"{exp_id}_{file_mod}_{num_run}_results.pkl",
        os.path.join(result_path, f"{exp_id}_{folder_mod}_multirun"),
    )


In [None]:
(
    pooled_np_logits,
    pooled_thorax_logits,
    pooled_hr_logits,
    pooled_pr_logits,
    pooled_spo2_logits,
    pooled_pco2_logits,
) = ([], [], [], [], [], [])
(
    agg_np_logits,
    agg_thorax_logits,
    agg_hr_logits,
    agg_pr_logits,
    agg_spo2_logits,
    agg_pco2_logits,
) = ([], [], [], [], [], [])

for id in cfg.dataset.ids:
    print(id)
    single_id_dataset = init_neonatal_dataset(id, cfg)
    single_id_model = init_classifier(cfg)
    single_id_model.load_state_dict(models_dict[id])

    labels, scores, logits, bias = get_full_performance_and_logits(
        single_id_dataset,
        single_id_model,
    )
    temp_np, temp_thorax, temp_hr, temp_pr, temp_spo2, temp_pco2 = logits

    pooled_np_logits.extend(temp_np)
    pooled_thorax_logits.extend(temp_thorax)
    pooled_hr_logits.extend(temp_hr)
    pooled_pr_logits.extend(temp_pr)
    pooled_spo2_logits.extend(temp_spo2)
    pooled_pco2_logits.extend(temp_pco2)

    agg_np_logits.append(np.std(temp_np))
    agg_thorax_logits.append(np.std(temp_thorax))
    agg_hr_logits.append(np.std(temp_hr))
    agg_pr_logits.append(np.std(temp_pr))
    agg_spo2_logits.append(np.std(temp_spo2))
    agg_pco2_logits.append(np.std(temp_pco2))



In [None]:
modality_aucs = {}
modality_aucs_std = {}
for mod, pat_dict in single_modality_results.items():
    individual_aucs = []
    for sub_dic in pat_dict.values():
        individual_aucs.append(roc_auc_score(sub_dic["ys"], sub_dic["scores"]))
    modality_aucs[mod] = individual_aucs



In [None]:
with plt.rc_context(fname=matplotlibrc_path):

    figure_width = 12
    with_to_height_ratio = 5 / 17

    large_fontsize = 15
    mid_fontsize = 10
    small_fontsize = 8

    bar_labels = [
        "NP",
        "T+A",
        "HR",
        "PPG",
        "$\mathregular{SpO_2}$",
        "$\mathregular{PCO_2}$",
    ]
    y_pos = np.arange(len(bar_labels))

    all_pooled_logits = [
        pooled_np_logits,
        pooled_thorax_logits,
        pooled_hr_logits,
        pooled_pr_logits,
        pooled_spo2_logits,
        pooled_pco2_logits,
    ]

    all_patient_logits = [
        agg_np_logits,
        agg_thorax_logits,
        agg_hr_logits,
        agg_pr_logits,
        agg_spo2_logits,
        agg_pco2_logits,
    ]

    colors = ["C1", "C2", "C3", "C4", "C5", "C6"]
    bins = [35, 25, 15, 20, 30, 15]

    fig = plt.figure(figsize=(figure_width, with_to_height_ratio * figure_width))
    fig.patch.set_facecolor("white")
    fig.subplots_adjust(wspace=1.5, hspace=0.5)

    ax_dict = fig.subplot_mosaic("""ABCDEFGGHH""")
    ax_ls = ["A", "B", "C", "D", "E", "F"]

    for i in range(6):
        ax_letter = ax_ls[i]
        if i == 0:
            ax_dict[ax_letter].set_title(
                r"$\bf{a}$", loc="left", fontsize=large_fontsize, y=1.05, x=-1.1
            )
            ax_dict[ax_letter].set_ylabel(
                "additive contribution", fontsize=mid_fontsize
            )
        ax_dict[ax_letter].hist(
            all_pooled_logits[i],
            orientation="horizontal",
            bins=bins[i],
            color=colors[i],
        )
        ax_dict[ax_letter].set_ylim((-4.0, 4.0))
        ax_dict[ax_letter].set_yticks([-4, -2, 0, 2, 4])
        ax_dict[ax_letter].spines["bottom"].set_visible(False)
        ax_dict[ax_letter].tick_params(bottom=False)
        ax_dict[ax_letter].tick_params(labelbottom=False)
        ax_dict[ax_letter].text(
            0.2,
            0.9,
            bar_labels[i],
            transform=ax_dict[ax_letter].transAxes,
            fontsize=mid_fontsize,
        )

    fig.text(
        0.32,
        0.94,
        "Distribution of pooled additive contributions",
        fontsize=mid_fontsize,
        ha="center",
    )

    bplot = ax_dict["G"].boxplot(
        all_patient_logits,
        vert=True,
        patch_artist=True,
        medianprops=dict(color="black"),
        labels=bar_labels,
    )

    for patch, color in zip(bplot["boxes"], colors):
        patch.set_facecolor(color)

    ax_dict["G"].set_title(
        r"$\bf{b}$", loc="left", fontsize=large_fontsize, y=1.05, x=-0.3
    )
    ax_dict["G"].set_ylabel("SD over additive contributions", fontsize=mid_fontsize)
    ax_dict["G"].set_ylim((0.0, 2.0))
    ax_dict["G"].set_yticks([0.0, 0.5, 1.0, 1.5, 2.0])
    ax_dict["G"].set_title(
        "Patient-based\nadditive\ncontributions", fontsize=mid_fontsize, y=0.94
    )
    ax_dict["G"].set_xticklabels(bar_labels, fontsize=mid_fontsize, rotation=45)

    fig.text(
        0.32,
        0.94,
        "Distribution of pooled additive contributions",
        fontsize=mid_fontsize,
        ha="center",
    )

    removal_x_axis = np.arange(6)
    ax_dict["H"].set_title(
        r"$\bf{c}$", loc="left", fontsize=large_fontsize, y=1.05, x=-0.3
    )
    ax_dict["H"].set_xticks(removal_x_axis)
    ax_dict["H"].set_xticklabels(bar_labels, fontsize=mid_fontsize, rotation=45)
    ax_dict["H"].errorbar(
        removal_x_axis,
        list(np.mean(ls) for ls in modality_aucs.values()),
        list(np.std(ls) for ls in modality_aucs.values()),
        marker="",
        color="black",
        linestyle="",
    )
    ax_dict["H"].set_ylabel("AuROC", fontsize=mid_fontsize)
    ax_dict["H"].set_title("Single-modality\nanalysis", fontsize=mid_fontsize)
    ax_dict["H"].bar(
        y_pos,
        list(np.mean(ls) for ls in modality_aucs.values()),
        align="center",
        color=colors,
    )
    ax_dict["H"].set_ylim((0.5, 1.0))

    fig.tight_layout()
    plt.show()
