In [None]:
%reload_ext autoreload
%autoreload 2

import pickle
from copy import copy
from pathlib import Path
from collections import defaultdict
from functools import partial
from joblib import Parallel, delayed

import numpy as np
from scipy import stats
import matplotlib
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from sklearn.metrics import (balanced_accuracy_score,
                             average_precision_score,
                             precision_recall_curve,
                             precision_score,
                             recall_score,
                             roc_auc_score,
                             roc_curve,
                             auc,
                             f1_score)

from koafusion.various import calc_metrics_v2

In [None]:
DIR_PROJECT_ROOT = # TODO: set to project root directory
DIR_DATA_ROOT= Path(DIR_PROJECT_ROOT, "data/")
DIR_RESULTS_ROOT = Path(DIR_PROJECT_ROOT, "results/")

DIR_OUT = Path(DIR_RESULTS_ROOT, "temporary/")
DIR_OUT.mkdir(parents=True, exist_ok=True)

In [None]:
DIR_META = Path(DIR_DATA_ROOT, "meta_agg_oai.csv")

df_meta = pd.read_csv(DIR_META, header=[0, 1])
df_meta = df_meta.loc[:, "-"]

display(df_meta.head())
display(df_meta.columns)

In [None]:
def select_subset_v2(df_data, subset):
    df = df_data.copy()

    if "INJ-" in subset:
        df = df[df["P01INJ-"].isin((0, ))]
    if "INJ+" in subset:
        df = df[df["P01INJ-"].isin((1, ))]
    if "SURG-" in subset:
        df = df[df["P01KSURG-"].isin((0, ))]
    if "SURG+" in subset:
        df = df[df["P01KSURG-"].isin((1, ))]

    if "KLG_c01" in subset:
        df = df[df["XRKL"].isin((0, 1))]
    if "KLG_2" in subset:
        df = df[df["XRKL"].isin((2, ))]
    if "KLG_3" in subset:
        df = df[df["XRKL"].isin((3, ))]
    
    if "WOMAC_0-10" in subset:
        df = df[df["WOMTS-"] <= 10.]
    if "WOMAC_11-96" in subset:
        df = df[df["WOMTS-"] > 10.]
        
    return df

In [None]:
def read_cache(fn, presel=None):
    with open(fn, "rb") as f:
        c = pickle.load(f)
    
    # Covert 2D ndarrays to lists. Required by pandas to parse correctly
    for k0, v0 in c.items():
        if k0 in ("LR", "DT"):
            for k1, v1 in v0.items():
                if k1.startswith("predict_proba__") and isinstance(v1, np.ndarray):
                    c[k0][k1] = v1.tolist()

    if presel is not None:
        c = c[presel]
    if "predict_proba" in c:
        c["predict_proba"] = [e for e in c["predict_proba"]]
    df = pd.DataFrame.from_dict(c)
    #     {"AGE": float, "P02SEX": str, 'P01BMI': float, 'XRKL': int, 'exam_knee_id': str,
    #      'predict_proba': np.array, 'predict': np.array, 'target': np.array}
    return df


def one_hot(array):
    unique, inverse = np.unique(array, return_inverse=True)
    onehot = np.eye(unique.shape[0])[inverse]
    return onehot

# Experiment list

In [None]:
# TODO: specify the experiment_id's in the second component of
# TODO: the paths, i.e. instead of "____"

PATHS_EXPERIMS = {
    # (inputs, target, model, data_train, data_eval)

    ("age,sex,BMI", "prog_kl_12", "LR", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("age,sex,BMI", "prog_kl_24", "LR", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("age,sex,BMI", "prog_kl_36", "LR", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("age,sex,BMI", "prog_kl_48", "LR", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("age,sex,BMI", "tiulpin2019_prog_bin", "LR", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    
    ("age,sex,BMI,KL", "prog_kl_12", "LR", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("age,sex,BMI,KL", "prog_kl_24", "LR", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("age,sex,BMI,KL", "prog_kl_36", "LR", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("age,sex,BMI,KL", "prog_kl_48", "LR", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("age,sex,BMI,KL", "tiulpin2019_prog_bin", "LR", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    
    ("age,sex,BMI,Surg,Inj,WOMAC", "prog_kl_12", "LR", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("age,sex,BMI,Surg,Inj,WOMAC", "prog_kl_24", "LR", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("age,sex,BMI,Surg,Inj,WOMAC", "prog_kl_36", "LR", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("age,sex,BMI,Surg,Inj,WOMAC", "prog_kl_48", "LR", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("age,sex,BMI,Surg,Inj,WOMAC", "tiulpin2019_prog_bin", "LR", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    
    ("age,sex,BMI,KL,Surg,Inj,WOMAC", "prog_kl_12", "LR", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("age,sex,BMI,KL,Surg,Inj,WOMAC", "prog_kl_24", "LR", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("age,sex,BMI,KL,Surg,Inj,WOMAC", "prog_kl_36", "LR", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("age,sex,BMI,KL,Surg,Inj,WOMAC", "prog_kl_48", "LR", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("age,sex,BMI,KL,Surg,Inj,WOMAC", "tiulpin2019_prog_bin", "LR", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),

    ("XR", "prog_kl_12", "XR1", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("XR", "prog_kl_24", "XR1", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("XR", "prog_kl_36", "XR1", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("XR", "prog_kl_48", "XR1", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("XR", "tiulpin2019_prog_bin", "XR1", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),

    ("DESS", "prog_kl_12", "MR1", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("DESS", "prog_kl_24", "MR1", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("DESS", "prog_kl_36", "MR1", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("DESS", "prog_kl_48", "MR1", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("DESS", "tiulpin2019_prog_bin", "MR1", "incid", "incid"): \
         Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    
    ("TSE", "prog_kl_12", "MR1", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("TSE", "prog_kl_24", "MR1", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("TSE", "prog_kl_36", "MR1", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("TSE", "prog_kl_48", "MR1", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("TSE", "tiulpin2019_prog_bin", "MR1", "incid", "incid"): \
         Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),

    ("T2_MAP", "prog_kl_12", "MR1", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("T2_MAP", "prog_kl_24", "MR1", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("T2_MAP", "prog_kl_36", "MR1", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("T2_MAP", "prog_kl_48", "MR1", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("T2_MAP", "tiulpin2019_prog_bin", "MR1", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),

    ("DESS,TSE", "prog_kl_12", "MR2", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("DESS,TSE", "prog_kl_24", "MR2", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("DESS,TSE", "prog_kl_36", "MR2", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("DESS,TSE", "prog_kl_48", "MR2", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("DESS,TSE", "tiulpin2019_prog_bin", "MR2", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    
    ("DESS,T2_MAP", "prog_kl_12", "MR2", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("DESS,T2_MAP", "prog_kl_24", "MR2", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("DESS,T2_MAP", "prog_kl_36", "MR2", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("DESS,T2_MAP", "prog_kl_48", "MR2", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("DESS,T2_MAP", "tiulpin2019_prog_bin", "MR2", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    
    ("TSE,T2_MAP", "prog_kl_12", "MR2", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("TSE,T2_MAP", "prog_kl_24", "MR2", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("TSE,T2_MAP", "prog_kl_36", "MR2", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("TSE,T2_MAP", "prog_kl_48", "MR2", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("TSE,T2_MAP", "tiulpin2019_prog_bin", "MR2", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),

    ("XR,DESS", "prog_kl_12", "XR1MR1", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("XR,DESS", "prog_kl_24", "XR1MR1", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("XR,DESS", "prog_kl_36", "XR1MR1", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("XR,DESS", "prog_kl_48", "XR1MR1", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("XR,DESS", "tiulpin2019_prog_bin", "XR1MR1", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    
    ("XR,TSE", "prog_kl_12", "XR1MR1", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("XR,TSE", "prog_kl_24", "XR1MR1", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("XR,TSE", "prog_kl_36", "XR1MR1", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("XR,TSE", "prog_kl_48", "XR1MR1", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("XR,TSE", "tiulpin2019_prog_bin", "XR1MR1", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    
    ("XR,T2_MAP", "prog_kl_12", "XR1MR1", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("XR,T2_MAP", "prog_kl_24", "XR1MR1", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("XR,T2_MAP", "prog_kl_36", "XR1MR1", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("XR,T2_MAP", "prog_kl_48", "XR1MR1", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("XR,T2_MAP", "tiulpin2019_prog_bin", "XR1MR1", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),

    ("XR,DESS,TSE", "prog_kl_12", "XR1MR2", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("XR,DESS,TSE", "prog_kl_24", "XR1MR2", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("XR,DESS,TSE", "prog_kl_36", "XR1MR2", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("XR,DESS,TSE", "prog_kl_48", "XR1MR2", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("XR,DESS,TSE", "tiulpin2019_prog_bin", "XR1MR2", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    
    ("XR,DESS,T2_MAP", "prog_kl_12", "XR1MR2", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("XR,DESS,T2_MAP", "prog_kl_24", "XR1MR2", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("XR,DESS,T2_MAP", "prog_kl_36", "XR1MR2", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("XR,DESS,T2_MAP", "prog_kl_48", "XR1MR2", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("XR,DESS,T2_MAP", "tiulpin2019_prog_bin", "XR1MR2", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    
    ("XR,TSE,T2_MAP", "prog_kl_12", "XR1MR2", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("XR,TSE,T2_MAP", "prog_kl_24", "XR1MR2", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("XR,TSE,T2_MAP", "prog_kl_36", "XR1MR2", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("XR,TSE,T2_MAP", "prog_kl_48", "XR1MR2", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("XR,TSE,T2_MAP", "tiulpin2019_prog_bin", "XR1MR2", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    
    ("XR,DESS,T2_MAP,clin", "prog_kl_12", "XR1MR2C1", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("XR,DESS,T2_MAP,clin", "prog_kl_24", "XR1MR2C1", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("XR,DESS,T2_MAP,clin", "prog_kl_36", "XR1MR2C1", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("XR,DESS,T2_MAP,clin", "prog_kl_48", "XR1MR2C1", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("XR,DESS,T2_MAP,clin", "tiulpin2019_prog_bin", "XR1MR2C1", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
}

# Read predictions from cache

In [None]:
expers_data = dict()

In [None]:
data = dict()

for k, d in PATHS_EXPERIMS.items():
    if "age,sex,BMI" in k[0] or k[0].startswith("clin"):
        p = Path(d, "eval_clin_raw_ens.pkl")
        # Take only LR model
        t = read_cache(p, presel="LR")
        data[k] = t
    else:
        p = Path(d, "eval_fus_raw_ens.pkl")
        data[k] = read_cache(p)

expers_data["full"] = data

## Add sample selection: available for all the horizons only

In [None]:
targets = ("prog_kl_12", "prog_kl_24", "prog_kl_36", "prog_kl_48", "tiulpin2019_prog_bin")

def select_samples_in_all_targets(d_i):
    uniqs_at_target = dict()
    d_o = dict()

    for target in targets:
        for k_proc, df_proc in d_i.items():
            if k_proc[1] != target:
                continue
            uniqs = set(df_proc["exam_knee_id"].tolist())
            if target not in uniqs_at_target:
                uniqs_at_target[target] = uniqs
            else:
                uniqs_at_target[target] = set.intersection(uniqs_at_target[target], uniqs)

    uniqs_across = set.intersection(*uniqs_at_target.values())

    for k_proc, df_proc in d_i.items():
        df_proc = df_proc[df_proc["exam_knee_id"].isin(uniqs_across)]
        df_proc = df_proc.reset_index(drop=True)

        d_o[k_proc] = df_proc.copy()
    return d_o

expers_data["in_all_targets"] = select_samples_in_all_targets(expers_data["full"])

## Add sample selection: available in SAG_T2_MAP only

In [None]:
targets = ("prog_kl_12", "prog_kl_24", "prog_kl_36", "prog_kl_48", "tiulpin2019_prog_bin")

def select_samples_w_t2(d_i):
    d_o = dict()
    
    for target in targets:
        print(f"Target: {target}")

        k_ref = ("T2_MAP", target, "MR1", "incid", "incid")
        df_ref = d_i[k_ref]

        for k_proc, df_proc in d_i.items():
            if k_proc[1] != target:
                continue

            sel = df_ref["exam_knee_id"].tolist()
            df_proc = df_proc[df_proc["exam_knee_id"].isin(sel)]
            df_proc = df_proc.reset_index(drop=True)

            d_o[k_proc] = df_proc.copy()
    return d_o

expers_data["w_t2"] = select_samples_w_t2(expers_data["full"])
expers_data["in_all_targets__w_t2"] = select_samples_w_t2(expers_data["in_all_targets"])

# Calculate metrics

In [None]:
def _calc_mx_for_exp_subset(df_meta, df_exp, code_exp, subset):
    target = code_exp[1]

    t_df_exp = df_exp.merge(df_meta, how="left",
                            on="exam_knee_id", suffixes=(None, "_dup"))

    mxs = dict()

    t_df_sub = select_subset_v2(t_df_exp, subset=subset)

    t_target = np.asarray(list(map(np.asarray, t_df_sub["target"].tolist()))).ravel()
    t_pred_proba = np.asarray(list(map(np.asarray, t_df_sub["predict_proba"].tolist())))

    # 1 shot metrics
    t = calc_metrics_v2(prog_target=t_target,
                        prog_pred_proba=t_pred_proba,
                        target=target,
                        with_curves=True,
                        kws_ppv={"pi0": 0.15},
                       )
    mxs.update(t)

    # Bootstrapped metrics
    t = calc_metrics_v2(prog_target=t_target,
                        prog_pred_proba=t_pred_proba,
                        target=target,
                        bootstrap=True,
                        kws_ppv={"pi0": 0.15},
                        kws_bs={"verbose": False})
    mxs.update(t)
    return mxs

In [None]:
### Subsets
SUBSETS_v3 = [
    "all",
    
    "INJ-,SURG-",
    "INJ-,SURG-,KLG_c01,WOMAC_0-10",
    "INJ-,SURG-,KLG_c01,WOMAC_11-96",
    "INJ-,SURG-,KLG_2,WOMAC_0-10",
    "INJ-,SURG-,KLG_2,WOMAC_11-96",
    "INJ-,SURG-,KLG_3,WOMAC_0-10",
    "INJ-,SURG-,KLG_3,WOMAC_11-96",

    "INJ+,SURG-",
    "INJ+,SURG-,KLG_c01,WOMAC_0-10",
    "INJ+,SURG-,KLG_c01,WOMAC_11-96",
    "INJ+,SURG-,KLG_2,WOMAC_0-10",
    "INJ+,SURG-,KLG_2,WOMAC_11-96",
    "INJ+,SURG-,KLG_3,WOMAC_0-10",
    "INJ+,SURG-,KLG_3,WOMAC_11-96",

    "SURG+",
    "SURG+,KLG_c01,WOMAC_0-10",
    "SURG+,KLG_c01,WOMAC_11-96",
    "SURG+,KLG_2,WOMAC_0-10",
    "SURG+,KLG_2,WOMAC_11-96",
    "SURG+,KLG_3,WOMAC_0-10",
    "SURG+,KLG_3,WOMAC_11-96",
]

selections = ("w_t2", "in_all_targets__w_t2")

expers_mx_sub = dict()
dfs_mx_sub = dict()

for selection in selections:
    print(f"Selection: {selection}")
    data_in = expers_data[selection]
    data_out = dict()

    t_keys = []
    t_tasks = []

    print(f"Total: {len(data_in) * len(SUBSETS_v3)}")

    for code_exp, df_exp in data_in.items():
        for subset in SUBSETS_v3:
            t_keys.append((*code_exp, subset))
            t_args = (copy(df_meta),
                      copy(df_exp),
                      copy(code_exp),
                      copy(subset))
            t_tasks.append(delayed(_calc_mx_for_exp_subset)(*t_args))

    t_ret = Parallel(n_jobs=24, verbose=5)(t_tasks)

    for k, r in zip(t_keys, t_ret):
        data_out[k] = r

    expers_mx_sub[selection] = data_out
    
    t = pd.DataFrame.from_dict(data_out, orient="index")
    display(t)
    dfs_mx_sub[selection] = t

In [None]:
# Save groupw
for selection in selections:
    t_p = Path(DIR_OUT, f"metrics_groupw__{selection}.pkl")
    dfs_mx_sub[selection].to_pickle(t_p)

In [None]:
# Load
dfs_mx_sub = dict()
selections = ("w_t2", "in_all_targets__w_t2")
for selection in selections:
    t_p = Path(DIR_OUT, f"metrics_groupw__{selection}.pkl")
    dfs_mx_sub[selection] = pd.read_pickle(t_p)

# Reformat dfs for convenient visualization

In [None]:
dfs_vis = dict()

for s, df in dfs_mx_sub.items():
    df_t = (df
            .reset_index()
            .rename({"level_0": "inputs",
                     "level_1": "target",
                     "level_2": "model",
                     "level_3": "data_train",
                     "level_4": "data_eval",
                     "level_5": "group",
                    }, axis=1))

    df_t["ap_mean"] = [e[0] if ~np.any(np.isnan(e)) else np.nan
                       for e in df_t["avg_precision"].tolist()]
    df_t["ap_se"] = [e[1] if ~np.any(np.isnan(e)) else np.nan
                     for e in df_t["avg_precision"].tolist()]
    df_t["ap_calib_mean"] = [e[0] if ~np.any(np.isnan(e)) else np.nan
                             for e in df_t["avg_ppv_calib"].tolist()]
    df_t["ap_calib_se"] = [e[1] if ~np.any(np.isnan(e)) else np.nan
                           for e in df_t["avg_ppv_calib"].tolist()]
    df_t["avg_npv_mean"] = [e[0] if ~np.any(np.isnan(e)) else np.nan
                            for e in df_t["avg_npv"].tolist()]
    df_t["roc_auc_mean"] = [e[0] if ~np.any(np.isnan(e)) else np.nan
                            for e in df_t["roc_auc"].tolist()]
    df_t["roc_auc_se"] = [e[1] if ~np.any(np.isnan(e)) else np.nan
                          for e in df_t["roc_auc"].tolist()]

    dfs_vis[s] = df_t

# Find the highest ranked MRI and multimodal fusion models

In [None]:
rank_metrics = ["ap_calib_mean", "roc_auc_mean"]
ranking = dict()

for selection in selections:
    df_t = dfs_vis[selection]
    
    # Keep only full testset scores
    df_t = df_t[df_t["group"] == "all"]
    
    # Keep only subset of models
    df_t = df_t[df_t["model"].isin(['MR1', 'MR2', 'XR1MR1',
                                    'XR1MR2', 'XR1MR2C1'])]
    
    uniq_inputs = list(pd.unique(df_t["inputs"]))
    n_inputs = len(pd.unique(df_t["inputs"]))
    n_targets = len(pd.unique(df_t["target"]))
    n_metrics = len(rank_metrics)
    # Linear scoring rule: 0, 1, 2, ...
    score_map = {e: e for e in range(n_inputs)}
                   
    score_matrix = {k: 0 for k in uniq_inputs}
    
    for rank_metric in rank_metrics:
        for _, df_gb in df_t.groupby(by=["target"]):
            df_gb = df_gb.sample(frac=1, random_state=0)  # reshuffle to randomly break sorting ties
            t_inputs = df_gb["inputs"].tolist()
            t_values = df_gb[rank_metric].tolist()
            t_sort_idcs = np.argsort(t_values)  # ascending order
            
            for i in range(len(t_inputs)):
                input_to_upd = t_inputs[t_sort_idcs[i]]
                score_matrix[input_to_upd] += score_map[i]
        
    ranking[selection] = score_matrix

In [None]:
display(pd.DataFrame.from_dict(ranking))

# Table. Metrics summary

In [None]:
t_df = dfs_vis["w_t2"]
t_df = t_df[t_df["group"] == "all"]

# t_metric = "prevalence"
t_metric = "prevalence"
t_metrics = ["ap_mean",
             "ap_calib_mean",
#              "avg_ppv_weighted_mean", "avg_ppv_adjusted_mean",
#              "avg_npv_mean",
#              "avg_npv_adjusted_mean",
             "roc_auc_mean",
#              "youdens_index"
            ]

for t_metric in t_metrics:
    print(f"---- {t_metric} ----")
    t_res = pd.pivot_table(t_df, values=t_metric,
                           index=["inputs", "model"], columns=["target"])
    display(t_res.sort_values(by=["inputs", "model"]))

# Figure. Group-wise radar

In [None]:
for s, d in dfs_mx_sub.items():
    display(s)
    display(d.head())

In [None]:
from matplotlib.patches import Circle, RegularPolygon
from matplotlib.path import Path as MPLPath
from matplotlib.projections.polar import PolarAxes
from matplotlib.projections import register_projection
from matplotlib.spines import Spine
from matplotlib.transforms import Affine2D


def radar_factory(num_vars, offset=0, frame='circle'):
    """Create a radar chart with `num_vars` axes.

    This function creates a RadarAxes projection and registers it.

    Parameters
    ----------
    num_vars : int
        Number of variables for radar chart.
    frame : {'circle' | 'polygon'}
        Shape of frame surrounding axes.

    """
    # calculate evenly-spaced axis angles
    theta = (np.linspace(0, 2*np.pi, num_vars, endpoint=False) + \
             (2*np.pi/num_vars) * offset)
    theta = theta % (2*np.pi)

    class RadarTransform(PolarAxes.PolarTransform):
        def transform_path_non_affine(self, path):
            # Paths with non-unit interpolation steps correspond to gridlines,
            # in which case we force interpolation (to defeat PolarTransform's
            # autoconversion to circular arcs).
            if path._interpolation_steps > 1:
                path = path.interpolated(num_vars)
            return MPLPath(self.transform(path.vertices), path.codes)

    class RadarAxes(PolarAxes):
        name = 'radar'
        PolarTransform = RadarTransform

        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
            # Rotate plot such that the first axis is at the top
            self.set_theta_zero_location("N", offset=(2*np.pi/num_vars) * offset)

        def fill(self, *args, closed=True, **kwargs):
            """Override fill so that line is closed by default"""
            return super().fill(closed=closed, *args, **kwargs)

        def plot(self, *args, **kwargs):
            """Override plot so that line is closed by default"""
            lines = super().plot(*args, **kwargs)
            for line in lines:
                self._close_line(line)

        def _close_line(self, line):
            x, y = line.get_data()
            # FIXME: markers at x[0], y[0] get doubled-up
            if x[0] != x[-1]:
                x = np.append(x, x[0])
                y = np.append(y, y[0])
                line.set_data(x, y)

        def set_varlabels(self, labels):
            self.set_thetagrids(np.degrees(theta), labels)

        def _gen_axes_patch(self):
            # The Axes patch must be centered at (0.5, 0.5) and of radius 0.5
            # in axes coordinates.
            if frame == 'circle':
                return Circle((0.5, 0.5), 0.5)
            elif frame == 'polygon':
                return RegularPolygon((0.5, 0.5), num_vars,
                                      radius=0.5, edgecolor="k")
            else:
                raise ValueError("unknown value for 'frame': %s" % frame)

        def _gen_axes_spines(self):
            if frame == 'circle':
                return super()._gen_axes_spines()
            elif frame == 'polygon':
                # spine_type must be 'left'/'right'/'top'/'bottom'/'circle'.
                spine = Spine(axes=self,
                              spine_type='circle',
                              path=MPLPath.unit_regular_polygon(num_vars))
                # unit_regular_polygon gives a polygon of radius 1 centered at
                # (0, 0) but we want a polygon of radius 0.5 centered at (0.5,
                # 0.5) in axes coordinates.
                spine.set_transform(Affine2D().scale(.5).translate(.5, .5)
                                    + self.transAxes)
                return {'polar': spine}
            else:
                raise ValueError("unknown value for 'frame': %s" % frame)

    register_projection(RadarAxes)
    return theta

In [None]:
#### Common configs for radar plots
matplotlib.rcParams.update(matplotlib.rcParamsDefault)

# Selection for Figure
t_inputs = [
#     'age,sex,BMI',  # clin
#     'age,sex,BMI,KL',
    'age,sex,BMI,Surg,Inj,WOMAC',
#     'age,sex,BMI,KL,Surg,Inj,WOMAC',
    'XR',
    'DESS',  # MR1
#     'TSE',
#     'T2_MAP',
#     'DESS,TSE',  # MR2
#     'DESS,T2_MAP',
#     'TSE,T2_MAP',
#     'XR,DESS',  # XR1MR1
#     'XR,TSE',
#     'XR,T2_MAP',
#     'XR,DESS,TSE',  # XR1MR2
    'XR,DESS,T2_MAP',
#     'XR,TSE,T2_MAP',
#     'XR,DESS,T2_MAP,clin',  # XR1MR2C1
]

inputs_to_vis_model = {
    "age,sex,BMI": "C1",
    "age,sex,BMI,KL": "C2",
    "age,sex,BMI,Surg,Inj,WOMAC": "C3",
    "age,sex,BMI,KL,Surg,Inj,WOMAC": "C4",
    "XR": "X", "DESS": "M1", "TSE": "M2", "T2_MAP": "M3",
    "DESS,T2_MAP": "F5", "DESS,TSE": "F4", "TSE,T2_MAP": "F6",
    "XR,DESS": "F1", "XR,TSE": "F2", "XR,T2_MAP": "F3",
    "XR,DESS,T2_MAP": "F8", "XR,DESS,TSE": "F7", "XR,TSE,T2_MAP": "F9",
    "XR,DESS,T2_MAP,clin": "U",
}

plot_group_to_group_to_x = {
    "inj0_surg0": {
        "INJ-,SURG-": 0,
        "INJ-,SURG-,KLG_c01,WOMAC_0-10": 1,
        "INJ-,SURG-,KLG_c01,WOMAC_11-96": 2,
        "INJ-,SURG-,KLG_2,WOMAC_0-10": 3,
        "INJ-,SURG-,KLG_2,WOMAC_11-96": 4,
        "INJ-,SURG-,KLG_3,WOMAC_0-10": 5,
        "INJ-,SURG-,KLG_3,WOMAC_11-96": 6,
    },
    "inj1_surg0": {
        "INJ+,SURG-": 0,
        "INJ+,SURG-,KLG_c01,WOMAC_0-10": 1,
        "INJ+,SURG-,KLG_c01,WOMAC_11-96": 2,
        "INJ+,SURG-,KLG_2,WOMAC_0-10": 3,
        "INJ+,SURG-,KLG_2,WOMAC_11-96": 4,
        "INJ+,SURG-,KLG_3,WOMAC_0-10": 5,
        "INJ+,SURG-,KLG_3,WOMAC_11-96": 6,
    },
    "surg1": {
        "SURG+": 0,
        "SURG+,KLG_c01,WOMAC_0-10": 1,
        "SURG+,KLG_c01,WOMAC_11-96": 2,
        "SURG+,KLG_2,WOMAC_0-10": 3,
        "SURG+,KLG_2,WOMAC_11-96": 4,
        "SURG+,KLG_3,WOMAC_0-10": 5,
        "SURG+,KLG_3,WOMAC_11-96": 6,
    },
}

plot_group_to_drop_prefix = {
    "inj0_surg0": "INJ-,SURG-,",
    "inj1_surg0": "INJ+,SURG-,",
    "surg1": "SURG+,",
    "inj0_surg1": "INJ-,SURG+,",
    "inj1_surg1": "INJ+,SURG+,",
}

short_group_to_vis_group = {
    "KLG_c01": "KLG 0/1",
    "KLG_c23": "KLG 2/3",
    "KLG_0": "KLG 0",
    "KLG_1": "KLG 1",
    "KLG_2": "KLG 2",
    "KLG_3": "KLG 3",
    "WOMAC_0-10": "WOMAC \n 0-10",
    "WOMAC_11-96": "WOMAC \n 11-96",
    #
    "KLG_c01,WOMAC_0-10": "KLG 0/1\nSx-",
#     "KLG_2,WOMAC_0-10": "KLG 2\nSx-",
    "KLG_2,WOMAC_0-10": "KLG 2  Sx-",
    "KLG_3,WOMAC_0-10": "KLG 3\nSx-",
    "KLG_c01,WOMAC_11-96": "KLG 0/1\nSx+",
#     "KLG_2,WOMAC_11-96": "KLG 2\nSx+",
    "KLG_2,WOMAC_11-96": "KLG 2  Sx+",
    "KLG_3,WOMAC_11-96": "KLG 3\nSx+",
    #
    "INJ-,SURG-": "all",
    "INJ+,SURG-": "all",
    "INJ-,SURG+": "all",
    "INJ+,SURG+": "all",
    "SURG+": "all",
    "all": "all",
}

In [None]:
#### Figure. __Target-average__ metrics at sub-groups

colors = sns.color_palette(palette="tab10", n_colors=10, as_cmap=True)
inputs_to_color = lambda x: colors(t_inputs.index(x) / 10)

t_plot_groups = ("inj0_surg0", "inj1_surg0", "surg1")
t_targets = ("prog_kl_12", "prog_kl_24", "prog_kl_36", "prog_kl_48", "tiulpin2019_prog_bin")

# TODO: select
t_metrics = [#"ap_mean",
             "ap_calib_mean",
#              "avg_ppv_weighted_mean", "avg_ppv_adjusted_mean",
#              "avg_npv_mean",
#              "avg_npv_adjusted_mean",
             "roc_auc_mean"]

metric_to_ylim = defaultdict(
    lambda: (0.0, 1.005), {
#     "ap_mean": (0.0, 1.0),
#     "npv_mean": (0.0, 1.0),
#     "roc_mean": (0.0, 1.0),
})

# df_t = dfs_vis['w_t2'].copy()
df_t = dfs_vis['in_all_targets__w_t2'].copy()
# \\\\

for plot_group in t_plot_groups:
    print(f"Plot group: {plot_group}")
    for metric in t_metrics:
        print(f"Metric: {metric}")

        group_to_x = plot_group_to_group_to_x[plot_group]

        t_groups = list(group_to_x.keys())

        N = len(t_groups)
        theta = radar_factory(N, offset=0, frame='polygon')

        fig, ax = plt.subplots(figsize=(4.1, 4.1),
                               nrows=1, ncols=1,
                               subplot_kw=dict(projection='radar', zorder=1))

        # Ylim
        t_ylim = metric_to_ylim[metric]
        ax.set_ylim(t_ylim)

        ax.spines["polar"].set_visible(False)

        for inputs in t_inputs:
            print(f"Inputs: {inputs}")

            color = inputs_to_color(inputs)

            df_t_i = df_t.copy()
            df_t_i = df_t_i[df_t_i["inputs"] == inputs]
            df_t_i = df_t_i[df_t_i["group"].isin(t_groups)]

            df_t_i["group_idx"] = [group_to_x[e] for e in df_t_i["group"].tolist()]
            df_t_i = df_t_i.sort_values(by="group_idx")

            df_t_i = df_t_i[df_t_i["target"].isin(t_targets)]
            
            y = df_t_i[["group_idx", metric]].groupby(by="group_idx").agg(np.nanmean)[metric]
            
            ss = df_t_i[["group_idx", "sample_size"]].groupby(by="group_idx").agg(np.mean)["sample_size"]

            ax.plot(theta, y, color=color, zorder=1e3, label=inputs)

        # Xticklabels
        variables = list([short_group_to_vis_group[e.replace(plot_group_to_drop_prefix[plot_group], '')]
                          for e in t_groups])
        ss = [s if not np.isnan(s) else 0 for s in ss]
        variables = [f"{v}\n($\it{{n={round(s)}}}$)" for v, s in zip(variables, ss)]
        ax.set_varlabels(variables)
        
        for tl in ax.get_xticklabels():
            tl.set_bbox(dict(boxstyle="round", facecolor="white",
                             edgecolor="lightgray",
                             alpha=0.6,
#                              alpha=0.1,
                             zorder=1e4))
        for tl in ax.get_xticklabels():
            tl.set_zorder(1e4)
        
        # Axes
    #         ax.xaxis.set_zorder(-1)
        ax.yaxis.set_zorder(-1)
        
        # Yticklabels
#         ax.tick_params(axis='y', labelsize=10, rotation=25)
        import matplotlib.transforms
        
        # Create offset transform
        dx = -50/300; dy = -25/300
        offset = matplotlib.transforms.ScaledTranslation(dx, dy, fig.dpi_scale_trans)

        # Aply offset transform to all x ticklabels.
        for tl in ax.yaxis.get_majorticklabels():
            tl.set_transform(tl.get_transform() + offset)

        # Legend
        labels = [f"${inputs_to_vis_model[e]}$" for e in t_inputs]
        ax.legend(labels, bbox_to_anchor=(-0.1, 1.12, 1.23, .18), #loc='lower left',
                  ncol=4, mode="expand", borderaxespad=0.)
       
        # Outer frame. Do not show
        ax.spines["polar"].set_visible(False)
        
        plt.tight_layout()
        
        # Output
        t_dir_out = Path(DIR_OUT, f"radar__groupw__targetavg")
        t_dir_out.mkdir(exist_ok=True)

        for fmt in ("png", "pdf"):
            t_path_out = Path(t_dir_out, f"{plot_group}__{metric}.{fmt}")
            plt.savefig(t_path_out, dpi=300)
        plt.close()
#         plt.show()

# Statistical testing

In [None]:
stat_tests_all = {}

FIELDS_REF_CMP = [
    # ---- Clinical ----
    (("age,sex,BMI", "LR"), ("age,sex,BMI,KL", "LR")),
    (("age,sex,BMI", "LR"), ("age,sex,BMI,Surg,Inj,WOMAC", "LR")),
    (("age,sex,BMI", "LR"), ("age,sex,BMI,KL,Surg,Inj,WOMAC", "LR")),
     # ---- \\\\ ----
    
    # ---- Single modality ----
    (("age,sex,BMI,Surg,Inj,WOMAC", "LR"), ("XR1",)),
    (("age,sex,BMI,Surg,Inj,WOMAC", "LR"), ("MR1",)),
    (("XR1",),                             ("MR1",)),
    # ---- \\\\ ----
    
    # ---- Fusion vs the rest ----
    (("age,sex,BMI,Surg,Inj,WOMAC", "LR"), ("XR1MR1",)),
    (("XR1",),                             ("XR1MR1",)),
    (("MR1",),                             ("XR1MR1",)),

    (("age,sex,BMI,Surg,Inj,WOMAC", "LR"), ("MR2",)),
    (("XR1",),                             ("MR2",)),
    (("MR1",),                             ("MR2",)),

    (("age,sex,BMI,Surg,Inj,WOMAC", "LR"), ("XR1MR2",)),
    (("XR1",),                             ("XR1MR2",)),
    (("MR1",),                             ("XR1MR2",)),

    (("age,sex,BMI,Surg,Inj,WOMAC", "LR"), ("XR1MR2C1",)),
    (("XR1",),                             ("XR1MR2C1",)),
    (("MR1",),                             ("XR1MR2C1",)),
    (("XR,DESS,T2_MAP", "XR1MR2"),         ("XR1MR2C1",)),
    # ---- \\\\ ----
]

TARGETS = ("prog_kl_12", "prog_kl_24", "prog_kl_36", "prog_kl_48", "tiulpin2019_prog_bin")


def statistic_roc_auc(x_pred_0, x_pred_1, x_target):
    s_0 = roc_auc_score(x_target, x_pred_0)
    s_1 = roc_auc_score(x_target, x_pred_1)
    return s_1 - s_0


def statistic_ap(x_pred_0, x_pred_1, x_target):
    s_0 = average_precision_score(x_target, x_pred_0)
    s_1 = average_precision_score(x_target, x_pred_1)
    return s_1 - s_0

In [None]:
### Permutation testing
# expers_data -- cached predictions
N_RESAMPLES = 1000

for target in TARGETS:
    print(f"Target: {target}")

    t_data = expers_data["w_t2"]

    expers_sel = {k: v for k, v in t_data.items() if target in k}
    
    for fields_ref, fields_cmp in FIELDS_REF_CMP:
        # Select experiments
        keys_ref = list([k for k in expers_sel.keys() if
                         all([f in k for f in fields_ref])])
        keys_cmp = list([k for k in expers_sel.keys() if
                         all([f in k for f in fields_cmp])])
        
        for key_ref in keys_ref:
            for key_cmp in keys_cmp:
                print(f"Ref: {repr(key_ref)} || Cmp: {repr(key_cmp)}")

                key_m = (target, key_ref, key_cmp)
                stat_tests_all[key_m] = dict()

                v_ref = expers_sel[key_ref]
                v_cmp = expers_sel[key_cmp]

                proba_target = np.asarray(list(map(np.asarray, v_ref["target"].tolist()))).ravel()
                pred_proba_ref = np.asarray(list(map(np.asarray, v_ref["predict_proba"].tolist())))
                pred_proba_cmp = np.asarray(list(map(np.asarray, v_cmp["predict_proba"].tolist())))

                x_ref = pred_proba_ref[:, 1]
                x_cmp = pred_proba_cmp[:, 1]

                # ROC AUC
                fn = partial(statistic_roc_auc, x_target=proba_target)
                ret = stats.permutation_test(data=(x_ref, x_cmp),
                                             permutation_type="samples",
                                             n_resamples=N_RESAMPLES,
                                             # alternative="greater",
                                             alternative="two-sided",
                                             statistic=fn)
                stat_tests_all[key_m].update({"pvalue__roc_auc": ret.pvalue,
                                              "statistic__roc_auc": ret.statistic})

                # AP
                fn = partial(statistic_ap, x_target=proba_target)
                ret = stats.permutation_test(data=(x_ref, x_cmp),
                                             permutation_type="samples",
                                             n_resamples=N_RESAMPLES,
#                                              alternative="greater",
                                             alternative="two-sided",
                                             statistic=fn)
                stat_tests_all[key_m].update({"pvalue__ap": ret.pvalue,
                                              "statistic__ap": ret.statistic})
    

t = pd.DataFrame.from_dict(stat_tests_all, orient="index")
display(t)
# df_mx_all = t

In [None]:
### Do thresholding and save to file
def apply_signif_thresh(x, level=0.05):
    return x < level

path_out = Path(DIR_OUT, "t__stat_tests.csv")

t_out = t.copy()
t_out = t_out.reset_index().rename({"level_0": "target",
                                    "level_1": "left",
                                    "level_2": "right"}, axis=1)
t_out = t_out.sort_values(by=["target", "right", "left"])
t_out["signif_0.05__roc_auc"] = apply_signif_thresh(t_out["pvalue__roc_auc"], level=0.05).astype(int)
t_out["signif_0.05__ap"] = apply_signif_thresh(t_out["pvalue__ap"], level=0.05).astype(int)

t_out.to_csv(path_out)

In [None]:
t_summary = t_out.copy()
t_summary["ref"] = [e[0] for e in t_summary["left"].tolist()]
t_summary["data"] = [e[0] for e in t_summary["right"].tolist()]
t_summary["architecture"] = [e[2] for e in t_summary["right"].tolist()]
t_summary = t_summary.drop(columns=[
    "left", "right",
    "pvalue__roc_auc", "statistic__roc_auc", "pvalue__ap", "statistic__ap",
])
display(t_summary)

In [None]:
# TABLE. Significant differences

t_replacements = {
    "age,sex,BMI,Surg,Inj,WOMAC": "c",
    "XR": "x",
    "DESS": "m",
    # Ignore
    "age,sex,BMI": "",
    "age,sex,BMI,KL": "",
    "TSE": "",
    "T2_MAP": "",
    "XR,DESS": "",
    "XR,TSE": "",
    "XR,T2_MAP": "",
    "DESS,TSE": "",
    "DESS,T2_map": "",
}

t_roc_auc = t_summary[["data", "architecture", "ref", "target", "signif_0.05__roc_auc"]]
t_roc_auc.loc[:, "signif_0.05__roc_auc"] = t_roc_auc["signif_0.05__roc_auc"] * t_roc_auc["ref"]
t_roc_auc.loc[:, "signif_0.05__roc_auc"] = t_roc_auc["signif_0.05__roc_auc"].replace(t_replacements)
t_roc_auc = t_roc_auc.drop(columns=["ref",])
display(t_roc_auc)

t_ap = t_summary[["data", "architecture", "ref", "target", "signif_0.05__ap"]]
t_ap.loc[:, "signif_0.05__ap"] = t_ap["signif_0.05__ap"] * t_ap["ref"]
t_ap.loc[:, "signif_0.05__ap"] = t_ap["signif_0.05__ap"].replace(t_replacements)
t_ap = t_ap.drop(columns=["ref",])
display(t_ap)

t_roc_auc = t_roc_auc[t_roc_auc["data"].isin([
    "DESS,T2_MAP",
    "DESS,TSE",
    "XR,DESS",
    "XR,DESS,T2_MAP",
    "XR,DESS,T2_MAP,clin",
])]
t_ap = t_ap[t_ap["data"].isin([
    "DESS,T2_MAP",
    "DESS,TSE",
    "XR,DESS",
    "XR,DESS,T2_MAP",
    "XR,DESS,T2_MAP,clin",
])]

print("ROC AUC")
display(
    t_roc_auc.pivot_table(index=("data", "architecture"),
                          columns="target",
                          aggfunc=lambda x: " ".join(x.astype(str)),
                          values="signif_0.05__roc_auc")
)

print("AP")
display(
    t_ap.pivot_table(index=("data", "architecture"),
                     columns="target",
                     aggfunc=lambda x: " ".join(x.astype(str)),
                     values="signif_0.05__ap")
)

# Model explanation

In [None]:
# TODO: specify experiment_id's (can be copied from the beginning of the notebook)

PATHS_EXPERIMS = {
    # (inputs, target, model, data_train, data_eval)

    ("DESS,T2_MAP", "prog_kl_12", "MR2", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("DESS,T2_MAP", "prog_kl_24", "MR2", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("DESS,T2_MAP", "prog_kl_36", "MR2", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("DESS,T2_MAP", "prog_kl_48", "MR2", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("DESS,T2_MAP", "tiulpin2019_prog_bin", "MR2", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    
    ("XR,DESS", "prog_kl_12", "XR1MR1", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("XR,DESS", "prog_kl_24", "XR1MR1", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("XR,DESS", "prog_kl_36", "XR1MR1", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("XR,DESS", "prog_kl_48", "XR1MR1", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("XR,DESS", "tiulpin2019_prog_bin", "XR1MR1", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),

    ("XR,DESS,T2_MAP", "prog_kl_12", "XR1MR2", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("XR,DESS,T2_MAP", "prog_kl_24", "XR1MR2", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("XR,DESS,T2_MAP", "prog_kl_36", "XR1MR2", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("XR,DESS,T2_MAP", "prog_kl_48", "XR1MR2", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("XR,DESS,T2_MAP", "tiulpin2019_prog_bin", "XR1MR2", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    
    ("XR,DESS,T2_MAP,clin", "prog_kl_12", "XR1MR2", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("XR,DESS,T2_MAP,clin", "prog_kl_24", "XR1MR2", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("XR,DESS,T2_MAP,clin", "prog_kl_36", "XR1MR2", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("XR,DESS,T2_MAP,clin", "prog_kl_48", "XR1MR2", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
    ("XR,DESS,T2_MAP,clin", "tiulpin2019_prog_bin", "XR1MR2", "incid", "incid"): \
        Path(DIR_RESULTS_ROOT, "____", "logs_eval", "incid"),
}

In [None]:
def read_pkl_explain(p):
    import pickle
    with open(p, "rb") as f:
        x = pickle.load(f)
    df = pd.DataFrame.from_dict(x)
    return df

In [None]:
expers_data = {}

for k, d in PATHS_EXPERIMS.items():
    if "age,sex,BMI" in k[0] or k[0].startswith("clin"):
        pass
    else:
        p = Path(d, "explain_fus_raw_ens.pkl")
        expers_data[k] = read_pkl_explain(p)

In [None]:
map_target_to_timepoint = {
    "prog_kl_12": 12,
    "prog_kl_24": 24,
    "prog_kl_36": 36,
    "prog_kl_48": 48,
    "tiulpin2019_prog_bin": 96,
}

# TODO: uncomment only one of the lines
# sel_cond = lambda x: x[0] == "XR,DESS"
# sel_cond = lambda x: x[0] == "DESS,T2_MAP"
# sel_cond = lambda x: x[0] == "XR,DESS,T2_MAP"
sel_cond = lambda x: x[0] == "XR,DESS,T2_MAP,clin"

sel_k = [k for k in PATHS_EXPERIMS.keys() if sel_cond(k)]
sel_expers = {k: v for k, v in expers_data.items() if k in sel_k}

In [None]:
t_dfs = []

for k, df in sel_expers.items():
    target = k[1]
    df["timepoint"] = map_target_to_timepoint[target]
    
    names_modals = df["modal_names"].iloc[0]
    num_modals = len(names_modals)
    
    df = df[["modal_names", "modal_abl_percent", "timepoint"]]
    
    for i in range(num_modals):
        df[names_modals[i]] = [e[i] for e in df["modal_abl_percent"].tolist()]
    
    df = df.drop(columns=["modal_names", "modal_abl_percent"])

    t_dfs.append(df)
    
    print(k)
    print(np.mean(df, axis=0))
#     break

In [None]:
df_vis = pd.concat(t_dfs, axis=0, ignore_index=True)
df_vis = pd.melt(df_vis, "timepoint", value_vars=names_modals, var_name="modality", value_name="percent")

# Cosmetics
df_vis = df_vis.replace({"modality": {"xr_pa": "XR",
                                      "sag_3d_dess": "DESS",
                                      "sag_t2_map": "T$_2$map",
                                      "cor_iw_tse": "TSE",
                                      "clin": "clinical"
                                     }})

display(df_vis)

In [None]:
# Figure. Utilization
matplotlib.rcParams.update(matplotlib.rcParamsDefault)
matplotlib.rcParams.update({'font.size': 12})

fig, axes = plt.subplots(figsize=(3.9, 3.4))

colors = sns.color_palette()
palette = {k: v for k, v in zip(["XR", "DESS", "T$_2$map", "TSE", "clinical"], colors)}

sns.lineplot(x="timepoint", y="percent", hue="modality",
             data=df_vis,
             err_style="band", errorbar="sd",
             palette=palette,
             ax=axes,
             marker="o",
            )
plt.legend(title="Modality", loc="center right")
plt.xlabel("Horizon, $\it{months}$")
plt.ylabel("Relative utilization rate")
plt.ylim((-0.05, 1.05))
plt.grid(axis="y", alpha=0.5)

dir_out = Path(DIR_OUT, "explain")
dir_out.mkdir(exist_ok=True)
fname = "__".join(names_modals)
path_out = Path(dir_out, f"{fname}.pdf")

plt.tight_layout()

plt.savefig(path_out, dpi=300)
plt.close()