# Table of Content

<a name="outline"></a>

## Setup

- [A](#seca) External Imports
- [B](#secb) Internal Imports
- [C](#secd) Configurations and Paths 
- [D](#sece) Patient Interface and Train/Val/Test Partitioning


## [Evaluations: Predictive Performance on CPRD](#eval)


<a name="seca"></a>

### A External Imports [^](#outline)

In [1]:
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from pathlib import Path
from IPython.display import display
from upsetplot import from_contents, plot, UpSet, from_indicators


<a name="secb"></a>

### B Internal Imports [^](#outline)

In [2]:
# HOME and DATA_STORE are arbitrary, change as appropriate.
HOME = os.environ.get('HOME')
DATA_STORE = f'{HOME}/GP/ehr-data'
SOURCE_DIR = os.path.abspath("..")

DATA_FILE = os.path.join(DATA_STORE, 'cprd-data/DUMMY_DATA.csv')
ARTEFACTS_DIR = 'cprd_artefacts'
TRAIN_DIR = os.path.join(ARTEFACTS_DIR, 'train')


%load_ext autoreload
%autoreload 2

import analysis as A
import common as C



  PyTreeDef = type(jax.tree_structure(None))


<a name="secc"></a>

### C Configurations and Paths [^](#outline)

In [3]:
with C.modified_environ(DATA_FILE=DATA_FILE):
    cprd_dataset = C.datasets['CPRD']

In [4]:
relative_auc_config = {
    'pvalue': 0.01, 
    'min_auc': 0.9
}
top_k_list=[1, 2, 3, 5, 7, 10, 15, 20]
percentile_range=20 
n_percentiles=int(100/percentile_range)


import matplotlib.font_manager as font_manager
plt.rcParams.update(plt.rcParamsDefault)
plt.rcParams.update({'font.family': 'sans-serif',
                     'font.sans-serif': 'Helvetica',
                     'font.weight':  'normal'})

In [5]:
output_dir = 'cprd_analysis_artefacts'
Path(output_dir).mkdir(parents=True, exist_ok=True)


<a name="secd"></a>

### D Patient Interface and Train/Val/Test Patitioning [^](#outline)

In [6]:
code_scheme = {
    'dx': 'dx_cprd_ltc9809',
    'dx_outcome': 'dx_cprd_ltc9809'
}

cprd_interface = C.Subject_JAX.from_dataset(cprd_dataset, code_scheme=code_scheme)

In [7]:
cprd_splits = cprd_interface.random_splits(split1=0.7, split2=0.85, random_seed=42)
cprd_train_ids, cprd_valid_ids, cprd_test_ids = cprd_splits


In [8]:
cprd_percentiles = cprd_interface.dx_outcome_by_percentiles(20, cprd_splits[0])


<a name="sec1"></a>

## 1 Snooping/Selecting Best Models from the Validation Set [^](#outline)

In [9]:
from glob import glob
clfs = [os.path.basename(d) for d in glob(f"{TRAIN_DIR}/*")]
model_dir = dict(zip(clfs, clfs))

In [10]:

cprd_top = A.get_trained_models(clfs=clfs, train_dir={'cprd': TRAIN_DIR}, 
                                model_dir=model_dir, data_tag='cprd', 
                               criterion='MICRO-AUC',  comp=max)
display(cprd_top['summary'])


Unnamed: 0,Clf,Best_i,MICRO-AUC
0,ICE-NODE_UNIFORM,0,0.440169
1,RETAIN,38,0.898837
2,LogReg,0,0.499683
3,ICE-NODE,0,0.440169
4,GRU,0,0.358739


In [11]:

def select_predictor(clf):
    config = cprd_top['config'][clf] 
    params = cprd_top['params'][clf]
    model = C.model_cls[clf].create_model(config, cprd_interface, [])
    state = model.init_with_params(config, params)
    return model, state



<a name="eval"></a>

## 2 Predictive Performance on CPRD [^](#outline)

In [12]:
cprd_predictors = {clf: select_predictor(clf) for clf in clfs}

In [13]:
test_res_cprd = {clf: C.eval2_(model, cprd_splits[2]) for clf, model in cprd_predictors.items()} 

  leaves, treedef = jax.tree_flatten(tree)
  return jax.tree_unflatten(treedef, leaves)
  leaves, treedef = jax.tree_flatten(tree)
  return jax.tree_unflatten(treedef, leaves)
  leaves, treedef = jax.tree_flatten(tree)
  return jax.tree_unflatten(treedef, leaves)


In [14]:
auctests_cprd = A.codes_auc_pairwise_tests({k: v['risk_prediction'] for k, v in test_res_cprd.items()}, fast=True)


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 9461/9461 [00:00<00:00, 409767.66it/s]


In [15]:
test_cols = [col for col in auctests_cprd.columns if col[:2] == 'P0']
auctests_cprd.loc[:, test_cols].isnull().max(axis=1).sum()

0.0

In [16]:
upset_clfs = ['ICE-NODE', 'ICE-NODE_UNIFORM', 'GRU', 'RETAIN', 'LogReg']

cprd_freq_v = cprd_interface.dx_outcome_frequency_vec(cprd_train_ids) 
cprd_code_attrs = {'Code frequency': dict(zip(range(len(cprd_freq_v)), cprd_freq_v))}

cprd_upset_result = A.relative_performance_upset(auctests_cprd, upset_clfs, 
                                                 code_attrs=cprd_code_attrs,
                                                 interface=cprd_interface,
                                                 **relative_auc_config)
upset_ctx = lambda : sns.plotting_context("paper", font_scale=1.5, rc={"font.family": "sans-serif", 
                                                                        'axes.labelsize': 'medium',
                                                                       'ytick.labelsize': 'medium'})



with sns.axes_style("darkgrid"), upset_ctx():
    upset_format = from_indicators(cprd_upset_result['indicator_df'], data=cprd_upset_result['data'])
    upset_object = UpSet(upset_format, subset_size='count', show_counts=True)
    upset_object.style_subsets(absent=['LogReg'], present=('ICE-NODE', 'GRU', 'RETAIN'),
                               facecolor="red",
                               edgecolor="red", linewidth=3)
    upset_object.add_catplot(value='Code frequency', kind="strip")
    
    g = upset_object.plot()
        
    current_figure = plt.gcf()
    w, h = 2.5, 3
    wi, hi = current_figure.get_size_inches()
    current_figure.set_size_inches(hi*(w/h), hi)
    current_figure.savefig(f"{output_dir}/upset_CPRD.pdf", bbox_inches='tight')
    plt.show()

0 codes predicted an AUC higher than 0.9 by at least one model.
0 codes predicted an AUC higher than 0.9 by at least one model, with valid tests.


ValueError: Must pass non-zero number of levels/codes

In [17]:
results_cprd_eval,_ = A.evaluation_table(test_res_cprd, cprd_percentiles, top_k_list=top_k_list)

  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(


In [18]:
table_clfs = ('LogReg', 
              'RETAIN',
              'GRU',
              'ICE-NODE_UNIFORM',
              'ICE-NODE'
              )
results_cprd_tables = A.top_k_tables(table_clfs, results_cprd_eval, top_k_list=top_k_list,
                                   n_percentiles=n_percentiles, out_prefix=output_dir)

Unnamed: 0,ACC-P0-k1,ACC-P1-k1,ACC-P2-k1,ACC-P3-k1,ACC-P4-k1
LogReg,0.0,,,,
RETAIN,0.0,,,,
GRU,0.0,,,,
ICE-NODE_UNIFORM,0.0,,,,
ICE-NODE,0.0,,,,


\begin{tabular}{lrrrrr}
{} & {ACC-P0-k1} & {ACC-P1-k1} & {ACC-P2-k1} & {ACC-P3-k1} & {ACC-P4-k1} \\
\texttt{LogReg} & {\cellcolor[HTML]{8E0152}} \color[HTML]{F1F1F1} 0.000 & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan \\
\texttt{RETAIN} & {\cellcolor[HTML]{8E0152}} \color[HTML]{F1F1F1} 0.000 & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan \\
\texttt{GRU} & {\cellcolor[HTML]{8E0152}} \color[HTML]{F1F1F1} 0.000 & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan \\
\texttt{\texttt{ICE-NODE}\_UNI

Unnamed: 0,ACC-P0-k2,ACC-P1-k2,ACC-P2-k2,ACC-P3-k2,ACC-P4-k2
LogReg,0.0,,,,
RETAIN,0.0,,,,
GRU,0.0,,,,
ICE-NODE_UNIFORM,0.0,,,,
ICE-NODE,0.0,,,,


\begin{tabular}{lrrrrr}
{} & {ACC-P0-k2} & {ACC-P1-k2} & {ACC-P2-k2} & {ACC-P3-k2} & {ACC-P4-k2} \\
\texttt{LogReg} & {\cellcolor[HTML]{8E0152}} \color[HTML]{F1F1F1} 0.000 & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan \\
\texttt{RETAIN} & {\cellcolor[HTML]{8E0152}} \color[HTML]{F1F1F1} 0.000 & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan \\
\texttt{GRU} & {\cellcolor[HTML]{8E0152}} \color[HTML]{F1F1F1} 0.000 & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan \\
\texttt{\texttt{ICE-NODE}\_UNI

Unnamed: 0,ACC-P0-k3,ACC-P1-k3,ACC-P2-k3,ACC-P3-k3,ACC-P4-k3
LogReg,0.0,,,,
RETAIN,0.0,,,,
GRU,0.0,,,,
ICE-NODE_UNIFORM,0.0,,,,
ICE-NODE,0.0,,,,


\begin{tabular}{lrrrrr}
{} & {ACC-P0-k3} & {ACC-P1-k3} & {ACC-P2-k3} & {ACC-P3-k3} & {ACC-P4-k3} \\
\texttt{LogReg} & {\cellcolor[HTML]{8E0152}} \color[HTML]{F1F1F1} 0.000 & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan \\
\texttt{RETAIN} & {\cellcolor[HTML]{8E0152}} \color[HTML]{F1F1F1} 0.000 & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan \\
\texttt{GRU} & {\cellcolor[HTML]{8E0152}} \color[HTML]{F1F1F1} 0.000 & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan \\
\texttt{\texttt{ICE-NODE}\_UNI

Unnamed: 0,ACC-P0-k5,ACC-P1-k5,ACC-P2-k5,ACC-P3-k5,ACC-P4-k5
LogReg,0.0,,,,
RETAIN,0.0,,,,
GRU,0.0,,,,
ICE-NODE_UNIFORM,0.0,,,,
ICE-NODE,0.0,,,,


\begin{tabular}{lrrrrr}
{} & {ACC-P0-k5} & {ACC-P1-k5} & {ACC-P2-k5} & {ACC-P3-k5} & {ACC-P4-k5} \\
\texttt{LogReg} & {\cellcolor[HTML]{8E0152}} \color[HTML]{F1F1F1} 0.000 & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan \\
\texttt{RETAIN} & {\cellcolor[HTML]{8E0152}} \color[HTML]{F1F1F1} 0.000 & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan \\
\texttt{GRU} & {\cellcolor[HTML]{8E0152}} \color[HTML]{F1F1F1} 0.000 & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan \\
\texttt{\texttt{ICE-NODE}\_UNI

Unnamed: 0,ACC-P0-k7,ACC-P1-k7,ACC-P2-k7,ACC-P3-k7,ACC-P4-k7
LogReg,0.0,,,,
RETAIN,0.0,,,,
GRU,0.0,,,,
ICE-NODE_UNIFORM,0.0,,,,
ICE-NODE,0.0,,,,


\begin{tabular}{lrrrrr}
{} & {ACC-P0-k7} & {ACC-P1-k7} & {ACC-P2-k7} & {ACC-P3-k7} & {ACC-P4-k7} \\
\texttt{LogReg} & {\cellcolor[HTML]{8E0152}} \color[HTML]{F1F1F1} 0.000 & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan \\
\texttt{RETAIN} & {\cellcolor[HTML]{8E0152}} \color[HTML]{F1F1F1} 0.000 & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan \\
\texttt{GRU} & {\cellcolor[HTML]{8E0152}} \color[HTML]{F1F1F1} 0.000 & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan \\
\texttt{\texttt{ICE-NODE}\_UNI

Unnamed: 0,ACC-P0-k10,ACC-P1-k10,ACC-P2-k10,ACC-P3-k10,ACC-P4-k10
LogReg,0.0,,,,
RETAIN,0.0,,,,
GRU,0.0,,,,
ICE-NODE_UNIFORM,0.0,,,,
ICE-NODE,0.0,,,,


\begin{tabular}{lrrrrr}
{} & {ACC-P0-k10} & {ACC-P1-k10} & {ACC-P2-k10} & {ACC-P3-k10} & {ACC-P4-k10} \\
\texttt{LogReg} & {\cellcolor[HTML]{8E0152}} \color[HTML]{F1F1F1} 0.000 & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan \\
\texttt{RETAIN} & {\cellcolor[HTML]{8E0152}} \color[HTML]{F1F1F1} 0.000 & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan \\
\texttt{GRU} & {\cellcolor[HTML]{8E0152}} \color[HTML]{F1F1F1} 0.000 & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan \\
\texttt{\texttt{ICE-NODE}

Unnamed: 0,ACC-P0-k15,ACC-P1-k15,ACC-P2-k15,ACC-P3-k15,ACC-P4-k15
LogReg,0.0,,,,
RETAIN,0.0,,,,
GRU,0.0,,,,
ICE-NODE_UNIFORM,0.0,,,,
ICE-NODE,0.0,,,,


\begin{tabular}{lrrrrr}
{} & {ACC-P0-k15} & {ACC-P1-k15} & {ACC-P2-k15} & {ACC-P3-k15} & {ACC-P4-k15} \\
\texttt{LogReg} & {\cellcolor[HTML]{8E0152}} \color[HTML]{F1F1F1} 0.000 & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan \\
\texttt{RETAIN} & {\cellcolor[HTML]{8E0152}} \color[HTML]{F1F1F1} 0.000 & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan \\
\texttt{GRU} & {\cellcolor[HTML]{8E0152}} \color[HTML]{F1F1F1} 0.000 & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan \\
\texttt{\texttt{ICE-NODE}

Unnamed: 0,ACC-P0-k20,ACC-P1-k20,ACC-P2-k20,ACC-P3-k20,ACC-P4-k20
LogReg,0.0,,,,
RETAIN,0.0,,,,
GRU,0.0,,,,
ICE-NODE_UNIFORM,0.0,,,,
ICE-NODE,0.0,,,,


\begin{tabular}{lrrrrr}
{} & {ACC-P0-k20} & {ACC-P1-k20} & {ACC-P2-k20} & {ACC-P3-k20} & {ACC-P4-k20} \\
\texttt{LogReg} & {\cellcolor[HTML]{8E0152}} \color[HTML]{F1F1F1} 0.000 & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan \\
\texttt{RETAIN} & {\cellcolor[HTML]{8E0152}} \color[HTML]{F1F1F1} 0.000 & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan \\
\texttt{GRU} & {\cellcolor[HTML]{8E0152}} \color[HTML]{F1F1F1} 0.000 & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan & {\cellcolor[HTML]{276419}} \color[HTML]{F1F1F1} nan \\
\texttt{\texttt{ICE-NODE}

In [20]:
competing_tests_df = auctests_cprd[auctests_cprd.CODE_INDEX.isin(cprd_upset_result['indicator_df'][cprd_upset_result['indicator_df']['LogReg']==False].index)]
competing_tests_df.loc[:, [col for col in competing_tests_df.columns if col[:2]=='P0']]

Unnamed: 0,P0(AUC_GRU==AUC_ICE-NODE),P0(AUC_GRU==AUC_ICE-NODE_UNIFORM),P0(AUC_GRU==AUC_LogReg),P0(AUC_GRU==AUC_RETAIN),P0(AUC_ICE-NODE==AUC_ICE-NODE_UNIFORM),P0(AUC_ICE-NODE==AUC_LogReg),P0(AUC_ICE-NODE==AUC_RETAIN),P0(AUC_ICE-NODE_UNIFORM==AUC_LogReg),P0(AUC_ICE-NODE_UNIFORM==AUC_RETAIN),P0(AUC_LogReg==AUC_RETAIN)


In [21]:
upset_clfs = ['ICE-NODE', 'ICE-NODE_UNIFORM', 'GRU', 'RETAIN', 'LogReg']

ax = A.selected_auc_barplot(upset_clfs, competing_tests_df, horizontal=True)
ax.legend(fontsize=22, title_fontsize=32,
          bbox_to_anchor=(-0.02, 1), ncol=2)
plt.xticks(fontsize=30)
plt.yticks(fontsize=30)

current_figure = plt.gcf()
# w, h = 4, 4
# wi, hi = current_figure.get_size_inches()
# current_figure.set_size_inches(hi*(w/h), hi)

current_figure.savefig(f"{output_dir}/icenode_cprd.pdf", bbox_inches='tight')
plt.show()


ValueError: cannot convert float NaN to integer

## Trajectories for Patients with CCS codes best predicted with ICENODE

### Analyse AUC for Each Admission in the Test Partition

In [34]:
def admissions_auc_scores(model, test_ids):
    model, state = model
    return model.admissions_auc_scores(state, test_ids)

In [35]:
m4_icenode_visit_auc_df = admissions_auc_scores(m4_predictors['ICE-NODE'], m4_test_ids)
m4_icenode_visit_auc_df['N_VISITS'] = m4_icenode_visit_auc_df['SUBJECT_ID'].apply(lambda i: (m4_icenode_visit_auc_df['SUBJECT_ID'] == i).sum())

In [36]:
m4_visit_auc_subject = m4_icenode_visit_auc_df.groupby('SUBJECT_ID').agg(
    {'AUC': 'mean', 
     'N_VISITS': 'max', 
     'N_CODES': ['min', 'max', 'mean', 'median'], 
     'INTERVALS': ['mean', 'max', 'min'] })

In [37]:
m4m3_icenode_visit_auc_df = admissions_auc_scores(m4m3_predictors['ICE-NODE'], m3_interface.subjects.keys())
m4m3_icenode_visit_auc_df['N_VISITS'] = m4m3_icenode_visit_auc_df['SUBJECT_ID'].apply(lambda i: (m4m3_icenode_visit_auc_df['SUBJECT_ID'] == i).sum())
m4m3_visit_auc_subject = m4m3_icenode_visit_auc_df.groupby('SUBJECT_ID').agg(
    {'AUC': 'mean', 
     'N_VISITS': 'max', 
     'N_CODES': ['min', 'max', 'mean', 'median'], 
     'INTERVALS': ['mean', 'max', 'min']
    })


In [38]:
m4_best_visit_auc_subjects =  m4_visit_auc_subject[(m4_visit_auc_subject.N_VISITS['max'] > 2) & (m4_visit_auc_subject.INTERVALS['max'] < 150)]
m4m3_best_visit_auc_subjects =  m4m3_visit_auc_subject[(m4m3_visit_auc_subject.N_VISITS['max'] > 1) & (m4m3_visit_auc_subject.INTERVALS['max'] < 150)]


In [39]:
len(m4_best_visit_auc_subjects), len(m4m3_best_visit_auc_subjects)

(365, 220)

In [40]:
m4_ccs_history = {i: m4_interface.dx_flatccs_history(i) for i in m4_best_visit_auc_subjects.index}
m4m3_ccs_history = {i: m3_interface.dx_flatccs_history(i) for i in m4m3_best_visit_auc_subjects.index}

m4_ccs_idx_frequency = m4_interface.dx_flatccs_frequency(list(m4_best_visit_auc_subjects.index))
m3_ccs_idx_frequency = m3_interface.dx_flatccs_frequency(list(m4m3_best_visit_auc_subjects.index))

In [41]:
m4_history_all_ccs_codes = set(map(C.ccs_dag.dx_flatccs_idx.get, set.union(*[set(h.keys()) for h in m4_ccs_history.values()])))
m3_history_all_ccs_codes = set(map(C.ccs_dag.dx_flatccs_idx.get, set.union(*[set(h.keys()) for h in m4m3_ccs_history.values()])))
m4_history_all_ccs_codes = {idx for idx in m4_history_all_ccs_codes if m4_ccs_idx_frequency[idx] < 10}
m3_history_all_ccs_codes = {idx for idx in m3_history_all_ccs_codes if m3_ccs_idx_frequency[idx] < 10}

len(m4_history_all_ccs_codes), len(m3_history_all_ccs_codes)

(51, 74)

In [42]:
icenode_m4_competent = upsetcontents_m4['ICE-NODE'] 
icenode_m4_competent = auctests_m4[auctests_m4['CODE_INDEX'].isin(icenode_m4_competent)]
icenode_m4_competent = icenode_m4_competent[['N_POSITIVE_CODES', 'AUC(ICE-NODE)', 'DESC']].sort_values('N_POSITIVE_CODES',ascending=False)
# icenode_m4_competent.head(50)
trajectory_ccs_codes_level2 = [
    173, 168, 169, 156, 165, 216, 171, 100, 167
]
icenode_m4_competent[icenode_m4_competent.index.isin(trajectory_ccs_codes_level2)]

Unnamed: 0,N_POSITIVE_CODES,AUC(ICE-NODE),DESC
173,1790.0,0.941892,Non-Hodg lym
168,520.0,0.951321,Kidny/rnl ca
216,244.0,0.974325,Meningitis
169,181.0,0.942847,Uriny org ca
156,148.0,0.965231,Uterus cancr
165,146.0,0.957154,Testis cancr
171,89.0,0.940833,Thyroid cncr
100,84.0,0.95271,Brnch/lng ca
167,67.0,0.957488,Bladder cncr


In [43]:
trajectory_ccs_codes_level1 = [
    64, #renal fail 
    6, # pulm heart dx
    236, # ear dx 
]


In [44]:
m4_ccs_history_level1 = {i: history for i, history in m4_ccs_history.items() 
                         if len(set(map(A.dx_flatccs_idx2code.get, trajectory_ccs_codes_level1)) & set(history.keys())) > 0}
m4m3_ccs_history_level1 = {i: history for i, history in m4m3_ccs_history.items() 
                         if len(set(map(A.dx_flatccs_idx2code.get, trajectory_ccs_codes_level1)) & set(history.keys())) > 0}

m4_ccs_history_level2 = {i: history for i, history in m4_ccs_history.items() 
                         if len(set(map(A.dx_flatccs_idx2code.get, trajectory_ccs_codes_level2)) & set(history.keys())) > 0}
m4m3_ccs_history_level2 = {i: history for i, history in m4m3_ccs_history.items() 
                         if len(set(map(A.dx_flatccs_idx2code.get, trajectory_ccs_codes_level2)) & set(history.keys())) > 0}

In [45]:
len(m4_ccs_history_level1), len(m4m3_ccs_history_level1), len(m4_ccs_history_level2), len(m4m3_ccs_history_level2) 


(180, 115, 68, 35)

In [46]:
m4_cases = set(m4_ccs_history_level1.keys()) | set(m4_ccs_history_level2.keys())
m4m3_cases = set(m4m3_ccs_history_level1.keys()) | set(m4m3_ccs_history_level2.keys())
len(m4_cases), len(m4m3_cases)

(206, 131)

In [47]:
m4_icenode, m4_icenode_state = m4_predictors['ICE-NODE']
m4_trajectory = m4_icenode.sample_trajectory(m4_icenode_state, m4_cases, 1)

  pos[[top10_idx]] = 1
100%|███████████████████████████████████████████| 22/22 [34:49<00:00, 94.97s/it]


In [48]:
m4m3_icenode, m4m3_icenode_state = m4m3_predictors['ICE-NODE']
m4m3_trajectory = m4m3_icenode.sample_trajectory(m4m3_icenode_state, m4m3_cases, 1)

100%|█████████████████████████████████████████████| 9/9 [12:57<00:00, 86.38s/it]


In [49]:
# m4_selected_subjects = [
#     13798593, #acute-renal
#     13965528, #acute-renal
#     11907876, #pulmonary heart dx
#     13557547, #ear dx
#     10139504, #acute renal fail
#     12367864, #pulomonary-heart dx
# ]

# m4_selected_trajectory = {i: m4_trajectory[i] for i in m4_selected_subjects}

# m3_selected_subjects = [
#     50093 #pulmonary-heart dx
# ]

# m3_selected_trajectory = {i: m4m3_trajectory[i] for i in m3_selected_subjects}


In [174]:
import random

trajectory_ccs_codes_level1 = [
    64, #renal fail 
    6, # pulm heart dx
    236, # ear dx 
    # Others
    100, # Brnch/lng ca
    168, # Kidney/rnl ca
    194, # Immunity dx
]



# icenode_m4_competent.head(50)
trajectory_ccs_codes_level2 = [
    173, 168, 169, 156, 165, 216, 171, 100, 167
]

random.seed(42)
ccs_color = {
    6: 'blue',
    64: 'purple',
    236: 'orange',
    # Others
    100: 'salmon', # Brnch/lng ca
    168: 'navy', # Kidney/rnl ca
    194: 'pink', # Immunity dx
    **{idx: "#"+''.join([random.choice('0123456789ABCDEF') for j in range(6)])
                   for idx in trajectory_ccs_codes_level2}
}


plt.rcParams['figure.figsize'] = (10, 7)



In [183]:
plt.close('all')
for data_label, interface, trajectory_set in zip(["M4", "M3"], [m4_interface, m3_interface], [m4_trajectory, m4m3_trajectory]):
    for indices_label, ccs_indices in zip(["L1", "L2", "L1UL2"], [trajectory_ccs_codes_level1, trajectory_ccs_codes_level2, trajectory_ccs_codes_level1 + trajectory_ccs_codes_level2]):
        out_dir = f'{output_dir}/trajectories/{data_label}_{indices_label}' 
        Path(out_dir).mkdir(parents=True, exist_ok=True)
        A.plot_trajectory(trajectories=trajectory_set, 
                          interface=interface, 
                          flatccs_selection=ccs_indices, 
                          ccs_color=ccs_color,
                          out_dir=out_dir)

  plt.figure(i)
