In [1]:
import numpy as np
import pandas as pd
import json
import collections
from collections import defaultdict 
from functools import partial
import networkx as nx
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

import jax

# Global flag to set a specific platform, must be used at startup.
jax.config.update('jax_platform_name', 'gpu')
jax.config.update("jax_debug_nans", True)
jax.config.update("jax_debug_infs", True)

jax.config.update('jax_log_compiles', False)
jax.config.update('jax_check_tracer_leaks', False)

In [2]:
# Good read: https://iq-inc.com/importerror-attempted-relative-import/

import sys
import importlib
from mimicnet import concept
from mimicnet import jax_interface
from mimicnet import dag
from mimicnet import glove
from mimicnet import gram
from mimicnet import ode

importlib.reload(sys.modules['mimicnet.concept'])
importlib.reload(sys.modules['mimicnet.dag'])
importlib.reload(sys.modules['mimicnet.jax_interface'])
importlib.reload(sys.modules['mimicnet.glove'])
importlib.reload(sys.modules['mimicnet.gram'])
importlib.reload(sys.modules['mimicnet.ode'])

<module 'mimicnet.ode' from '/home/asem/GP/MIMIC-SNONET/mimicnet/ode.py'>

In [3]:

KG = dag.CCSDAG()

In [4]:
# multi_visit_mimic_dir = '/home/am8520/GP/ehr-data/mimic3-multi-visit'
multi_visit_mimic_dir = '/home/asem/GP/ehr-data/mimic3-multi-visit'
transformed_mimic_dir = '/home/asem/GP/ehr-data/mimic3-transforms'
mimic_dir = '/home/asem/GP/ehr-data/mimic3-v1.4/physionet.org/files/mimiciii/1.4'
experiments_dir = '/home/asem/GP/ehr-data/mimic3-snonet-exp'
experiment_prefix = 'NOV29'

In [5]:
D_LABITEMS = pd.read_csv(f'{mimic_dir}/D_LABITEMS.csv.gz')
D_ITEMS = pd.read_csv(f'{mimic_dir}/D_ITEMS.csv.gz')
D_TEST = pd.concat([D_LABITEMS, D_ITEMS], join='inner')
test_label_dict = dict(zip(D_TEST.ITEMID, D_TEST.LABEL))
test_cat_dict = dict(zip(D_TEST.ITEMID, D_TEST.CATEGORY))

In [6]:
static_df = pd.read_csv(f'{transformed_mimic_dir}/static_df.csv.gz')
adm_df = pd.read_csv(f'{transformed_mimic_dir}/adm_df.csv.gz')
diag_df = pd.read_csv(f'{transformed_mimic_dir}/diag_df.csv.gz', dtype={'ICD9_CODE': str})
proc_df = pd.read_csv(f'{transformed_mimic_dir}/proc_df.csv.gz', dtype={'ICD9_CODE': str})
test_df = pd.read_csv(f'{transformed_mimic_dir}/test_df.csv.gz')


In [7]:
# Cast columns of dates to datetime64

static_df['DOB'] = pd.to_datetime(static_df.DOB, infer_datetime_format=True).dt.normalize()
adm_df['ADMITTIME'] = pd.to_datetime(adm_df.ADMITTIME, infer_datetime_format=True).dt.normalize()
adm_df['DISCHTIME'] = pd.to_datetime(adm_df.DISCHTIME, infer_datetime_format=True).dt.normalize()
test_df['DATE'] = pd.to_datetime(test_df.DATE, infer_datetime_format=True).dt.normalize()

In [8]:
static_df.ETHNIC_GROUP.nunique()

7

In [9]:
test_df

Unnamed: 0,SUBJECT_ID,ITEMID,DATE,VALUE
0,17,50852,2134-12-22,-1.201339
1,17,50861,2134-12-22,-0.809370
2,17,50862,2134-12-22,1.634410
3,17,50863,2134-12-22,-1.452596
4,17,50867,2134-12-22,-0.608894
...,...,...,...,...
4546204,99982,227456,2157-02-22,1.794272
4546205,99982,227457,2157-02-22,-0.723781
4546206,99982,227465,2157-02-22,2.247434
4546207,99982,227466,2157-02-22,-0.029395


In [10]:
patients = concept.Subject.to_list(static_df, adm_df, diag_df, proc_df, test_df)

In [11]:
len(patients)

4434

In [12]:
# ehr_dfs = concept.Subject.to_df(patients)

In [13]:
# def df_eq(df1, df2):
#     df11 = df1.sort_values(by=df1.columns.tolist()).reset_index(drop=True)
#     df22 = df2.sort_values(by=df2.columns.tolist()).reset_index(drop=True)
#     return df11[df22.columns].equals(df22)

# all(map(df_eq, ehr_dfs, [static_df, adm_df, diag_df, proc_df, test_df]))

### [FORK] Skip the cell below to load the jaxified data from a stored file on disc

In [14]:
# subjects_interface = jax_interface.SubjectJAXInterface(patients, set(test_df.ITEMID), KG)
# import pickle
# with open(f'{experiments_dir}/{experiment_prefix}_subjects_interface.pkl', 'wb') as pickleFile:
#     pickle.dump(subjects_interface, pickleFile)

In [15]:
import pickle
with open(f'{experiments_dir}/{experiment_prefix}_subjects_interface.pkl', 'rb') as pickleFile:
    subjects_interface = pickle.load(pickleFile)

## GloVe Initialization

In [16]:
%load_ext autoreload
%autoreload 2

In [17]:
import logging
logging.basicConfig(level=logging.DEBUG)
logging.debug("test")

DEBUG:root:test


In [18]:
glove_args = {
    'diag_idx': subjects_interface.diag_multi_ccs_idx,
    'proc_idx': subjects_interface.proc_multi_ccs_idx,
    'ccs_dag': KG,
    'subjects': patients,
    'diag_vector_size': 100,
    'proc_vector_size': 60,
    'iterations': 30,
    'window_size_days': 2 * 365
}

diag_glove_rep, proc_glove_rep = glove.glove_representation(**glove_args)

In [19]:
print(f'#point_indices: {len(subjects_interface.nth_points)}')
print(f'#total_points: {sum(len(points) for n, points in subjects_interface.nth_points.items())}')

#[len(points) for n, points in subjects_interface.nth_points.items()]

#point_indices: 1085
#total_points: 129334


## GRAM objects

In [20]:
from datetime import datetime
daily_tracer = "/tmp/tensorboard/"+ datetime.now().strftime("%Y%m%d-%H%M%S") 
print(daily_tracer)

/tmp/tensorboard/20211202-222535


In [21]:


logs = '/tmp/tensorboard/20210708-182059'
#server = jax.profiler.start_server(9999)

In [22]:
       
config = {
    'gram_config': {
        'diag': {
            'ccs_dag': KG,
            'code2index': subjects_interface.diag_multi_ccs_idx,
            'attention_method': 'tanh', #l2, tanh
            'attention_dim': 100,
            'ancestors_mat': subjects_interface.diag_multi_ccs_ancestors_mat,
            'basic_embeddings': diag_glove_rep
        },
        'proc': {
            'ccs_dag': KG,
            'code2index': subjects_interface.proc_multi_ccs_idx,
            'attention_method': 'tanh',
            'attention_dim': 60,
            'ancestors_mat': subjects_interface.proc_multi_ccs_ancestors_mat,
            'basic_embeddings': proc_glove_rep
        }
    },
    'model': {
        'ode_dyn': 'mlp', # gru, mlp
        'state_size': 50,
        'numeric_hidden_size': 40,
        'bias': True
    },
    'training': {
        'train_validation_split': 0.8,
        'batch_size': 12,
        'epochs': 5,
        'lr': 1e-3,
        'diag_loss': 'balanced_focal', # balanced_focal, bce
        'tay_reg': None, # Order of regularized derivative of the dynamics function (None for disable).
        'loss_mixing': {
            'num_alpha': 0.1,
            'diag_alpha': 0.1,
            'ode_alpha': 1e-3,
            'l1_reg': 1e-6,
            'l2_reg': 1e-5,
            'dyn_reg': 1e-5
        },
        'eval_freq': 10,
        'save_freq': 100,
        'save_params_prefix': None
    }
}


In [23]:
diag_gram = gram.DAGGRAM(**config['gram_config']['diag'])

DEBUG:absl:Initializing backend 'interpreter'
DEBUG:absl:Backend 'interpreter' initialized
DEBUG:absl:Initializing backend 'cpu'
DEBUG:absl:Backend 'cpu' initialized
DEBUG:absl:Initializing backend 'tpu_driver'
INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: 
DEBUG:absl:Initializing backend 'gpu'
DEBUG:absl:Backend 'gpu' initialized
DEBUG:absl:Initializing backend 'tpu'
INFO:absl:Unable to initialize backend 'tpu': Invalid argument: TpuPlatform is not available.


In [24]:
proc_gram = gram.DAGGRAM(**config['gram_config']['proc'])

## GRU-ODE-Bayes

In [None]:
import random
%load_ext autoreload
%autoreload 2
from absl import logging
logging.set_verbosity(logging.INFO)

#with jax.profiler.trace(logs):
res = ode.train_ehr(subject_interface=subjects_interface,
                diag_gram=diag_gram,
                proc_gram=proc_gram,
                rng=random.Random(42),
                **config['model'],
                **config['training'],
                verbose_debug=False,
                shape_debug=False,
                nan_debug=False,
                memory_profile=False)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


INFO:ode:#params: 293044
INFO:numexpr.utils:NumExpr defaulting to 8 threads.
info retrieval:   0%|          | 1/1477 [04:19<106:35:09, 259.97s/it]                 INFO:ode:
                        Training    Validation
prejump_num_loss       14.317818    917.747375
postjump_num_loss     890.755676  61111.664062
prejump_diag_loss       0.002124      0.076024
postjump_diag_loss      0.002105      0.074312
num_loss              101.961604   6937.139044
diag_loss               0.002122      0.075853
ode_loss                0.104082      7.012916
l1_loss             14247.754883  14247.754883
l1_loss_per_point      31.732194      0.562463
l2_loss              1785.656128   1785.656128
l2_loss_per_point       3.976962      0.070493
dyn_loss                 14200.0      775924.0
dyn_loss_per_week      10.344469     7.9545236
loss                 0.104449585      7.013223
INFO:ode:
                          Training     Valdation
accuracy                  0.506583      0.486989
recall        

info retrieval:   3%|▎         | 41/1477 [33:01<10:43:15, 26.88s/it]                 INFO:ode:
                        Training    Validation
prejump_num_loss       16.301437    978.553467
postjump_num_loss     781.436096  49125.906250
prejump_diag_loss       0.001322      0.049367
postjump_diag_loss      0.001341      0.049680
num_loss               92.814903   5793.288745
diag_loss               0.001324      0.049398
ode_loss                0.094137      5.842637
l1_loss             14348.020508  14348.020508
l1_loss_per_point      46.584482      0.566421
l2_loss              1760.824951   1760.824951
l2_loss_per_point       5.716964      0.069513
dyn_loss                  9208.0      782686.0
dyn_loss_per_week       8.288029      8.023846
loss                  0.09451132      5.842947
INFO:ode:
                          Training     Valdation
accuracy                  0.804049      0.799993
recall                    0.830409      0.780383
npv                       0.993489      0.9

INFO:ode:
                          Training     Valdation
accuracy                  0.793427      0.790047
recall                    0.768340      0.802502
npv                       0.986937      0.990483
specificity               0.794566      0.789568
precision                 0.145150      0.127777
f1-score                  0.244172      0.220453
tp                        0.033367      0.029687
tn                        0.760060      0.760360
fp                        0.196512      0.202647
fn                        0.010060      0.007306
points_count            229.000000  25331.000000
odeint_weeks_per_visit    4.089832      3.850815
nfe_per_point             0.000000      0.000000
nfe_per_week              0.000000      0.000000
nfex1000                  0.000000      0.000000
INFO:ode:
    Training(pre)  Training(post)  Validation(pre)  Validation(post)
p0       0.000000        0.000000         0.009000          0.006429
p1       0.000000        0.000000         0.013609        

INFO:ode:
    Training(pre)  Training(post)  Validation(pre)  Validation(post)
p0       0.000000        0.000000         0.007715          0.010929
p1       0.035714        0.107143         0.024556          0.044970
p2       0.191176        0.132353         0.364190          0.288063
p3       0.901961        0.882353         0.714646          0.747790
p4       1.000000        0.963636         0.998612          0.991396
info retrieval:   9%|▉         | 131/1477 [1:34:27<10:44:32, 28.73s/it]                 INFO:ode:
                        Training    Validation
prejump_num_loss       16.610407    995.001038
postjump_num_loss     618.306885  37928.191406
prejump_diag_loss       0.000999      0.046697
postjump_diag_loss      0.001009      0.047370
num_loss               76.780055   4688.320074
diag_loss               0.001000      0.046764
ode_loss                0.077779      4.735038
l1_loss             14359.715820  14359.715820
l1_loss_per_point      39.558446      0.566883
l2_loss 

info retrieval:  12%|█▏        | 171/1477 [2:03:17<9:36:20, 26.48s/it]                  INFO:ode:
                        Training    Validation
prejump_num_loss       17.161905    999.980896
postjump_num_loss     652.004089  37614.722656
prejump_diag_loss       0.001277      0.046602
postjump_diag_loss      0.001314      0.047388
num_loss               80.646124   4661.455072
diag_loss               0.001280      0.046680
ode_loss                0.081925      4.708089
l1_loss             14207.440430  14207.440430
l1_loss_per_point      57.057994      0.560872
l2_loss              1780.918457   1780.918457
l2_loss_per_point       7.152283      0.070306
dyn_loss                  8208.0      800272.0
dyn_loss_per_week      3.5327106      8.204131
loss                  0.08234761      4.708406
INFO:ode:
                          Training     Valdation
accuracy                  0.778169      0.785526
recall                    0.782051      0.791563
npv                       0.990139      

In [None]:
res.keys()
len(res['res_val'])

In [None]:
last_detections = res['res_val'][442]['diag_detectability_df']
last_detections

In [None]:
ccs_single_frequency = subjects_interface.diag_single_ccs_frequency(res['trn_ids'])
ccs_single_frequency_df = pd.DataFrame({'code': ccs_single_frequency.keys(), 'frequency': ccs_single_frequency.values()})
ccs_single_frequency_df = ccs_single_frequency_df.sort_values('frequency')
ccs_single_frequency_df

In [None]:
ccs_single_frequency_df['cum_sum'] = ccs_single_frequency_df['frequency'].cumsum()
ccs_single_frequency_df['cum_perc'] = round(100*ccs_single_frequency_df.cum_sum/ccs_single_frequency_df["frequency"].sum(),2)
ccs_single_frequency_df

In [None]:
ccs_single_frequency_df.dtypes

In [None]:
codes_by_percentiles = []
for l, u in [(-0.1, 20), (20, 40), (40, 60), (60, 80), (80, 100)]:
    cum_perc = ccs_single_frequency_df.cum_perc
    codes = set(ccs_single_frequency_df[(cum_perc > l) & (cum_perc <= u)].code)
    codes_by_percentiles.append(codes)
codes_by_percentiles

In [None]:
detectability_by_percentiles = []
for codes in codes_by_percentiles:
    detections_df = last_detections[last_detections.code.isin(codes)]
    print(len(detections_df))
    p = detections_df.pre_detected.mean()
    detectability_by_percentiles.append(p)
    
detectability_by_percentiles

In [None]:
last_detections.post_detected.sum()