In [None]:
import os

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import torch

from src.misc_utils import read_pickle_file
from src.plotting_utils import get_gam_cams, plot_class_activation
from src.train_neonatal import init_classifier, init_neonatal_dataset

matplotlibrc_path = "../matplotlibrc"
result_path = "TODO"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
exp_id = "TODO"
multi_run = f"{exp_id}_main_multirun"
num_run = "01"
cfg = read_pickle_file(
    f"{exp_id}_run_{num_run}_config.pkl", os.path.join(result_path, multi_run)
)
models_dict = read_pickle_file(
    f"{exp_id}_run_{num_run}_models.pkl", os.path.join(result_path, multi_run)
)
results_dict = read_pickle_file(
    f"{exp_id}_run_{num_run}_results.pkl", os.path.join(result_path, multi_run)
)



In [None]:
analysis_id = "005"
classifier = init_classifier(cfg)
classifier.load_state_dict(models_dict[analysis_id])
classifier.eval()

single_id_dataset = init_neonatal_dataset(analysis_id, cfg)

num_all_time_windows = len(single_id_dataset.time_window_df)
print(num_all_time_windows)

iteri = np.arange(num_all_time_windows)


In [None]:
# NOTE: Creates a lot of plots!
large_fontsize = 15
mid_fontsize = 10
small_fontsize = 8
y_pos_tit = 0.87

with plt.rc_context(fname=matplotlibrc_path):
    bar_labels = ["NP", "T+A", "HR", "PPG", "SPO2", "PCO2"]
    colors = ["C1", "C2", "C3", "C4", "C5", "C6"]
    for i in iteri:
        print(i)
        fig = plt.figure(figsize=(12, 3))
        fig.subplots_adjust(wspace=5.0, hspace=0.3)
        ax_dict = fig.subplot_mosaic(
            """GGGAAAAABBBBBCCCCCX
               GGGDDDDDEEEEEFFFFFX"""
        )
        cams, signals, label, logits, bias, score = get_gam_cams(
            i, single_id_dataset, classifier
        )

        fig.suptitle(
            f"Label: {label}, Score: {np.round(score, 2)}", fontsize=mid_fontsize
        )
        plot_class_activation(
            signals["NP"],
            cams["NP"],
            ax=ax_dict["A"],
            fig=fig,
            supress_colorbar=True,
            x_ticks=[0, 75, 150],
            y_ticks=[-5, 0, 5],
        )
        ax_dict["A"].set_title("NP", fontsize=mid_fontsize, y=y_pos_tit)
        plot_class_activation(
            signals["Thorax"],
            cams["Thorax"],
            ax=ax_dict["B"],
            fig=fig,
            supress_colorbar=True,
            x_ticks=[0, 75, 150],
            y_ticks=[-5, 0, 5],
        )
        ax_dict["B"].set_title("Thorax + Abdomen", fontsize=mid_fontsize, y=y_pos_tit)
        plot_class_activation(
            signals["PR"],
            cams["PR"],
            ax=ax_dict["C"],
            fig=fig,
            supress_colorbar=True,
            x_ticks=[0, 75, 150],
            y_ticks=[-5, 0, 5],
        )
        ax_dict["C"].set_title("PPG", fontsize=mid_fontsize, y=y_pos_tit)
        plot_class_activation(
            signals["SpO2"],
            cams["SpO2"],
            ax=ax_dict["D"],
            fig=fig,
            lower_y=-3.0,
            upper_y=3.0,
            supress_colorbar=True,
            x_ticks=[0, 15, 30],
            y_ticks=[-3, 0, 3],
        )
        ax_dict["D"].set_title(
            "$\mathregular{SpO_2}$", fontsize=mid_fontsize, y=y_pos_tit
        )
        plot_class_activation(
            signals["HR"],
            cams["HR"],
            ax=ax_dict["E"],
            fig=fig,
            lower_y=-3.0,
            upper_y=3.0,
            supress_colorbar=True,
            x_ticks=[0, 15, 30],
            y_ticks=[-3, 0, 3],
        )
        ax_dict["E"].set_title("HR", fontsize=mid_fontsize, y=y_pos_tit)
        plot_class_activation(
            signals["PCO2"],
            cams["HR"],
            ax=ax_dict["F"],
            fig=fig,
            lower_y=-3.0,
            upper_y=3.0,
            supress_colorbar=True,
            x_ticks=[0, 15, 30],
            y_ticks=[-3, 0, 3],
        )
        ax_dict["F"].set_title(
            "$\mathregular{PCO_2}$", fontsize=mid_fontsize, y=y_pos_tit
        )

        ax_dict["D"].set_xlabel("Samples", fontsize=mid_fontsize)
        ax_dict["E"].set_xlabel("Samples", fontsize=mid_fontsize)
        ax_dict["F"].set_xlabel("Samples", fontsize=mid_fontsize)

        # Example data
        y_pos = np.arange(len(bar_labels))
        ax_dict["G"].bar(y_pos, logits, align="center", color=colors)
        ax_dict["G"].set_xticks(y_pos)
        ax_dict["G"].set_xticklabels(bar_labels, fontsize=mid_fontsize, rotation=45)
        ax_dict["G"].set_ylabel("Additive Contribution", fontsize=mid_fontsize)
        ax_dict["G"].set_ylim((-3.0, 3.0))

        cmap = mpl.cm.cool
        norm = mpl.colors.Normalize(vmin=-3.0, vmax=3.0)

        fig.colorbar(
            mpl.cm.ScalarMappable(norm=norm, cmap=cmap),
            cax=ax_dict["X"],
            orientation="vertical",
        )
        ax_dict["X"].set_ylabel("Activation", fontsize=mid_fontsize)

        plt.show()



In [None]:
assert analysis_id == "005"  # used in the paper

TP = 14
TN = 28
FP = 72
FN = 50

with plt.rc_context(fname=matplotlibrc_path):
    figure_width = 12
    with_to_height_ratio = 8 / 17

    large_fontsize = 15
    mid_fontsize = 10
    small_fontsize = 8

    bar_chart_lim = (-2, 2)
    bar_chart_lab = [-1.5, 0, 1.5]

    colors = ["C1", "C2", "C3", "C4", "C5", "C6"]
    bar_labels = [
        "$\mathregular{PCO_2}$",
        "$\mathregular{SpO_2}$",
        "PPG",
        "HR",
        "T+A",
        "NP",
    ]
    y_pos = np.arange(len(bar_labels))[::-1]

    fig = plt.figure(figsize=(figure_width, with_to_height_ratio * figure_width))
    fig.patch.set_facecolor("white")
    fig.subplots_adjust(wspace=5.0, hspace=0.3)

    ax_dict = fig.subplot_mosaic(
        """BBBBBFFFFFIIIIILLLLLN
         AAAAADDDDDGGGGGJJJJJM
         CCCCCEEEEEHHHHHKKKKKM"""
    )

    cams, signals, label, logits, bias, score = get_gam_cams(
        TP, single_id_dataset, classifier
    )
    plot_class_activation(
        signals["NP"],
        cams["NP"],
        ax=ax_dict["A"],
        fig=fig,
        supress_colorbar=True,
        x_ticks=[0, 75, 150],
        y_ticks=[-5, 0, 5],
    )
    ax_dict["A"].text(75, -4, "Nasal Pressure", fontsize=mid_fontsize, ha="center")
    ax_dict["A"].set_ylabel("normalized signal", fontsize=mid_fontsize)
    plot_class_activation(
        signals["SpO2"],
        cams["SpO2"],
        ax=ax_dict["C"],
        fig=fig,
        lower_y=-1.0,
        upper_y=1.0,
        supress_colorbar=True,
        x_ticks=[0, 15, 30],
        y_ticks=[-1, 0, 1],
    )
    ax_dict["C"].text(
        15, -0.8, "$\mathregular{SpO_2}$", fontsize=mid_fontsize, ha="center"
    )
    ax_dict["C"].set_xlabel("time (s)", fontsize=mid_fontsize)
    ax_dict["C"].set_ylabel("normalized signal", fontsize=mid_fontsize)
    performance = logits
    ax_dict["B"].barh(y_pos, performance, align="center", color=colors)
    ax_dict["B"].set_xlabel("Additive contribution", fontsize=mid_fontsize)
    ax_dict["B"].set_xlim(bar_chart_lim)
    ax_dict["B"].set_xticks([-1, 0, 1])
    ax_dict["B"].set_title(r"$\bf{True \ positive}$", fontsize=mid_fontsize)
    ax_dict["B"].set_title(
        r"$\bf{a}$", loc="left", fontsize=large_fontsize, y=1.05, x=-0.1
    )
    ax_dict["B"].set_yticks(np.arange(6))
    ax_dict["B"].set_yticklabels(bar_labels, fontsize=mid_fontsize)

    cams, signals, label, logits, bias, score = get_gam_cams(
        TN, single_id_dataset, classifier
    )
    plot_class_activation(
        signals["NP"],
        cams["NP"],
        ax=ax_dict["D"],
        fig=fig,
        supress_colorbar=True,
        x_ticks=[0, 75, 150],
        y_ticks=[-5, 0, 5],
    )
    ax_dict["D"].text(75, -4, "Nasal Pressure", fontsize=mid_fontsize, ha="center")
    plot_class_activation(
        signals["Thorax"],
        cams["Thorax"],
        ax=ax_dict["E"],
        fig=fig,
        supress_colorbar=True,
        x_ticks=[0, 75, 150],
        y_ticks=[-5, 0, 5],
    )
    ax_dict["E"].text(75, -4, "Thorax + Abdomen", fontsize=mid_fontsize, ha="center")
    ax_dict["E"].set_xlabel("time (s)", fontsize=mid_fontsize)
    performance = logits
    ax_dict["F"].barh(y_pos, performance, align="center", color=colors)
    ax_dict["F"].set_yticks([])
    ax_dict["F"].set_yticklabels([])
    ax_dict["F"].set_xlabel("Additive contribution", fontsize=mid_fontsize)
    ax_dict["F"].set_xlim(bar_chart_lim)
    ax_dict["F"].set_xticks([-1, 0, 1])
    ax_dict["F"].set_title(r"$\bf{True \ negative}$", fontsize=mid_fontsize)
    ax_dict["F"].set_title(
        r"$\bf{b}$", loc="left", fontsize=large_fontsize, y=1.05, x=-0.1
    )

    cams, signals, label, logits, bias, score = get_gam_cams(
        FP, single_id_dataset, classifier
    )
    plot_class_activation(
        signals["NP"],
        cams["NP"],
        ax=ax_dict["G"],
        fig=fig,
        supress_colorbar=True,
        x_ticks=[0, 75, 150],
        y_ticks=[-5, 0, 5],
    )
    ax_dict["G"].text(75, -4, "Nasal Pressure", fontsize=mid_fontsize, ha="center")
    plot_class_activation(
        signals["Thorax"],
        cams["Thorax"],
        ax=ax_dict["H"],
        fig=fig,
        supress_colorbar=True,
        x_ticks=[0, 75, 150],
        y_ticks=[-5, 0, 5],
    )
    ax_dict["H"].text(75, -4, "Thorax + Abdomen", fontsize=mid_fontsize, ha="center")
    ax_dict["H"].set_xlabel("time (s)", fontsize=mid_fontsize)
    performance = logits
    ax_dict["I"].barh(y_pos, performance, align="center", color=colors)
    ax_dict["I"].set_yticks([])
    ax_dict["I"].set_yticklabels([])
    ax_dict["I"].set_xlabel("Additive contribution", fontsize=mid_fontsize)
    ax_dict["I"].set_xlim(bar_chart_lim)
    ax_dict["I"].set_xticks([-1, 0, 1])
    ax_dict["I"].set_title(r"$\bf{False \ positive}$", fontsize=mid_fontsize)
    ax_dict["I"].set_title(
        r"$\bf{c}$", loc="left", fontsize=large_fontsize, y=1.05, x=-0.1
    )

    cams, signals, label, logits, bias, score = get_gam_cams(
        FN, single_id_dataset, classifier
    )
    plot_class_activation(
        signals["PR"],
        cams["PR"],
        ax=ax_dict["K"],
        fig=fig,
        supress_colorbar=True,
        x_ticks=[0, 75, 150],
        y_ticks=[-5, 0, 5],
    )
    ax_dict["J"].text(75, -4, "Nasal pressure", fontsize=mid_fontsize, ha="center")
    plot_class_activation(
        signals["NP"],
        cams["NP"],
        ax=ax_dict["J"],
        fig=fig,
        supress_colorbar=True,
        x_ticks=[0, 75, 150],
        y_ticks=[-5, 0, 5],
    )
    ax_dict["K"].text(75, -4, "Photoplethsymogram", fontsize=mid_fontsize, ha="center")
    ax_dict["K"].set_xlabel("time (s)", fontsize=mid_fontsize)
    performance = logits
    ax_dict["L"].barh(y_pos, performance, align="center", color=colors)
    ax_dict["L"].set_yticks([])
    ax_dict["L"].set_yticklabels([])
    ax_dict["L"].set_xlabel("Additive contribution", fontsize=mid_fontsize)
    ax_dict["L"].set_xlim(bar_chart_lim)
    ax_dict["L"].set_xticks([-1, 0, 1])
    ax_dict["L"].set_title(r"$\bf{False \ negative}$", fontsize=mid_fontsize)
    ax_dict["L"].set_title(
        r"$\bf{d}$", loc="left", fontsize=large_fontsize, y=1.05, x=-0.1
    )

    ax_dict["A"].set_xticklabels([0, 15, 30])
    ax_dict["D"].set_xticklabels([0, 15, 30])
    ax_dict["G"].set_xticklabels([0, 15, 30])
    ax_dict["J"].set_xticklabels([0, 15, 30])
    ax_dict["C"].set_xticklabels([0, 15, 30])
    ax_dict["E"].set_xticklabels([0, 15, 30])
    ax_dict["H"].set_xticklabels([0, 15, 30])
    ax_dict["K"].set_xticklabels([0, 15, 30])

    cmap = mpl.cm.cool
    norm = mpl.colors.Normalize(vmin=-3.0, vmax=3.0)

    cb = fig.colorbar(
        mpl.cm.ScalarMappable(norm=norm, cmap=cmap),
        cax=ax_dict["M"],
        orientation="vertical",
    )
    ax_dict["M"].set_ylabel("Activation", fontsize=mid_fontsize)
    ax_dict["N"].axis("off")

    fig.tight_layout()
    plt.show()
