# Table of Content

<a name="outline"></a>

## Setup

- [A](#seca) External Imports
- [B](#secb) Internal Imports
- [C](#secd) Configurations and Paths 
- [D](#sece) Patient Interface and Train/Val/Test Partitioning


## Evaluations

- [1](#sec2) Predictive Performance: MIMIC-III (Test Set)
- [2](#sec3) Predictive Performance: MIMIC-IV (Test Set)
- [3](#sec4) Predictive Performance: from MIMIC-IV (Training Set) to MIMIC-III (All)

<a name="seca"></a>

### A External Imports [^](#outline)

In [1]:
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from pathlib import Path
from IPython.display import display
from upsetplot import from_contents, plot, UpSet, from_indicators


<a name="secb"></a>

### B Internal Imports [^](#outline)

In [2]:
# HOME and DATA_STORE are arbitrary, change as appropriate.
HOME = os.environ.get('HOME')
DATA_STORE = f'{HOME}/GP/ehr-data'



%load_ext autoreload
%autoreload 2

import analysis as A
import common as C

<a name="secc"></a>

### C Configurations and Paths [^](#outline)

In [3]:
with C.modified_environ(DATA_DIR=DATA_STORE):
    mimic3_dataset = C.datasets['M3']
    mimic4_dataset = C.datasets['M4']

The following cell configures the location of the models pretrained on MIMIC-III (M3) and MIMIC-IV (M4), with GRAM embeddings (G) and without. Each training experiment produces 100 snapshots of parameters (100 training checkpoints throughout all the training iterations), and this Notebook picks the one that maximizes the average AUC of the visit level (prob. of having higher risks for present codes than the absent codes of the same visit).

In [4]:
train_dir = {
    # Location of training experiments on MIMIC-III.
    'M3': f'{DATA_STORE}/icd9v3/M3',
    # Location of training experiments on MIMIC-IV.
    'M4': f'{DATA_STORE}/icd9v3/M4'
}

In [5]:
relative_auc_config = {
    'pvalue': 0.01, 
    'min_auc': 0.9
}
top_k_list=[1, 2, 3, 5, 7, 10, 15, 20]
percentile_range=20 
n_percentiles=int(100/percentile_range)

import matplotlib.font_manager as font_manager
plt.rcParams.update(plt.rcParamsDefault)
plt.rcParams.update({'font.family': 'sans-serif',
                     'font.sans-serif': 'Helvetica',
                     'font.weight':  'normal'})

In [6]:
output_dir = 'artefacts'
Path(output_dir).mkdir(parents=True, exist_ok=True)


<a name="secd"></a>

### D Patient Interface and Train/Val/Test Patitioning [^](#outline)

In [7]:
code_scheme_M =  {
        "dx": "dx_icd9",
        "pr": "pr_icd9",
        "dx_outcome": "dx_icd9_filter_v3_groups",
        "pr": "pr_icd9"
}

code_scheme_G = {
        "dx": "dx_icd9",
        "pr": "pr_icd9",
        "dx_dagvec": True,
        "pr_dagvec": True,
        "dx_outcome": "dx_icd9_filter_v3_groups"
}

m3_interface_M = C.Subject_JAX.from_dataset(mimic3_dataset, code_scheme=code_scheme_M)
m4_interface_M = C.Subject_JAX.from_dataset(mimic4_dataset, code_scheme=code_scheme_M)

m3_interface_G = C.Subject_JAX.from_dataset(mimic3_dataset, code_scheme=code_scheme_G)
m4_interface_G = C.Subject_JAX.from_dataset(mimic4_dataset, code_scheme=code_scheme_G)

In [8]:
m4_splits = m4_interface_M.random_splits(split1=0.7, split2=0.85, random_seed=42)
m3_splits = m3_interface_M.random_splits(split1=0.7, split2=0.85, random_seed=42)

In [9]:
m3_train_ids, m3_valid_ids, m3_test_ids = m3_splits
m4_train_ids, m4_valid_ids, m4_test_ids = m4_splits


In [10]:
m3_percentiles = m3_interface_M.dx_outcome_by_percentiles(20, m3_splits[0])
m4_percentiles = m4_interface_M.dx_outcome_by_percentiles(20, m4_splits[0])

<a name="sec1"></a>

## 1 Snooping/Selecting Best Models from the Validation Set [^](#outline)

In [11]:
from glob import glob
clfs = [os.path.basename(d) for d in glob(f"{train_dir['M3']}/*")]
model_dir = dict(zip(clfs, clfs))

In [12]:
print('> Models trained on MIMIC-III')
m3_top = A.get_trained_models(clfs=clfs, train_dir=train_dir, model_dir=model_dir, data_tag='M3', 
                               criterion='MICRO-AUC',  comp=max)
display(m3_top['summary'])

print('> Models trained on MIMIC-IV')
m4_top = A.get_trained_models(clfs=clfs, train_dir=train_dir, model_dir=model_dir, data_tag='M4', 
                               criterion='MICRO-AUC',  comp=max)
display(m4_top['summary'])

In [13]:
m3_interface = {'M': m3_interface_M, 'G': m3_interface_G}
m3_models = C.lsr_get_models(clfs, m3_top["config"], m3_top["params"], m3_interface)

m4_interface = {'M': m4_interface_M, 'G': m4_interface_G}
m4_models = C.lsr_get_models(clfs, m4_top["config"], m4_top["params"], m4_interface)


In [14]:

def cross_predictor(clf, source_tag, target_tag):
    _params = {'M3': m3_top['params'][clf], 
               'M4': m4_top['params'][clf]}
    _config = {'M3': m3_top['config'][clf], 
               'M4': m4_top['config'][clf]}
    _interface = {'M3': m3_interface, 'M4': m4_interface}
    _emb = 'G' if '_G' in clf else 'M'

    return C.lsr_get_model(clf=clf, 
                           config=_config[source_tag], 
                           params=_params[source_tag],
                           interface = _interface[target_tag][_emb])


<a name="sec2"></a>

## 2 Predictive Performance on MIMIC-III (Test Set) [^](#outline)

In [15]:
m3_predictors = {clf: cross_predictor(clf, 'M3', 'M3') for clf in clfs}

In [16]:
test_res_m3 = {clf: C.eval2_(model, m3_splits[2]) for clf, model in m3_predictors.items()} 

In [26]:
auctests_m3 = A.codes_auc_pairwise_tests({k: v['risk_prediction'] for k, v in test_res_m3.items()}, fast=True)


In [27]:
test_cols = [col for col in auctests_m3.columns if col[:2] == 'P0']
auctests_m3.loc[:, test_cols].isnull().max(axis=1).sum()

In [None]:
# upset_clfs = ['ICE-NODE', 'ICE-NODE_UNIFORM', 'GRU', 'RETAIN', 'LogReg', 
#               'ICE-NODE/G', 'ICE-NODE_UNIFORM/G', 'GRU/G']
m3_freq_v = m3_interface_M.dx_outcome_frequency_vec(m3_train_ids) 
m3_code_attrs = {'Code frequency': dict(zip(range(len(m3_freq_v)), m3_freq_v))}

m3_upset_result = A.relative_performance_upset(auctests_m3, clfs, 
                                               code_attrs=m3_code_attrs,
                                               interface=m3_interface_M,
                                               **relative_auc_config)
upset_ctx = lambda : sns.plotting_context("paper", font_scale=1.5, rc={"font.family": "sans-serif", 
                                                                        'axes.labelsize': 'medium',
                                                                       'ytick.labelsize': 'medium'})



with sns.axes_style("darkgrid"), upset_ctx():
    upset_format = from_indicators(m3_upset_result['indicator_df'], data=m3_upset_result['data'])
    upset_object = UpSet(upset_format, subset_size='count', show_counts=True)
    upset_object.style_subsets(absent=['LogReg'], present=('ICE-NODE', 'GRU', 'RETAIN'),
                               facecolor="red",
                               edgecolor="red", linewidth=3)
    upset_object.add_catplot(value='Code frequency', kind="strip")
    
    g = upset_object.plot()
        
    current_figure = plt.gcf()
    w, h = 2.5, 3
    wi, hi = current_figure.get_size_inches()
    current_figure.set_size_inches(hi*(w/h), hi)
    current_figure.savefig(f"{output_dir}/upset_M3.pdf", bbox_inches='tight')
    plt.show()

In [36]:
results_m3_eval,_ = A.evaluation_table(test_res_m3, m3_percentiles, top_k_list=top_k_list)

In [37]:
# table_clfs = ('LogReg', 
#               'RETAIN',
#               'GRU',
#               'GRU/G',
#               'ICE-NODE_UNIFORM',
#               'ICE-NODE_UNIFORM/G',
#               'ICE-NODE', 
#               'ICE-NODE/G'
#               )
table_clfs = sorted(clfs)
results_m3_tables = A.top_k_tables(table_clfs, results_m3_eval, top_k_list=top_k_list,
                                   n_percentiles=n_percentiles, out_prefix=f'{output_dir}/M3')

In [30]:
competing_tests_df = auctests_m3[auctests_m3.CODE_INDEX.isin(m3_upset_result['indicator_df'][m3_upset_result['indicator_df']['LogReg']==False].index)]
competing_tests_df.loc[:, [col for col in competing_tests_df.columns if col[:2]=='P0']]

In [31]:
upset_clfs = ['ICE-NODE', 'ICE-NODE_UNIFORM', 'GRU', 'RETAIN', 'LogReg',
              'ICE-NODE/G', 'ICE-NODE_UNIFORM/G', 'GRU/G']

ax = A.selected_auc_barplot(upset_clfs, competing_tests_df,  horizontal=True)
ax.legend(fontsize=22, title_fontsize=32,
          bbox_to_anchor=(-0.02, 1), ncol=2)
plt.xticks(fontsize=30)
plt.yticks(fontsize=30)

current_figure = plt.gcf()
# w, h = 4, 4
# wi, hi = current_figure.get_size_inches()
# current_figure.set_size_inches(hi*(w/h), hi)

current_figure.savefig(f"{output_dir}/icenode_m3.pdf", bbox_inches='tight')
plt.show()


<a name="sec3"></a>

## 3 Relative AUC Performance on MIMIC-IV (Test Set) [^](#outline)

In [20]:
m4_predictors = {clf: cross_predictor(clf, 'M4', 'M4') for clf in clfs}

In [21]:
test_res_m4 = {clf: C.eval2_(model, m4_test_ids) for clf, model in m4_predictors.items()} 

In [38]:
auctests_m4 = A.codes_auc_pairwise_tests({k: v['risk_prediction'] for k, v in test_res_m4.items()}, fast=True)


In [41]:
# upset_clfs = ('ICE-NODE', 'ICE-NODE_UNIFORM', 'GRU', 'RETAIN', 'LogReg',)
#                'ICE-NODE/G', 'ICE-NODE_UNIFORM/G', 'GRU/G')

m4_freq_v = m4_interface_M.dx_outcome_frequency_vec(m4_train_ids) 
m4_code_attrs = {'Code frequency': dict(zip(range(len(m4_freq_v)), m4_freq_v))}

m4_upset_result = A.relative_performance_upset(auctests_m4, clfs, 
                                               code_attrs=m4_code_attrs,
                                               interface=m4_interface_M,
                                               **relative_auc_config)
upset_ctx = lambda : sns.plotting_context("paper",  font_scale=1.5, rc={"font.family": "Loma", 
                                                                        'axes.labelsize': 'medium',
                                                                       'ytick.labelsize': 'medium'})
with sns.axes_style("darkgrid"), upset_ctx():
    upset_format = from_indicators(m4_upset_result['indicator_df'], data=m4_upset_result['data'])
    upset_object = UpSet(upset_format, subset_size='count', show_counts=True)
#     upset_object.style_subsets(present=['ICE-NODE'], absent=('ICE-NODE_UNIFORM', 'GRU', 'RETAIN'),
#                                edgecolor="red", linewidth=3, facecolor="red")
    upset_object.add_catplot(value='Code frequency', kind="strip")

    g = upset_object.plot()
    current_figure = plt.gcf()
#     w, h = 5, 3
#     wi, hi = current_figure.get_size_inches()
#     current_figure.set_size_inches(hi*(w/h), hi)

    current_figure.savefig(f"{output_dir}/upset_M4.pdf", bbox_inches='tight')
    plt.show()


In [42]:
results_m4_eval, _ = A.evaluation_table(test_res_m4, m4_percentiles, top_k_list=top_k_list)

In [43]:
# table_clfs = ('LogReg', 
#               'RETAIN',
#               'GRU',
#               'GRU/G',
#               'ICE-NODE_UNIFORM',
#               'ICE-NODE_UNIFORM/G',
#               'ICE-NODE', 
#               'ICE-NODE/G',
#               )
table_clfs = sorted(clfs)
results_m4_tables = A.top_k_tables(table_clfs, results_m4_eval, top_k_list=top_k_list, 
                                   n_percentiles=n_percentiles, out_prefix=f'{output_dir}/M4')

In [38]:
icenode_m4_excl = m4_upset_result['content_sets']['ICE-NODE'] - set.union(*list(m4_upset_result['content_sets'][clf] for clf in ('RETAIN', 'GRU', 'ICE-NODE_UNIFORM')))
icenode_m4_excl = m4_upset_result['competing_performance'][m4_upset_result['competing_performance']['CODE_INDEX'].isin(icenode_m4_excl)]
icenode_m4_excl

In [39]:
# icenode_ratain_gru_m4 = upsetcontents_m4['ICE-NODE']
# icenode_ratain_gru_m4 = compete_codesm4[compete_codesm4['CODE_INDEX'].isin(icenode_ratain_gru_m4)]
# icenode_ratain_gru_m4.sort_values('AUC(ICE-NODE)', ascending=False)[['CODE_INDEX', 'N_POSITIVE_CODES', 'DESC', 'AUC(ICE-NODE)']]

In [40]:
upset_clfs = ('ICE-NODE', 'ICE-NODE_UNIFORM', 'GRU', 'RETAIN', 'LogReg')
w, h = 4, 3
ax = A.selected_auc_barplot(upset_clfs, icenode_m4_excl, horizontal=True)

plt.xticks(fontsize=30)
plt.yticks(fontsize=30)
ax.legend(fontsize=22, title_fontsize=32,
          bbox_to_anchor=(-0.02, 1), ncol=2)
current_figure = plt.gcf()
w, h = 4, 4
wi, hi = current_figure.get_size_inches()
current_figure.set_size_inches(hi*(w/h), hi)

current_figure.savefig(f"{output_dir}/icenode_m4.pdf", bbox_inches='tight')
plt.show()


<a name="sec4"></a>

## 4 Relative AUC Performance From MIMIC-IV (Training Set) to MIMIC-III (All) [^](#outline)

In [44]:
# clfs_ordered = ('ICE-NODE', 'ICE-NODE_UNIFORM', 'GRU', 'RETAIN')
m3_subjects = list(m3_interface_M.subjects.keys())
m4m3_predictors = {clf: cross_predictor(clf, 'M4', 'M3') for clf in clfs}

In [23]:
test_res_m4m3 = {clf: C.eval2_(model, m3_subjects) for clf, model in m4m3_predictors.items()} 

In [39]:
auctests_m4m3 = A.codes_auc_pairwise_tests({k: v['risk_prediction'] for k, v in test_res_m4m3.items()}, fast=True)

In [47]:
# upset_clfs = ('ICE-NODE', 'ICE-NODE_UNIFORM', 'GRU', 'RETAIN', 'LogReg',)
#                 'ICE-NODE/G', 'ICE-NODE_UNIFORM/G', 'GRU/G')

m4m3_freq_v = m4_interface_M.dx_outcome_frequency_vec(m4_train_ids) 
m4m3_code_attrs = {'Code frequency': dict(zip(range(len(m4_freq_v)), m4_freq_v))}

m4m3_upset_result = A.relative_performance_upset(auctests_m4m3, clfs, 
                                                 code_attrs=m4m3_code_attrs,
                                                 interface=m3_interface_M,
                                                 **relative_auc_config)

upset_ctx = lambda : sns.plotting_context("paper", font_scale=1.5, rc={"font.family": "Loma", 
                                                                        'axes.labelsize': 'medium',
                                                                       'ytick.labelsize': 'medium'})
with sns.axes_style("darkgrid"), upset_ctx():
    upset_format = from_indicators(m4m3_upset_result['indicator_df'], data=m4m3_upset_result['data'])
    upset_object = UpSet(upset_format, subset_size='count', show_counts=True)
    upset_object.add_catplot(value='Code frequency', kind="strip")

#     upset_object.style_subsets(present='ICE-NODE', absent=['ICE-NODE_UNIFORM', 'GRU', 'RETAIN'],
#                               edgecolor="red", facecolor="red")
    g = upset_object.plot()
    
    current_figure = plt.gcf()
    current_figure.savefig(f"{output_dir}/upset_M4M3.pdf", bbox_inches='tight')
    plt.show()

In [49]:
results_m4m3_eval, _ = A.evaluation_table(test_res_m4m3, m4_percentiles, top_k_list=top_k_list)

In [50]:
# table_clfs = ('LogReg', 
#               'RETAIN',
#               'GRU',
#               'GRU/G',
#               'ICE-NODE', 
#               'ICE-NODE/G',
#               'ICE-NODE_UNIFORM',
#               'ICE-NODE_UNIFORM/G'
#               )
table_clfs = sorted(clfs)
results_m4m3_tables = A.top_k_tables(table_clfs, results_m4m3_eval, top_k_list=top_k_list, 
                                     n_percentiles=n_percentiles, out_prefix=f'{output_dir}/M4M3')

In [47]:
icenode_m4m3_excl = m4m3_upset_result['content_sets']['ICE-NODE'] - set.union(*list(m4m3_upset_result['content_sets'][clf] for clf in ('RETAIN', 'GRU', 'ICE-NODE_UNIFORM')))
icenode_m4m3_excl = compete_codesm4m3[m4m3_upset_result['competing_performance']['CODE_INDEX'].isin(icenode_m4m3_excl)]
icenode_m4m3_excl

In [None]:
ax = A.selected_auc_barplot(upset_clfs, icenode_m4m3_excl, horizontal=True)

plt.xticks(fontsize=30)
plt.yticks(fontsize=30)
# ax.legend(fontsize=22, title_fontsize=32,
#           bbox_to_anchor=(0.02, 1), ncol=2)
plt.xticks(fontsize=30)
plt.yticks(fontsize=30)
ax.legend(fontsize=22, title_fontsize=32,
          bbox_to_anchor=(1, 1.25), ncol=2)

current_figure = plt.gcf()
current_figure.savefig(f"{output_dir}/icenode_m4m3.pdf", bbox_inches='tight')
plt.show()


## Trajectories for Patients with CCS codes best predicted with ICENODE

### Analyse AUC for Each Admission in the Test Partition

In [34]:
def admissions_auc_scores(model, test_ids):
    model, state = model
    return model.admissions_auc_scores(state, test_ids)

In [35]:
m4_icenode_visit_auc_df = admissions_auc_scores(m4_predictors['ICE-NODE'], m4_test_ids)
m4_icenode_visit_auc_df['N_VISITS'] = m4_icenode_visit_auc_df['SUBJECT_ID'].apply(lambda i: (m4_icenode_visit_auc_df['SUBJECT_ID'] == i).sum())

In [36]:
m4_visit_auc_subject = m4_icenode_visit_auc_df.groupby('SUBJECT_ID').agg(
    {'AUC': 'mean', 
     'N_VISITS': 'max', 
     'N_CODES': ['min', 'max', 'mean', 'median'], 
     'INTERVALS': ['mean', 'max', 'min'] })

In [37]:
m4m3_icenode_visit_auc_df = admissions_auc_scores(m4m3_predictors['ICE-NODE'], m3_interface.subjects.keys())
m4m3_icenode_visit_auc_df['N_VISITS'] = m4m3_icenode_visit_auc_df['SUBJECT_ID'].apply(lambda i: (m4m3_icenode_visit_auc_df['SUBJECT_ID'] == i).sum())
m4m3_visit_auc_subject = m4m3_icenode_visit_auc_df.groupby('SUBJECT_ID').agg(
    {'AUC': 'mean', 
     'N_VISITS': 'max', 
     'N_CODES': ['min', 'max', 'mean', 'median'], 
     'INTERVALS': ['mean', 'max', 'min']
    })


In [38]:
m4_best_visit_auc_subjects =  m4_visit_auc_subject[(m4_visit_auc_subject.N_VISITS['max'] > 2) & (m4_visit_auc_subject.INTERVALS['max'] < 150)]
m4m3_best_visit_auc_subjects =  m4m3_visit_auc_subject[(m4m3_visit_auc_subject.N_VISITS['max'] > 1) & (m4m3_visit_auc_subject.INTERVALS['max'] < 150)]


In [39]:
len(m4_best_visit_auc_subjects), len(m4m3_best_visit_auc_subjects)

In [40]:
m4_ccs_history = {i: m4_interface.dx_flatccs_history(i) for i in m4_best_visit_auc_subjects.index}
m4m3_ccs_history = {i: m3_interface.dx_flatccs_history(i) for i in m4m3_best_visit_auc_subjects.index}

m4_ccs_idx_frequency = m4_interface.dx_flatccs_frequency(list(m4_best_visit_auc_subjects.index))
m3_ccs_idx_frequency = m3_interface.dx_flatccs_frequency(list(m4m3_best_visit_auc_subjects.index))

In [41]:
m4_history_all_ccs_codes = set(map(C.ccs_dag.dx_flatccs_idx.get, set.union(*[set(h.keys()) for h in m4_ccs_history.values()])))
m3_history_all_ccs_codes = set(map(C.ccs_dag.dx_flatccs_idx.get, set.union(*[set(h.keys()) for h in m4m3_ccs_history.values()])))
m4_history_all_ccs_codes = {idx for idx in m4_history_all_ccs_codes if m4_ccs_idx_frequency[idx] < 10}
m3_history_all_ccs_codes = {idx for idx in m3_history_all_ccs_codes if m3_ccs_idx_frequency[idx] < 10}

len(m4_history_all_ccs_codes), len(m3_history_all_ccs_codes)

In [42]:
icenode_m4_competent = upsetcontents_m4['ICE-NODE'] 
icenode_m4_competent = auctests_m4[auctests_m4['CODE_INDEX'].isin(icenode_m4_competent)]
icenode_m4_competent = icenode_m4_competent[['N_POSITIVE_CODES', 'AUC(ICE-NODE)', 'DESC']].sort_values('N_POSITIVE_CODES',ascending=False)
# icenode_m4_competent.head(50)
trajectory_ccs_codes_level2 = [
    173, 168, 169, 156, 165, 216, 171, 100, 167
]
icenode_m4_competent[icenode_m4_competent.index.isin(trajectory_ccs_codes_level2)]

In [43]:
trajectory_ccs_codes_level1 = [
    64, #renal fail 
    6, # pulm heart dx
    236, # ear dx 
]


In [44]:
m4_ccs_history_level1 = {i: history for i, history in m4_ccs_history.items() 
                         if len(set(map(A.dx_flatccs_idx2code.get, trajectory_ccs_codes_level1)) & set(history.keys())) > 0}
m4m3_ccs_history_level1 = {i: history for i, history in m4m3_ccs_history.items() 
                         if len(set(map(A.dx_flatccs_idx2code.get, trajectory_ccs_codes_level1)) & set(history.keys())) > 0}

m4_ccs_history_level2 = {i: history for i, history in m4_ccs_history.items() 
                         if len(set(map(A.dx_flatccs_idx2code.get, trajectory_ccs_codes_level2)) & set(history.keys())) > 0}
m4m3_ccs_history_level2 = {i: history for i, history in m4m3_ccs_history.items() 
                         if len(set(map(A.dx_flatccs_idx2code.get, trajectory_ccs_codes_level2)) & set(history.keys())) > 0}

In [45]:
len(m4_ccs_history_level1), len(m4m3_ccs_history_level1), len(m4_ccs_history_level2), len(m4m3_ccs_history_level2) 


In [46]:
m4_cases = set(m4_ccs_history_level1.keys()) | set(m4_ccs_history_level2.keys())
m4m3_cases = set(m4m3_ccs_history_level1.keys()) | set(m4m3_ccs_history_level2.keys())
len(m4_cases), len(m4m3_cases)

In [47]:
m4_icenode, m4_icenode_state = m4_predictors['ICE-NODE']
m4_trajectory = m4_icenode.sample_trajectory(m4_icenode_state, m4_cases, 1)

In [48]:
m4m3_icenode, m4m3_icenode_state = m4m3_predictors['ICE-NODE']
m4m3_trajectory = m4m3_icenode.sample_trajectory(m4m3_icenode_state, m4m3_cases, 1)

In [49]:
# m4_selected_subjects = [
#     13798593, #acute-renal
#     13965528, #acute-renal
#     11907876, #pulmonary heart dx
#     13557547, #ear dx
#     10139504, #acute renal fail
#     12367864, #pulomonary-heart dx
# ]

# m4_selected_trajectory = {i: m4_trajectory[i] for i in m4_selected_subjects}

# m3_selected_subjects = [
#     50093 #pulmonary-heart dx
# ]

# m3_selected_trajectory = {i: m4m3_trajectory[i] for i in m3_selected_subjects}


In [174]:
import random

trajectory_ccs_codes_level1 = [
    64, #renal fail 
    6, # pulm heart dx
    236, # ear dx 
    # Others
    100, # Brnch/lng ca
    168, # Kidney/rnl ca
    194, # Immunity dx
]



# icenode_m4_competent.head(50)
trajectory_ccs_codes_level2 = [
    173, 168, 169, 156, 165, 216, 171, 100, 167
]

random.seed(42)
ccs_color = {
    6: 'blue',
    64: 'purple',
    236: 'orange',
    # Others
    100: 'salmon', # Brnch/lng ca
    168: 'navy', # Kidney/rnl ca
    194: 'pink', # Immunity dx
    **{idx: "#"+''.join([random.choice('0123456789ABCDEF') for j in range(6)])
                   for idx in trajectory_ccs_codes_level2}
}


plt.rcParams['figure.figsize'] = (10, 7)



In [183]:
plt.close('all')
for data_label, interface, trajectory_set in zip(["M4", "M3"], [m4_interface, m3_interface], [m4_trajectory, m4m3_trajectory]):
    for indices_label, ccs_indices in zip(["L1", "L2", "L1UL2"], [trajectory_ccs_codes_level1, trajectory_ccs_codes_level2, trajectory_ccs_codes_level1 + trajectory_ccs_codes_level2]):
        out_dir = f'{output_dir}/trajectories/{data_label}_{indices_label}' 
        Path(out_dir).mkdir(parents=True, exist_ok=True)
        A.plot_trajectory(trajectories=trajectory_set, 
                          interface=interface, 
                          flatccs_selection=ccs_indices, 
                          ccs_color=ccs_color,
                          out_dir=out_dir)