# Table of Content

<a name="outline"></a>

## Setup

- [A](#seca) External Imports
- [B](#secb) Internal Imports
- [C](#secc) Lazy Dictionary (Lazy Caching)
- [D](#secd) Configurations and Paths 
- [E](#sece) Patient Interface and Train/Val/Test Partitioning
- [F](#secf) General Utility Functions


## Evaluations

- [1](#sec1) Performance Analysis for Training/Testing on MIMIC-III
- [2](#sec2) Performance Analysis for Training/Testing on MIMIC-IV
- [3](#sec3) Performance Analysis for Training on MIMIC-IV and Testing on MIMIC-III
- [4](#sec4) Risk Trajectories

<a name="seca"></a>

### A External Imports [^](#outline)

In [None]:
import sys
import os
import glob
from collections import defaultdict

from IPython.display import display

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from upsetplot import from_contents, plot, UpSet, from_indicators


<a name="secb"></a>

### B Internal Imports [^](#outline)

In [None]:
# sys.path.append('..')
sys.path.append('repo')
from icenode.train_icenode_2lr import ICENODE
from icenode.train_icenode_uniform2lr import ICENODE as ICENODE_UNIFORM
from icenode.train_gram import GRAM
from icenode.train_retain import RETAIN
from icenode.metrics import codes_auc_pairwise_tests
from icenode.metrics import evaluation_table
from icenode.utils import write_params, load_config, load_params


from icenode.mimic3.dag import CCSDAG
from icenode.mimic3.concept import DiagSubject
from icenode.jax_interface import SubjectDiagSequenceJAXInterface,  DiagnosisJAXInterface 


%load_ext autoreload
%autoreload 2

<a name="secd"></a>

### D Configurations and Paths [^](#outline)

In [None]:
mimic3_files = {
    'adm_df': 'data/mimic3_adm_df.csv.gz',
    'diag_df': 'data/mimic3_diag_df.csv.gz'
}

mimic4_files = {
    'adm_df': 'data/mimic4_adm_df.csv.gz',
    'diag_df': 'data/mimic4_diag_df.csv.gz'
}

m3_trained_dir = {
    'ICE-NODE': 'pretrained_models/M3/icenode',
    'ICE-NODE_UNIFORM': 'pretrained_models/M3/icenode_uniform',
    'GRU': 'pretrained_models/M3/gru',
    'RETAIN': 'pretrained_models/M3/retain'
}

m4_trained_dir = {
    'ICE-NODE': 'pretrained_models/M4/icenode',
    'ICE-NODE_UNIFORM': 'pretrained_models/M4/icenode_uniform',
    'GRU': 'pretrained_models/M4/gru',
    'RETAIN': 'pretrained_models/M4/retain'
}

model_cls = {
    'ICE-NODE': ICENODE,
    'ICE-NODE_UNIFORM': ICENODE_UNIFORM,
    'GRU': GRAM,
    'RETAIN': RETAIN
}   

# Same configurations for models between MIMIC-III and MIMIC-IV.
model_config = {
    clf: load_config(f'{m_dir}/config.json') for clf, m_dir in m3_trained_dir.items()
}

clfs = list(m3_trained_dir.keys())

relative_auc_config = {
    'pvalue': 0.01, 
    'min_auc': 0.9
}

plt.rcParams["font.family"] = "Loma"

<a name="sece"></a>

### E Patient Interface and Train/Val/Test Patitioning [^](#outline)

In [None]:
def get_patient_interface(mimic_files, clfs):
    adm_df = pd.read_csv(mimic_files['adm_df'])
    diag_df = pd.read_csv(mimic_files['diag_df'], dtype={'ICD9_CODE': str})
    # Cast columns of dates to datetime64
    adm_df['ADMITTIME'] = pd.to_datetime(adm_df['ADMITTIME'], infer_datetime_format=True).dt.normalize()
    adm_df['DISCHTIME'] = pd.to_datetime(adm_df['DISCHTIME'], infer_datetime_format=True).dt.normalize()

    # From the DataFrame representation to List[Subject] representation.
    subjects = DiagSubject.to_list(adm_df, diag_df)
    
    # The coding scheme of CCS
    ccs_dag = CCSDAG()
    
    # JAX vectorisation of subjects.
    interface_by_kind =  {
        'timestamped': DiagnosisJAXInterface(subjects, ccs_dag),
        'sequential': SubjectDiagSequenceJAXInterface(subjects, ccs_dag)
    }
    
    interface_kind = {
        'ICE-NODE':  'timestamped',
        'ICE-NODE_UNIFORM': 'timestamped',
        'GRU': 'sequential',
        'RETAIN': 'sequential'
    }

    return {clf: interface_by_kind[interface_kind[clf]] for clf in clfs}
    
    
m4_interface = get_patient_interface(mimic4_files, clfs)
m3_interface = get_patient_interface(mimic3_files, clfs)

m4_train_ids, m4_valid_ids, m4_test_ids = m4_interface[clfs[0]].random_splits(split1=0.7, split2=0.85, random_seed=42)
m3_train_ids, m3_valid_ids, m3_test_ids = m3_interface[clfs[0]].random_splits(split1=0.7, split2=0.85, random_seed=42)


In [None]:
m4_percentiles = m4_interface[clfs[0]].diag_flatccs_by_percentiles(20)
m3_percentiles = m3_interface[clfs[0]].diag_flatccs_by_percentiles(20)

m4_train_percentiles = m4_interface[clfs[0]].diag_flatccs_by_percentiles(20, m4_train_ids)
m3_train_percentiles = m3_interface[clfs[0]].diag_flatccs_by_percentiles(20, m3_train_ids)

In [None]:
m3_params = {clf: load_params(f'{m3_trained_dir[clf]}/params.pickle') for clf in clfs}

m4_params = {clf: load_params(f'{m4_trained_dir[clf]}/params.pickle') for clf in clfs}


<a name="secf"></a>

### F Utility Functions [^](#outline)

In [None]:
def eval_(model, ids):
    model, state = model
    return model.eval(state, ids)['diag_detectability']

def eval2_(model, ids):
    model, state = model
    return model.eval(state, ids)

def test_eval_table(dfs, metric):
    data = {}
    clfs = dfs.keys()
    for clf, df in dfs.items():
        data[clf] = df.loc[metric, "TST"].tolist()
    return pd.DataFrame(data=data, index=metric).transpose()

def get_model(clf, config, params, interface):
    model = model_cls[clf].create_model(config, interface, [], None)
    state = model.init_with_params(config, params)
    return model, state
        
def get_models(clfs, config, params, interface):
    return {clf: get_model(clf, config[clf], params[clf], interface[clf]) for clf in clfs}

def cross_predictor(clf, source_tag, target_tag):   
    _params = {'M3': m3_params[clf],
               'M4': m4_params[clf]}

    _interface = {'M3': m3_interface[clf],
                  'M4': m4_interface[clf]}
        
    return get_model(clf, model_config[clf], _params[source_tag], _interface[target_tag])
            
    
def selected_auc_barplot(clfs, auctest_df, horizontal=False, rotate_ccs=True):
    
    clfs = sorted(clfs)
    auc_df = []

    for clf in clfs:
        comp_auc = auctest_df[f'AUC({clf})']
        comp_var = auctest_df[f'VAR[AUC({clf})]']
        comp_std = comp_var.apply(np.sqrt)
        comp_desc = auctest_df['DESC'].apply(lambda t: t if len(t) < 15 else t.replace(' ', '\n'))
        df = pd.DataFrame({'AUC': comp_auc, 'std': comp_std, 'CCS': comp_desc, 'Classifier': clf})
        auc_df.append(df)
    auc_df = pd.concat(auc_df)
    
    min_auc_tick = int(auc_df['AUC'].min() * 20)/20
    max_auc_tick = int(auc_df['AUC'].max() * 20 + 1)/20
    
    vals = auc_df.pivot(index='CCS', columns='Classifier', values='AUC')
    err = auc_df.pivot(index='CCS', columns='Classifier', values='std')
    
    icenode_idx = clfs.index('ICE-NODE')

    
    colors = ['green', 'gray', 'skyblue', 'brown', 'purple', 'navy', 'pink']
    patterns = ['o',    '',     '+',       '',      '',       '',     '/']
    patterns[icenode_idx] = 'x'

    colors[icenode_idx] = 'white'
    
    pltbarconf = dict(rot=0, figsize=(10, 10), width=0.7,
                      error_kw=dict(lw=5, capsize=8, capthick=5, ecolor='salmon'),
                      color=colors, 
                      edgecolor='black') 
    if horizontal:
        # plot vals with yerr
        ax = vals.plot.barh(xerr=err, **pltbarconf)
        plt.xlabel('AUC', fontsize=32)
        plt.xticks(fontsize=30)    
        plt.xlim(min_auc_tick, max_auc_tick)

        xstart, xend = ax.get_xlim()
        ax.xaxis.set_ticks(np.arange(xstart, xend+0.01, 0.05))

        plt.yticks(fontsize=24)

        plt.ylabel(None)
        ax.tick_params(bottom=True, left=False) 
        
        ax.xaxis.grid(color='gray', linestyle='dashed')
        ax.xaxis.set_zorder(3)


        
    else:
        # plot vals with yerr
        ax = vals.plot.bar(yerr=err, **pltbarconf)
        plt.ylabel('AUC', fontsize=32)
        plt.yticks(fontsize=24)    
        plt.ylim(min_auc_tick, max_auc_tick)

        ystart, yend = ax.get_ylim()
        ax.yaxis.set_ticks(np.arange(ystart, yend+0.01, 0.05))

        plt.xticks(fontsize=30, rotation=90 * rotate_ccs)

        plt.xlabel(None)
        ax.tick_params(bottom=False, left=True) 
        
        ax.yaxis.grid(color='gray', linestyle='dashed')
        ax.yaxis.set_zorder(3)
        
    for axis in ['top', 'bottom', 'left', 'right']:
        ax.spines[axis].set_linewidth(6)  # change width
        ax.spines[axis].set_color('red')    # change color



    # Add hatches
#     patterns =('.', 'x', 'O','o','/','-', '+','O','o','\\','\\\\')
    bars = ax.patches

    hatches = [p for p in patterns for i in range(len(df))]
    for bar, hatch in zip(bars, hatches):
        bar.set_hatch(hatch)
    
    _ = ax.legend(loc='upper right',  fontsize=22)
    return ax

def make_clf_paris(clfs):
    clfs_pairs = []
    for i in range(len(clfs)):
        for j in range(i + 1, len(clfs)):
            clfs_pairs.append((clfs[i], clfs[j]))
    return tuple(sorted(clfs_pairs))
    
def relative_performance_upset(auc_tests, selected_clfs, patient_interface, train_ids, pvalue, min_auc):
    flatccs_idx2code = {idx: code for code, idx in patient_interface.diag_flatccs_idx.items()}
    flatccs_frequency_train = patient_interface.diag_flatccs_frequency(train_ids)
    
    idx2desc = lambda i: patient_interface.dag.diag_flatccs_desc[flatccs_idx2code[i]]
    auc_tests['DESC'] = auc_tests['CODE_INDEX'].apply(idx2desc)
    
    # remove codes that no classifier has scored above `min_auc`
    accepted_aucs = auc_tests.loc[:,[f'AUC({clf})' for clf in selected_clfs]].max(axis=1) > min_auc
    accepted_auc_tests = auc_tests[accepted_aucs]
    print(f'{len(accepted_auc_tests)} codes predicted an AUC higher than {min_auc} by at least one model.')
    
    test_cols = [col for col in auc_tests.columns if col[:2] == 'P0']

    # exclude tests with nans
    accepted_auc_tests = accepted_auc_tests[accepted_auc_tests.loc[:,test_cols].isnull().max(axis=1) == 0]
    
    print(f'{len(accepted_auc_tests)} codes predicted an AUC higher than {min_auc} by at least one model, with valid tests.')
    
    tests = accepted_auc_tests
    
    # Codes when no significant difference of AUCs among all pairs of models.
    common_perf = tests[tests.loc[:,test_cols].min(axis=1) > pvalue]
    
    
    auc_sets = defaultdict(set)
    clfs = tuple(sorted(selected_clfs))
    auc_sets[clfs] = set(common_perf.CODE_INDEX)
    competing_tests = tests.drop(index=common_perf.index)

    clfs_pairs = make_clf_paris(clfs)

    # Assign each code to the best model (max AUC), then assign it as well 
    # to any model with no significant difference with the best.
    for index, row in competing_tests.iterrows():
        max_auc_clf = max(clfs, key=lambda clf: row[f'AUC({clf})'])
        insignificant_diff = {(clf1, clf2): f'P0(AUC_{clf1}==AUC_{clf2})' for (clf1, clf2) in clfs_pairs \
                          if max_auc_clf in (clf1, clf2) and row[f'P0(AUC_{clf1}==AUC_{clf2})'] > pvalue}

        # Case 1: The best model is significantly outperforming all others.
        if len(insignificant_diff) == 0:
            auc_sets[max_auc_clf].add(int(row['CODE_INDEX']))
        # Case 2: Some insigificant difference with others though.
        else:
            for (clf1, clf2), test_col in insignificant_diff.items():
                # Populate the intersections.
                auc_sets[(clf1, clf2)].add(int(row['CODE_INDEX']))            
            
    # Prepare for using Upset plot -> Set Layout (passed to `from_contents`)
    content_sets = {}
    for clf in clfs:
        content_sets[clf] = auc_sets[clf] | auc_sets[clfs]
        for clf1, clf2 in clfs_pairs:
            if clf in (clf1, clf2):
                content_sets[clf].update(auc_sets[(clf1, clf2)])
    
    # Prepare for using Upset plot -> DataFrame Layout (passed to `from_indicators`)
    code_index = tests.CODE_INDEX.tolist()
    competence_assignments = {}
    for clf in clfs:
        competence_assignments[clf] = [c in content_sets[clf] for c in code_index]
    indicator_df = pd.DataFrame(competence_assignments, index=code_index)
    
    # Descriptive statistics for each code.    
    avg_aucs, n_codes = [], []
    for c in code_index:
        competent_clfs = [clf for clf in clfs if indicator_df.loc[c, clf]]
        avg_auc = tests.loc[c, list(f'AUC({clf})' for clf in competent_clfs)].mean()
        avg_aucs.append(avg_auc)
        n_codes.append(flatccs_frequency_train[c])
    data = pd.DataFrame({'Avg. AUC': avg_aucs, '#codes (train)': n_codes}, index=code_index)    
    return content_sets, indicator_df, data, common_perf, competing_tests


def styled_df(df):  
    pd.set_option('precision', 3)
    def highlight_max(s, props=''):
        return np.where(s == np.nanmax(s.values), props, '')
    
    s_df = df.style
    s_df = s_df.apply(highlight_max, props='bfseries: ;color:white;background-color:darkblue', axis=0)
    texttt = [{'selector': 'th', 'props': 'font-family: monospace;'}]

    latex_str = s_df.to_latex(convert_css=True)
    for clf in df.index.tolist():
        latex_str = latex_str.replace(clf, f'\\texttt{{{clf}}}', 1)
    latex_str = latex_str.replace('_', '\\_')
    return s_df, latex_str

<a name="sec1"></a>

## 1 Performance Analysis for Training/Testing on MIMIC-III [^](#outline)

In [None]:
m3_clfs =  ('ICE-NODE', 'ICE-NODE_UNIFORM', 'GRU', 'RETAIN')
m3_predictors = {clf: cross_predictor(clf, 'M3', 'M3') for clf in m3_clfs}

In [None]:
test_res_m3 = {clf: eval2_(model, m3_test_ids) for clf, model in m3_predictors.items()} 

In [None]:
auctests_m3 = codes_auc_pairwise_tests({k: v['diag_detectability'] for k, v in test_res_m3.items()}, fast=True)


In [None]:
test_cols = [col for col in auctests_m3.columns if col[:2] == 'P0']
auctests_m3.loc[:, test_cols].isnull().max(axis=1).sum()

In [None]:

upset_clfs = ['ICE-NODE', 'ICE-NODE_UNIFORM', 'GRU', 'RETAIN']
upsetcontents_m3, upsetindicator_m3, data_m3,  _, compete_codesm3 = relative_performance_upset(auctests_m3, 
                                                                                               upset_clfs, 
                                                                                               m3_interface[clfs[0]],
                                                                                               m3_train_ids,
                                                                                               **relative_auc_config)

upset_ctx = lambda : sns.plotting_context("paper", font_scale=1.5, rc={"font.family": "Loma", 
                                                                        'axes.labelsize': 'medium',
                                                                       'ytick.labelsize': 'medium'})
with sns.axes_style("darkgrid"), upset_ctx():
    upset_format = from_indicators(upsetindicator_m3, data=data_m3)
    upset_object = UpSet(upset_format, subset_size='count', show_counts=True)
    upset_object.style_subsets(max_subset_size=1,
                               facecolor="red",
                               edgecolor="red", linewidth=3)
    g = upset_object.plot()

    current_figure = plt.gcf()
    w, h = 3.5, 3
    wi, hi = current_figure.get_size_inches()
    current_figure.set_size_inches(hi*(w/h), hi)
    current_figure.savefig(f"upset_M3.pdf", bbox_inches='tight')

In [None]:
results_m3_k15, _ = evaluation_table(test_res_m3, m3_train_percentiles, top_k=15)


In [None]:
df_acc15 = results_m3_k15.loc[list(f'ACC-P{i}' for i in range(5)), :].transpose()
df_acc15 = df_acc15.apply(lambda x: round(x, 3))
df_acc15.to_csv(f'acc15_mimic3.csv')
s_df, ltx_s = styled_df(df_acc15)
display(s_df)
print(ltx_s)

In [None]:
competing_tests_df = auctests_m3[auctests_m3.CODE_INDEX.isin(upsetindicator_m3[upsetindicator_m3.sum(axis=1)<len(m3_clfs)].index)]
competing_tests_df.loc[:, [col for col in competing_tests_df.columns if col[:2]=='P0']]

In [None]:
upset_clfs = ['ICE-NODE', 'ICE-NODE_UNIFORM', 'GRU', 'RETAIN']#, 'ICE-NODE/G', 'ICE-NODE_UNIFORM/G', 'GRU/G']

ax = selected_auc_barplot(upset_clfs, competing_tests_df,  horizontal=True)
ax.legend(fontsize=22, title_fontsize=32,
          bbox_to_anchor=(-0.02, 1), ncol=2)
plt.xticks(fontsize=30)
plt.yticks(fontsize=30)

current_figure = plt.gcf()
w, h = 4, 4
wi, hi = current_figure.get_size_inches()
current_figure.set_size_inches(hi*(w/h), hi)

current_figure.savefig("icenode_m3.pdf", bbox_inches='tight')
plt.show()


<a name="sec2"></a>

## 2 Performance Analysis for Training/Testing on MIMIC-IV [^](#outline)

In [None]:
m4_clfs =  ('ICE-NODE', 'ICE-NODE_UNIFORM', 'GRU', 'RETAIN')
m4_predictors = {clf: cross_predictor(clf, 'M4', 'M4') for clf in m4_clfs}

In [None]:
test_res_m4 = {clf: eval2_(model, m4_test_ids) for clf, model in m4_predictors.items()} 

In [None]:
auctests_m4 = codes_auc_pairwise_tests({k: v['diag_detectability'] for k, v in test_res_m4.items()}, fast=True)


In [None]:
upset_clfs = ('ICE-NODE', 'ICE-NODE_UNIFORM', 'GRU', 'RETAIN')

upsetcontents_m4, upsetindicator_m4, data_m4,  _, compete_codesm4 = relative_performance_upset(auctests_m4, 
                                                                                               upset_clfs, 
                                                                                               m4_interface[clfs[0]], 
                                                                                               m4_train_ids,
                                                                                               **relative_auc_config)

upset_ctx = lambda : sns.plotting_context("paper",  font_scale=1.5, rc={"font.family": "Loma", 
                                                                        'axes.labelsize': 'medium',
                                                                       'ytick.labelsize': 'medium'})
with sns.axes_style("darkgrid"), upset_ctx():
    upset_format = from_indicators(upsetindicator_m4, data=data_m4)
    upset_object = UpSet(upset_format, subset_size='count', show_counts=True)
    upset_object.style_subsets(present=['ICE-NODE'], absent=('ICE-NODE_UNIFORM', 'GRU', 'RETAIN'),
                               edgecolor="red", linewidth=3, facecolor="red")
#     upset_object.add_catplot(value='#codes (train)', kind="strip")

    g = upset_object.plot()
#     g['extra1'].set_yscale('log')
    
    current_figure = plt.gcf()
    w, h = 5, 3
    wi, hi = current_figure.get_size_inches()
    current_figure.set_size_inches(hi*(w/h), hi)

    current_figure.savefig(f"upset_M4.pdf", bbox_inches='tight')

In [None]:
results_m4_k15, _ = evaluation_table(test_res_m4, m4_train_percentiles, top_k=15)

In [None]:
df_acc15 = results_m4_k15.loc[list(f'ACC-P{i}' for i in range(5)), :].transpose()
df_acc15 = df_acc15.apply(lambda x: round(x, 3))
df_acc15.to_csv(f'acc15_mimic4.csv')
s_df, ltx_s = styled_df(df_acc15)
display(s_df)
print(ltx_s)

In [None]:
icenode_m4_excl = upsetcontents_m4['ICE-NODE'] - set.union(*list(upsetcontents_m4[clf] for clf in ('RETAIN', 'GRU', 'ICE-NODE_UNIFORM')))
icenode_m4_excl = compete_codesm4[compete_codesm4['CODE_INDEX'].isin(icenode_m4_excl)]
icenode_m4_excl

In [None]:
upset_clfs = ('ICE-NODE', 'ICE-NODE_UNIFORM', 'GRU', 'RETAIN')
w, h = 4, 3
ax = selected_auc_barplot(upset_clfs, icenode_m4_excl, horizontal=True)

plt.xticks(fontsize=30)
plt.yticks(fontsize=30)
ax.legend(fontsize=22, title_fontsize=32,
          bbox_to_anchor=(-0.02, 1), ncol=2)
current_figure = plt.gcf()
w, h = 4, 4
wi, hi = current_figure.get_size_inches()
current_figure.set_size_inches(hi*(w/h), hi)

current_figure.savefig("icenode_m4.pdf", bbox_inches='tight')
plt.show()


<a name="sec3"></a>

## 3 Performance Analysis for Training on MIMIC-IV and Testing on MIMIC-III [^](#outline)

In [None]:
m4m3_clfs =  ('ICE-NODE', 'ICE-NODE_UNIFORM', 'GRU', 'RETAIN')

m3_subjects = list(m3_interface[clfs[0]].subjects.keys())
m4m3_predictors = {clf: cross_predictor(clf, 'M4', 'M3') for clf in m4m3_clfs}

In [None]:
test_res_m4m3 = {clf: eval2_(model, m3_subjects) for clf, model in m4m3_predictors.items()} 

In [None]:
auctests_m4m3 = codes_auc_pairwise_tests({k: v['diag_detectability'] for k, v in test_res_m4m3.items()}, fast=True)

In [None]:
upset_clfs = ('ICE-NODE', 'ICE-NODE_UNIFORM', 'GRU', 'RETAIN')

# upset_clfs = ('ICE-NODE', 'ICE-NODE_UNIFORM', 'GRU', 'RETAIN',
#                 'ICE-NODE/G', 'ICE-NODE_UNIFORM/G', 'GRU/G')

upsetcontents_m4m3, upsetindicator_m4m3, data_m4m3,  _, compete_codesm4m3 = relative_performance_upset(auctests_m4m3, 
                                                                                                       upset_clfs, 
                                                                                                       m4_interface[clfs[0]],
                                                                                                       m4_train_ids,
                                                                                                       **relative_auc_config)
upset_ctx = lambda : sns.plotting_context("paper", font_scale=1.5, rc={"font.family": "Loma", 
                                                                        'axes.labelsize': 'medium',
                                                                       'ytick.labelsize': 'medium'})
with sns.axes_style("darkgrid"), upset_ctx():
    upset_format = from_indicators(upsetindicator_m4m3, data=data_m4m3)
    upset_object = UpSet(upset_format, subset_size='count', show_counts=True)
    upset_object.style_subsets(present='ICE-NODE', absent=['ICE-NODE_UNIFORM', 'GRU', 'RETAIN'],
                              edgecolor="red", facecolor="red")
    # upset_object.add_catplot(value='Avg. AUC', kind="strip")
#     upset_object.add_catplot(value='#codes (train)', kind="strip")
    g = upset_object.plot()
#     g['extra1'].set_yscale('log')

    current_figure = plt.gcf()
    current_figure.savefig(f"upset_M4M3.pdf", bbox_inches='tight')

In [None]:
icenode_m4m3_excl = upsetcontents_m4m3['ICE-NODE'] - set.union(*list(upsetcontents_m4m3[clf] for clf in ('RETAIN', 'GRU', 'ICE-NODE_UNIFORM')))
icenode_m4m3_excl = compete_codesm4m3[compete_codesm4m3['CODE_INDEX'].isin(icenode_m4m3_excl)]
icenode_m4m3_excl

In [None]:
ax = selected_auc_barplot(upset_clfs, icenode_m4m3_excl, horizontal=True)

plt.xticks(fontsize=30)
plt.yticks(fontsize=30)
# ax.legend(fontsize=22, title_fontsize=32,
#           bbox_to_anchor=(0.02, 1), ncol=2)
plt.xticks(fontsize=30)
plt.yticks(fontsize=30)
ax.legend(fontsize=22, title_fontsize=32,
          bbox_to_anchor=(1, 1.17), ncol=2)

current_figure = plt.gcf()
current_figure.savefig("icenode_m4m3.pdf", bbox_inches='tight')
plt.show()


<a name="sec4"></a>

## 4 Risk Trajectories [^](#outline)

### Analyse AUC for Each Admission in the Test Partition

In [None]:
def admissions_auc_scores(model, test_ids):
    model, state = model
    return model.admissions_auc_scores(state, test_ids)

In [None]:
flatccs_idx2code = {idx: code for code, idx in m4_interface[clfs[0]].diag_flatccs_idx.items()}
flatccs_code2idx = m4_interface[clfs[0]].diag_flatccs_idx
idx2desc = lambda idx: m4_interface[clfs[0]].dag.diag_flatccs_desc[flatccs_idx2code[idx]]

In [None]:
m4_icenode_visit_auc_df = admissions_auc_scores(m4_predictors['ICE-NODE'], m4_test_ids)
m4_icenode_visit_auc_df['N_VISITS'] = m4_icenode_visit_auc_df['SUBJECT_ID'].apply(lambda i: (m4_icenode_visit_auc_df['SUBJECT_ID'] == i).sum())
m4_visit_auc_subject = m4_icenode_visit_auc_df.groupby('SUBJECT_ID').agg({'AUC': 'mean', 'N_VISITS': 'max', 'N_CODES': ['min', 'max', 'mean', 'median'], 'INTERVALS': ['mean', 'max', 'min'], 'R/T': ['min', 'max', 'mean'] })

In [None]:
m4m3_icenode_visit_auc_df = admissions_auc_scores(m4m3_predictors['ICE-NODE'], m3_interface[clfs[0]].subjects.keys())
m4m3_icenode_visit_auc_df['N_VISITS'] = m4m3_icenode_visit_auc_df['SUBJECT_ID'].apply(lambda i: (m4m3_icenode_visit_auc_df['SUBJECT_ID'] == i).sum())
m4m3_visit_auc_subject = m4m3_icenode_visit_auc_df.groupby('SUBJECT_ID').agg({'AUC': 'mean', 'N_VISITS': 'max', 'N_CODES': ['min', 'max', 'mean', 'median'], 'INTERVALS': ['mean', 'max', 'min'], 'R/T': ['min', 'max', 'mean'] })


In [None]:
m4_best_visit_auc_subjects =  m4_visit_auc_subject[(m4_visit_auc_subject.N_VISITS['max'] > 2) & (m4_visit_auc_subject.INTERVALS['max'] < 150)]
m4m3_best_visit_auc_subjects =  m4m3_visit_auc_subject[(m4m3_visit_auc_subject.N_VISITS['max'] > 1) & (m4m3_visit_auc_subject.INTERVALS['max'] < 150)]


In [None]:
len(m4_best_visit_auc_subjects), len(m4m3_best_visit_auc_subjects)

In [None]:
m4_ccs_history = {i: m4_interface[clfs[0]].diag_flatccs_history(i) for i in m4_best_visit_auc_subjects.index}
m4m3_ccs_history = {i: m3_interface[clfs[0]].diag_flatccs_history(i) for i in m4m3_best_visit_auc_subjects.index}

m4_ccs_idx_frequency = m4_interface[clfs[0]].diag_flatccs_frequency(list(m4_best_visit_auc_subjects.index))
m3_ccs_idx_frequency = m3_interface[clfs[0]].diag_flatccs_frequency(list(m4m3_best_visit_auc_subjects.index))

In [None]:
m4_history_all_ccs_codes = set(map(flatccs_code2idx.get, set.union(*[set(h.keys()) for h in m4_ccs_history.values()])))
m3_history_all_ccs_codes = set(map(flatccs_code2idx.get, set.union(*[set(h.keys()) for h in m4m3_ccs_history.values()])))
m4_history_all_ccs_codes = {idx for idx in m4_history_all_ccs_codes if m4_ccs_idx_frequency[idx] < 10}
m3_history_all_ccs_codes = {idx for idx in m3_history_all_ccs_codes if m3_ccs_idx_frequency[idx] < 10}

len(m4_history_all_ccs_codes), len(m3_history_all_ccs_codes)

In [None]:
icenode_m4_competent = upsetcontents_m4['ICE-NODE'] 
icenode_m4_competent = auctests_m4[auctests_m4['CODE_INDEX'].isin(icenode_m4_competent)]
icenode_m4_competent = icenode_m4_competent[['N_POSITIVE_CODES', 'AUC(ICE-NODE)', 'DESC']].sort_values('N_POSITIVE_CODES',ascending=False)
# icenode_m4_competent.head(50)
trajectory_ccs_codes_level2 = [
    173, 168, 169, 156, 165, 216, 171, 100, 167
]
icenode_m4_competent[icenode_m4_competent.index.isin(trajectory_ccs_codes_level2)]

In [None]:
trajectory_ccs_codes_level1 = [
    64, #renal fail 
    6, # pulm heart dx
    236, # ear dx 
]


In [None]:
m4_ccs_history_level1 = {i: history for i, history in m4_ccs_history.items() 
                         if len(set(map(flatccs_idx2code.get, trajectory_ccs_codes_level1)) & set(history.keys())) > 0}
m4m3_ccs_history_level1 = {i: history for i, history in m4m3_ccs_history.items() 
                         if len(set(map(flatccs_idx2code.get, trajectory_ccs_codes_level1)) & set(history.keys())) > 0}

m4_ccs_history_level2 = {i: history for i, history in m4_ccs_history.items() 
                         if len(set(map(flatccs_idx2code.get, trajectory_ccs_codes_level2)) & set(history.keys())) > 0}
m4m3_ccs_history_level2 = {i: history for i, history in m4m3_ccs_history.items() 
                         if len(set(map(flatccs_idx2code.get, trajectory_ccs_codes_level2)) & set(history.keys())) > 0}

In [None]:
len(m4_ccs_history_level1), len(m4m3_ccs_history_level1), len(m4_ccs_history_level2), len(m4m3_ccs_history_level2) 


In [None]:
m4_cases = set(m4_ccs_history_level1.keys()) | set(m4_ccs_history_level2.keys())
m4m3_cases = set(m4m3_ccs_history_level1.keys()) | set(m4m3_ccs_history_level2.keys())
len(m4_cases), len(m4m3_cases)

In [None]:
m4_icenode, m4_icenode_state = m4_predictors['ICE-NODE']
m4_trajectory = m4_icenode.sample_trajectory(m4_icenode_state, m4_cases, 1)

In [None]:
m4m3_icenode, m4m3_icenode_state = m4m3_predictors['ICE-NODE']
m4m3_trajectory = m4m3_icenode.sample_trajectory(m4m3_icenode_state, m4m3_cases, 1)

In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
# m4_selected_subjects = [
#     13798593, #acute-renal
#     13965528, #acute-renal
#     11907876, #pulmonary heart dx
#     13557547, #ear dx
#     10139504, #acute renal fail
#     12367864, #pulomonary-heart dx
# ]

# m4_selected_trajectory = {i: m4_trajectory[i] for i in m4_selected_subjects}

# m3_selected_subjects = [
#     50093 #pulmonary-heart dx
# ]

# m3_selected_trajectory = {i: m4m3_trajectory[i] for i in m3_selected_subjects}


In [None]:
import random

trajectory_ccs_codes_level1 = [
    64, #renal fail 
    6, # pulm heart dx
    236, # ear dx 
    # Others
    100, # Brnch/lng ca
    168, # Kidney/rnl ca
    194, # Immunity dx
]



# icenode_m4_competent.head(50)
trajectory_ccs_codes_level2 = [
    173, 168, 169, 156, 165, 216, 171, 100, 167
]


ccs_color = {
    6: 'blue',
    64: 'purple',
    236: 'orange',
    # Others
    100: 'salmon', # Brnch/lng ca
    168: 'navy', # Kidney/rnl ca
    194: 'pink', # Immunity dx
    **{idx: "#"+''.join([random.choice('0123456789ABCDEF') for j in range(6)])
                   for idx in trajectory_ccs_codes_level2}
}

In [None]:
interface = m4_interface[clfs[0]]
trajectories = m4_trajectory
save_dir = "m4_trajectories_level2"
ccs_indexes = trajectory_ccs_codes_level2# + trajectory_ccs_codes_level2

In [None]:
len(interface.dag.diag_flatccs_codes)

In [None]:
import math

plt.rcParams['figure.figsize']=(10,10)

def plot_codes(codes_dict):
    for ccs_idx in codes_dict:
        ccs_desc = idx2desc(ccs_idx)
        time, traj_vals = zip(*codes_dict[ccs_idx])
        plt.scatter(time,  traj_vals, s=100,  marker='^', 
                    color=ccs_color[ccs_idx], linewidths=2,
                    label=f'code: {ccs_desc if len(ccs_desc) < 15 else ccs_desc[:15]+".."}')
        
def plot_admission_lines(adms):
    adms, dischs = zip(*adms)
    for i, (adm_ti, disch_ti) in enumerate(zip(adms, dischs)):
        plt.axvline(x=adm_ti, color='green', linestyle='-.', label='admission' if i == 0 else None)
        plt.axvline(x=disch_ti, color='red', linestyle='--', label='discharge' if i == 0 else None)
        plt.fill_between([adm_ti, disch_ti], [1.0, 1.0], alpha=0.1, color='green')
        
        

def plot_risk_traj(trajs):
    for ccs_idx in trajs: 
        ccs_desc = idx2desc(ccs_idx)
        time, traj_vals = zip(*trajs[ccs_idx])
        time = np.concatenate(time)
        traj_vals = np.concatenate(traj_vals)
        
        plt.plot(time, traj_vals,  color=ccs_color[ccs_idx], 
                 marker='o', markersize=2, linewidth=1,
                 label=f'risk: {ccs_desc if len(ccs_desc) < 15 else ccs_desc[:15]+".."}')
    
for i, traj in list(trajectories.items()):
    plt.figure(i)
    
    adm_times = interface.adm_times(i)
    
    plot_admission_lines(adm_times)
    
    history = interface.diag_flatccs_history(i)
    
    t = traj['t']
    d = traj['d']
    

    plt_codes = defaultdict(list)
    plt_trajs = defaultdict(list)
    max_min = (-np.inf, np.inf)
    for code in history:
        ccs_idx = flatccs_code2idx[code]
        code = flatccs_idx2code[ccs_idx]
        
        if ccs_idx not in ccs_indexes:
            continue

        code_history = history[code]
        code_history_adm, code_history_disch = zip(*code_history)

        if code_history_adm[0] == adm_times[0][0]:
            plt_codes[ccs_idx].append((adm_times[0][1], d[0][0, ccs_idx]))
            
        for ti, di, (adm_time_i, disch_time_i) in zip(t, d, adm_times[1:]):
            max_min = max(max_min[0], di[:, ccs_idx].max()), min(max_min[1], di[:, ccs_idx].min())
            plt_trajs[ccs_idx].append((ti, di[:, ccs_idx]))

            
            if disch_time_i in code_history_disch:
                plt_codes[ccs_idx].append((disch_time_i, di[-1, ccs_idx]))
            

    if len(plt_codes) == 0:
        continue

    plot_codes(plt_codes)       
    plot_risk_traj(plt_trajs)

            
    # Make the major grid
    plt.grid(which='major', linestyle=':', color='gray', linewidth='1')
    # Turn on the minor ticks on
    plt.minorticks_on()
    # Make the minor grid
    plt.grid(which='minor', linestyle=':', color='black', linewidth='0.5')
    
    plt.ylim(math.floor(max_min[1]/0.05)*0.05, 
             math.ceil(max_min[0]/0.05)*0.05)
    
    ystart, yend = plt.gca().get_ylim()
    plt.gca().yaxis.set_ticks(np.arange(ystart, yend+0.01, 0.05))

    plt.ylabel('Predicted Risk ($\widehat{v}(t)$)', fontsize=26)
    plt.yticks(fontsize=24)
    plt.xlabel('Days Since First Admission ($t$)', fontsize=26)
    plt.xticks(fontsize=20)
    plt.title(f'Disease Risk Trajectory for Subject ID: {i}', fontsize=28)
    plt.legend(fontsize=22, title_fontsize=32,
          loc='upper right', bbox_to_anchor=(1.5, 0.5), ncol=1)
    
    current_figure = plt.gcf()
    current_figure.savefig(f"{save_dir}/trajectory_{i}.pdf", bbox_inches='tight')
