In [None]:
import sys
import os
import pandas as pd
import seaborn as sns
import glob

In [None]:
sys.path.append('..')
HOME = os.environ.get('HOME')


In [None]:
from icenode.train_icenode_2lr import ICENODE
from icenode.train_icenode_uniform2lr import ICENODE as ICENODE_UNIFORM
from icenode.train_gram import GRAM
from icenode.train_retain import RETAIN
from icenode.metrics import codes_auc_pairwise_tests

%load_ext autoreload
%autoreload 2

In [None]:
class LazyDict(dict):
    def __getitem__(self, k):
        v = super().__getitem__(k)
        if callable(v):
            v = v()
            super().__setitem__(k, v)
        return v

    def get(self, k, default=None):
        if k in self:
            return self.__getitem__(k)
        return default

In [None]:
from icenode.utils import load_config, load_params


mimic_dir = {
    'M3': f'{HOME}/GP/ehr-data/mimic3-transforms',
    'M4': f'{HOME}/GP/ehr-data/mimic4-transforms'
}

trained_dir = {
    'M3': f'{HOME}/GP/ehr-data/icenode-m3-exp/train_config_v0.2.25_M3',
    'M4': f'{HOME}/GP/ehr-data/icenode-m4-exp/train_config_v0.2.25_M4'
}

model_dir = {
    'ICENODE': 'icenode_2lr',
    'ICENODE_UNIFORM': 'icenode_uniform2lr',
    'GRU': 'gru',
    'RETAIN': 'retain'
}

model_cls = {
    'ICENODE': ICENODE,
    'ICENODE_UNIFORM': ICENODE_UNIFORM,
    'GRU': GRAM,
    'RETAIN': RETAIN
}   

def get_trained_models(data_tag, clfs, criterion, comp):
    params = {}
    config = {}
    clfs_params_dir = trained_dir[data_tag]
    
    for clf in clfs:
        clf_dir = model_dir[clf]
        csv_files =  glob.glob(f'{clfs_params_dir}/{clf_dir}/*.csv', recursive=False)
        dfs = [pd.read_csv(csv_file, index_col=[0]) for csv_file in csv_files]
        max_i = comp(range(len(dfs)), key=lambda i: dfs[i].loc[criterion, 'VAL'])
        
        print(f'{clf}@{max_i} {criterion}={dfs[max_i].loc[criterion, "VAL"]}')
        csv_file = csv_files[max_i]
        prefix = csv_file.split('_')
        prefix[-1] = 'params.pickle'
        params_file = '_'.join(prefix)
        params[clf] = load_params(params_file)
        config[clf] = load_config(f'{clfs_params_dir}/{clf_dir}/config.json')
    return config, params

def get_patient_interface(data_tag, clfs):
    interface_by_kind = LazyDict({
        'timestamped': lambda: ICENODE.create_patient_interface(mimic_dir[data_tag], data_tag),
        'sequential': lambda: GRAM.create_patient_interface(mimic_dir[data_tag], data_tag)
    })
    
    interface_kind = {
        'ICENODE':  'timestamped',
        'ICENODE_UNIFORM': 'timestamped',
        'GRU': 'sequential',
        'RETAIN': 'sequential'
    }

    return {clf: interface_by_kind[interface_kind[clf]] for clf in clfs}
    

clfs = (
    'ICENODE', 
    'ICENODE_UNIFORM',
    'GRU',
    'RETAIN'
)

data_tag = 'M3'

## Params

In [None]:
config, params = get_trained_models(data_tag, clfs, 'MICRO-AUC', comp=max)

## Patient Interface

In [None]:
interface = get_patient_interface(data_tag, clfs)

## Dataset Partitioning

In [None]:
train_ids, valid_ids, test_ids = interface[clfs[0]].random_splits(split1=0.7, split2=0.85, random_seed=42)

## Load Models

In [None]:
def get_model_eval(clfs):
    evals = {}
    for clf in clfs:
        model = model_cls[clf].create_model(config[clf], interface[clf], train_ids, None)
        state = model.init_with_params(config[clf], params[clf])
        evals[clf] = lambda ids: model.eval(state, ids)['diag_detectability']
    return evals

evals = get_model_eval(clfs)

## Per-code performance

In [None]:
test_res = {clf: eval_(test_ids) for clf, eval_ in evals.items()} 


In [None]:
tests_raw = codes_auc_pairwise_tests(test_res)

In [None]:
tests_raw_fast = codes_auc_pairwise_tests(test_res, fast=True)

In [None]:
tests_raw

In [None]:
tests_raw_fast

## Correlation between AUC and N_POS_CODES

In [None]:
tests_raw_auc_corr = tests_raw[[f'AUC({clf})' for clf in test_res] + ['N_POSITIVE_CODES']]
# sns.pairplot(tests_raw_auc_corr)

In [None]:
# flatccs_idx2code = {idx: code for code, idx in m3_timestamped_interface.diag_flatccs_idx.items()}
# idx2desc = lambda i: m3_timestamped_interface.dag.diag_flatccs_desc[flatccs_idx2code[i]]
# tests_raw['DESC'] = tests_raw['CODE_INDEX'].apply(idx2desc)
# tests_raw

flatccs_idx2code = {idx: code for code, idx in m4_timestamped_interface.diag_flatccs_idx.items()}
idx2desc = lambda i: m4_timestamped_interface.dag.diag_flatccs_desc[flatccs_idx2code[i]]
tests_raw['DESC'] = tests_raw['CODE_INDEX'].apply(idx2desc)
tests_raw

In [None]:
# remove codes that no classifier has scored above 0.7
at_least_AUC_07 = tests_raw.loc[:,[f'AUC({clf})' for clf in test_res]].max(axis=1) > 0.7
tests = tests_raw[at_least_AUC_07]
tests

In [None]:
tests.describe()

In [None]:
from collections import defaultdict
auc_sets = defaultdict(set)
# clfs = tuple(sorted(m3_model.keys()))
clfs = tuple(sorted(m4_model.keys()))

clfs_pairs = []
for i in range(len(clfs)):
    for j in range(i + 1, len(clfs)):
        clfs_pairs.append((clfs[i], clfs[j]))
clfs_pairs = tuple(sorted(clfs_pairs))

In [None]:
# Extract codes that are predicted with equivalent performance with all classifiers.
test_cols = tuple(f'P0(AUC_{clf1}==AUC_{clf2})' for (clf1, clf2) in clfs_pairs)

In [None]:
common_perf = tests[(tests[test_cols[0]] > 0.05) & (tests[test_cols[1]] > 0.05) & (tests[test_cols[2]] > 0.05)]
auc_sets[clfs] = set(common_perf.CODE_INDEX)

In [None]:
common_perf

In [None]:
competing_tests = tests.drop(index=common_perf.index)

In [None]:
for index, row in competing_tests.iterrows():
    max_auc_clf = max(clfs, key=lambda clf: row[f'AUC({clf})'])
    relevant_tests = {(clf1, clf2): f'P0(AUC_{clf1}==AUC_{clf2})' for (clf1, clf2) in clfs_pairs if max_auc_clf in (clf1, clf2)}
    
    significant_max = True
    for (clf1, clf2), test_col in relevant_tests.items():
        # If max_auc_clf has maximum AUC, but without insigificant difference with another classifier
        # then consider both outperforming the third classifier.
        if row[test_col] > 0.05:
            significant_max = False
            auc_sets[(clf1, clf2)].add(int(row['CODE_INDEX']))
    
    if significant_max:
        auc_sets[max_auc_clf].add(int(row['CODE_INDEX']))

In [None]:
# Prepare for using Upset plot

best_sets = {}
for clf in clfs:
    best_sets[clf] = auc_sets[clf] | auc_sets[clfs]
    for clf1, clf2 in clfs_pairs:
        if clf in (clf1, clf2):
            best_sets[clf].update(auc_sets[(clf1, clf2)])
    

In [None]:
from upsetplot import from_contents, plot, UpSet
import matplotlib.pyplot as plt

In [None]:
upset_contents = from_contents(best_sets)

In [None]:
UpSet(upset_contents, subset_size='count', show_counts=True).plot()
current_figure = plt.gcf()
current_figure.savefig("auc_upset.pdf")

plt.show()


In [None]:
model_best_tests = {clf: competing_tests[competing_tests['CODE_INDEX'].isin(best_sets[clf])] for clf in clfs}
model_exc_best_tests = {clf: competing_tests[competing_tests['CODE_INDEX'].isin(auc_sets[clf])] for clf in clfs}

In [None]:
from IPython.display import display

for clf, best_tests in model_best_tests.items():
    print(clf)
    display(best_tests)

In [None]:
for clf, best_tests in model_exc_best_tests.items():
    print(clf)
    display(best_tests)

In [None]:
icenode_best_tests = model_best_tests['ICENODE']
icenode_best_test_above07 = icenode_best_tests[icenode_best_tests['AUC(ICENODE)'] > 0.8]
icenode_best_test_above07

## AUC Distribution

In [None]:
import numpy as np
df = common_perf[['AUC(ICENODE)', 'DESC', 'VAR[AUC(ICENODE)]']].sort_values('AUC(ICENODE)')
df = df[df['AUC(ICENODE)'] > 0.65]
df.columns = ['AUC', 'CCS', 'VAR']
error = df['VAR'].apply(np.sqrt)

In [None]:
len(df)

In [None]:
fig, ax = plt.subplots(figsize=(8,20))
sns.set_theme()
sns.set_style("darkgrid", {"axes.facecolor": ".9"})

ax = sns.barplot(x="AUC", y="CCS", color="salmon", xerr=error*1,capsize=.2, data=df)
# plt.title('ICE-NODE AUC on CCS Codes of Comparable AUC with GRU/RETAIN', fontsize=20)

fig.tight_layout(pad=4)
plt.xlabel('AUC', fontsize=24)
plt.xlim(0.65, 1.0)
plt.xticks(fontsize=20)

plt.ylabel('CCS', fontsize=24)
plt.yticks(fontsize=14)

sns.despine(left=True)
ax.grid(True)
ax.tick_params(bottom=True, left=False)

current_figure = plt.gcf()
current_figure.savefig("common_performance.pdf")


plt.show()

In [None]:
competing_df = []
comp_tests = competing_tests[competing_tests[[f'AUC({clf})' for clf in clfs]].max(axis=1) > 0.7]

for clf in clfs:
    comp_auc = comp_tests[f'AUC({clf})']
    comp_var = comp_tests[f'VAR[AUC({clf})]']
    comp_std = comp_var.apply(np.sqrt)
    comp_desc = comp_tests['DESC'].apply(lambda t: t if len(t) < 15 else t[:14] + '...')
    df = pd.DataFrame({'AUC': comp_auc, 'std': comp_std, 'CCS': comp_desc, 'Classifier': clf})
    df = df.sort_values('AUC').reset_index(drop=True)
    competing_df.append(df)

competing_df = pd.concat(competing_df)

In [None]:
fig, ax = plt.subplots(figsize=(10, 15))
sns.set_theme()
sns.set_style("darkgrid", {"axes.facecolor": ".9"})
colors = sns.xkcd_palette(["windows blue", "amber", "greyish"])

ax = sns.barplot(x="AUC", y="CCS", hue='Classifier', palette =colors , data=competing_df)
# plt.title('Performance of ICE-NODE/GRU/RETAIN', fontsize=40)

fig.tight_layout(pad=10)
plt.xlabel('AUC', fontsize=32)
plt.xlim(0.5, 1.0)
plt.xticks(fontsize=24)
plt.yticks(fontsize=14)

plt.ylabel('CCS', fontsize=36)
plt.legend(fontsize='xx-large', title_fontsize='40')
plt.legend(bbox_to_anchor=(1, 1), loc=2, borderaxespad=0.)

sns.despine(left=True)
ax.grid(True)
ax.tick_params(bottom=True, left=False)
current_figure = plt.gcf()
current_figure.savefig("competing_performance.pdf")

plt.show()


## Trajectories for Patients with CCS codes best predicted with ICENODE

### Analyse AUC for Each Admission in the Test Partition

In [None]:
icenode = ICENODE.create_model(config['ICENODE'], m4_interface['ICENODE'], m4_train_ids, None)
icenode_state = icenode.init_with_params(config['ICENODE'], m4_params['ICENODE'])

In [None]:
icenode_visit_auc_df = icenode.admissions_auc_scores(icenode_state, m4_test_ids)

In [None]:
icenode_visit_auc_df['N_VISITS'] = icenode_visit_auc_df['SUBJECT_ID'].apply(lambda i: (icenode_visit_auc_df['SUBJECT_ID'] == i).sum())
icenode_visit_auc_df

In [None]:
visit_auc_subject = icenode_visit_auc_df.groupby('SUBJECT_ID').agg({'AUC': 'mean', 'N_VISITS': 'max', 'N_CODES': ['min', 'max', 'mean', 'median'], 'INTERVALS': ['mean', 'max', 'min'], 'R/T': ['min', 'max', 'mean'] })


In [None]:
best_visit_auc_subjects =  visit_auc_subject[(visit_auc_subject.AUC['mean'] > 0.85) & (visit_auc_subject.N_VISITS['max'] > 1) & (visit_auc_subject.N_VISITS['max'] <10) & (visit_auc_subject.INTERVALS['max'] < 90)]
best_visit_auc_subjects

In [None]:
ccs_history = {i: m4_interface['ICENODE'].diag_flatccs_history(i)[1] for i in best_visit_auc_subjects.index}

In [None]:
ccs_history_icenode_best = {i: history for i, history in ccs_history.items() if len(set(history) & set(icenode_best_test_above07['CODE_INDEX']))> 0}

In [None]:
ccs_history_icenode_best

In [None]:
len(ccs_history_icenode_best)

In [None]:
trajectory = icenode.sample_trajectory(icenode_state, ccs_history_icenode_best.keys(), 1)

In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
ccs_index = list(icenode_best_test_above07['CODE_INDEX'])
idx2desc = lambda idx: m4_interface['ICENODE'].dag.diag_flatccs_desc[flatccs_idx2code[idx]]
ccs_description = list(map(idx2desc, ccs_index))


In [None]:
data = {}
for i, traj in trajectory.items():
    diag_times = m4_interface['ICENODE'].diag_times(i)
    
    t = traj['t']
    d = traj['d']
    
    prob = []
    time = []
    code = []
    
    for ccs_desc, ccs_idx in zip(ccs_description, ccs_index):
        time.append(t)
        code.extend([ccs_desc]*len(t))
        prob.append(d[:, ccs_idx])

    prob = np.hstack(prob)
    time = np.hstack(time)
    
    
    df = pd.DataFrame({'t': time, r'$\hat{v}$': prob, 'code': code})
    data[i] = (df, diag_times)
    

In [None]:
plt.rcParams['figure.figsize']=(10,10)
import math
for i, (df, diag_times) in data.items():

    plt.figure(i)
    
    g = sns.lineplot(data=df, x="t", y=r'$\hat{v}$', hue='code', marker='o')
    for diag_time in diag_times:
        g.axvline(x=diag_time, ymin=0, ymax=1, c="red", ls='--', linewidth=0.8, zorder=0, clip_on=False)

It seems that we cannot catch the smoothness of the trajectory as it evolves very quickly to the saturation value.