In [1]:
import sys
import os
import pandas as pd

sys.path.append('..')

In [2]:
from mimicnet.train_snonet_diag import SNONETDiag
from mimicnet.train_gram import GRAM
from mimicnet.train_retain import RETAIN

%load_ext autoreload
%autoreload 2



## Define Directories

In [3]:
HOME = os.environ.get('HOME')

# MIMIC-III Dataset Directory
mimic3_dir = f'{HOME}/GP/ehr-data/mimic3-transforms'

# ICE-NODE/M trained on MIMIC-III training partition (70%)
icenode_m3_dir = f'{HOME}/GP/ehr-data/mimicnet-m3-exp/v0.1.23M3_snonet_diag_M/frozen_trial_2220'

# RETAIN trained on MIMIC-III training partition (70%)
retain_m3_dir = f'{HOME}/GP/ehr-data/mimicnet-m3-exp/v0.1.23M3_retain_M/frozen_trial_142'

# GRU=GRAM/M trained on MIMIC-III training partition (70%)
gru_m3_dir = f'{HOME}/GP/ehr-data/mimicnet-m3-exp/v0.1.23M3_gram_M/frozen_trial_615'

# GRAM=GRAM/G trained on MIMIC-III training partition (70%)
gram_m3_dir = f'{HOME}/GP/ehr-data/mimicnet-m3-exp/v0.1.23M3_gram_G/frozen_trial_442'

## Patient Interface for each Model

In [4]:
icnode_patient_interface = SNONETDiag.create_patient_interface(mimic3_dir, 'M3')
retain_patient_interface = RETAIN.create_patient_interface(mimic3_dir, 'M3')
# Works for both GRAM/M and GRAM/G
gram_patient_interface = GRAM.create_patient_interface(mimic3_dir, 'M3')


## Dataset Partitioning

In [5]:
import random

# seed 42 is used in all our experiments in this work.
rng = random.Random(42)
subjects_id = list(icnode_patient_interface.subjects.keys())
rng.shuffle(subjects_id)

# splits = train:val:test = 0.7:.15:.15
splits = int(.7 * len(subjects_id)), int(.85 * len(subjects_id))

train_ids = subjects_id[:splits[0]]
valid_ids = subjects_id[splits[0]:splits[1]]
test_ids = subjects_id[splits[1]:]

## Load Configs and Trained Params

In [7]:
from mimicnet.utils import load_config, load_params

# icenode_config = load_config(f'{icenode_m3_dir}/config.json')
# icenode_params = load_params(f'{icenode_m3_dir}/step0100_params.pickle')

# retain_config = load_config(f'{retain_m3_dir}/config.json')
# retain_params = load_params(f'{retain_m3_dir}/step0100_params.pickle')

gru_config = load_config(f'{gru_m3_dir}/config.json')
gru_params = load_params(f'{gru_m3_dir}/step0030_params.pickle')

# gram_config = load_config(f'{gram_m3_dir}/config.json')
# gram_params = load_params(f'{gram_m3_dir}/step0100_params.pickle')

## Create Model Objects

In [8]:
# icenode = SNONETDiag.create_model(icenode_config, icnode_patient_interface, train_ids, None)
# retain = RETAIN.create_model(retain_config, retain_patient_interface, train_ids, None)
gru = GRAM.create_model(gru_config, gram_patient_interface, train_ids, None)
# gram = GRAM.create_model(gram_config, gram_patient_interface, train_ids, None)

In [9]:
code_partitions = GRAM.code_partitions(gram_patient_interface, train_ids) 

In [12]:
from mimicnet.metrics import evaluation_table
res = gru.eval(gru_config['training']['loss_mixing'], gru_params, test_ids)
eval_df = evaluation_table({'TST': res}, code_partitions)



In [13]:
eval_df

(                      TST
 MACRO-AUC        0.451112
 MICRO-AUC        0.462751
 accuracy         0.796344
 diag_loss       91.446327
 f1-score         0.019595
 fn               0.013922
 fp               0.189734
 l1_loss      28326.291016
 l2_loss       2184.802002
 loss            91.704056
 npv              0.982774
 pre_ACC-P0       0.032022
 pre_ACC-P1       0.006267
 pre_ACC-P2       0.052989
 pre_ACC-P3       0.000000
 pre_ACC-P4       0.000000
 precision        0.010613
 recall           0.127540
 specificity      0.807189
 tn               0.794308
 tp               0.002035,
 {'TST_diag_loss': 91.44632720947266,
  'TST_loss': 91.70405578613281,
  'TST_l1_loss': 28326.291015625,
  'TST_l2_loss': 2184.802001953125,
  'TST_accuracy': 0.7963436841964722,
  'TST_recall': 0.12753979861736298,
  'TST_npv': 0.9827744364738464,
  'TST_specificity': 0.8071891069412231,
  'TST_precision': 0.010612758807837963,
  'TST_f1-score': 0.019594993442296982,
  'TST_tp': 0.002035201992839575,


## Analyse AUC for Each Admission in the Test Partition

In [27]:
from mimicnet.metrics import evaluation_table

In [58]:
res = icenode.eval(icenode_config['training']['loss_mixing'], p0, valid_ids)

In [54]:
code_percentiles = icenode.code_partitions(icnode_patient_interface, train_ids) 

In [59]:
eval_df = evaluation_table({'TST': res}, code_percentiles)

In [60]:
eval_df


(                             TST
 MACRO-AUC               0.521618
 MICRO-AUC               0.537354
 accuracy                0.888518
 all_points_count      424.000000
 diag_loss              98.166565
 dyn_loss                0.271464
 dyn_loss_per_week       0.000004
 f1-score                0.029766
 fn                      0.015052
 fp                      0.096431
 l1_loss             33092.835938
 l2_loss              2647.580566
 loss                   98.219398
 nfe_per_week            1.250589
 nfex1000               83.050003
 npv                     0.983311
 postjump_diag_loss     98.488777
 pre_ACC-P0              0.024173
 pre_ACC-P1              0.024927
 pre_ACC-P2              0.038375
 pre_ACC-P3              0.047669
 pre_ACC-P4              0.000000
 precision               0.017425
 predictable_count    8966.000000
 prejump_diag_loss      97.660095
 recall                  0.102026
 specificity             0.901925
 tn                      0.886808
 tp           

In [61]:
res = retain.eval(retain_config['training']['loss_mixing'], retain_params, valid_ids)
eval_df = evaluation_table({'TST': res}, code_percentiles)

In [62]:
eval_df

(                      TST
 MACRO-AUC        0.516685
 MICRO-AUC        0.523802
 accuracy         0.839888
 diag_loss      105.477043
 f1-score         0.046056
 fn               0.012491
 fp               0.147621
 l1_loss      39069.394531
 l2_loss       4987.006836
 loss           105.513245
 npv              0.985279
 pre_ACC-P0       0.074388
 pre_ACC-P1       0.073510
 pre_ACC-P2       0.079752
 pre_ACC-P3       0.052980
 pre_ACC-P4       0.184758
 precision        0.025514
 recall           0.236307
 specificity      0.849925
 tn               0.836023
 tp               0.003865,
 {'TST_diag_loss': 105.47704315185547,
  'TST_loss': 105.51324462890625,
  'TST_l1_loss': 39069.39453125,
  'TST_l2_loss': 4987.0068359375,
  'TST_accuracy': 0.8398882150650024,
  'TST_recall': 0.23630741238594055,
  'TST_npv': 0.9852791428565979,
  'TST_specificity': 0.8499245047569275,
  'TST_precision': 0.02551409602165222,
  'TST_f1-score': 0.04605557769536972,
  'TST_tp': 0.003865026170387864,
  '

In [15]:
icenode_auc_df = icenode.admissions_auc_scores(icenode_params, test_ids)

In [18]:
retain_auc_df = retain.admissions_auc_scores(retain_params, test_ids)

In [19]:
gru_auc_df = gru.admissions_auc_scores(gru_params, test_ids)

In [20]:
gram_auc_df = gram.admissions_auc_scores(gram_params, test_ids)

In [None]:
icenode_auc_df.eval

In [26]:
icenode_auc_df.tail(20)

Unnamed: 0,SUBJECT_ID,HADM_IDX,AUC,DAYS_AHEAD,INTERVALS,NFE
9337,59496,15,0.557576,25,1,8.0
9338,59496,16,0.557226,26,1,8.0
9339,1949,1,0.646902,1,1,8.0
9340,1949,2,0.647283,2,1,8.0
9341,1949,3,0.651092,3,1,8.0
9342,1949,4,0.652489,4,1,8.0
9343,1949,5,0.597824,10,6,8.0
9344,1949,6,0.643611,11,1,8.0
9345,1949,7,0.660633,12,1,8.0
9346,1949,8,0.666828,13,1,8.0


In [22]:
gram_auc_df.AUC.min()

0.017906336088154284

In [23]:
gram_auc_df.AUC.mean()

0.49718712421225186

In [23]:
icenode_auc_df.AUC.mean()

0.5889630890815816

In [20]:
dir(icenode)

['__call__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_diag_loss',
 '_emb_error',
 '_extract_nth_points',
 '_f_dec',
 '_f_init',
 '_f_n_ode',
 '_f_update',
 '_initialization_data',
 '_sample_ode_model_config',
 '_sample_ode_training_config',
 '_sample_training_config',
 'admissions_auc_scores',
 'code_partitions',
 'create_embedding',
 'create_model',
 'create_patient_interface',
 'detailed_loss',
 'diag_emb',
 'diag_loss',
 'diag_out_index',
 'dimensions',
 'eval',
 'eval_stats',
 'f_dec',
 'f_init',
 'f_n_ode',
 'f_update',
 'init_params',
 'initializers',
 'loss',
 'max_odeint_days',
 'sample_embeddings_config',
 'sample_experiment_config',
 'sample_model_co