In [11]:
import pandas as pd
import json
from collections import defaultdict 
from functools import partial
from tqdm import tqdm

import jax

# Global flag to set a specific platform, must be used at startup.
jax.config.update('jax_platform_name', 'cpu')
jax.config.update("jax_debug_nans", True)
jax.config.update("jax_debug_infs", True)

jax.config.update('jax_log_compiles', False)
jax.config.update('jax_check_tracer_leaks', False)

In [12]:
import jax.numpy as jnp
a = jnp.array(range(1000)) 
a.sum()

DeviceArray(499500, dtype=int32)

In [13]:
# Good read: https://iq-inc.com/importerror-attempted-relative-import/

import sys
import importlib
from mimicnet import concept
from mimicnet import jax_interface
from mimicnet import dag
from mimicnet import glove
from mimicnet import gram
from mimicnet import train_snonet
from mimicnet import models

importlib.reload(sys.modules['mimicnet.concept'])
importlib.reload(sys.modules['mimicnet.dag'])
importlib.reload(sys.modules['mimicnet.jax_interface'])
importlib.reload(sys.modules['mimicnet.glove'])
importlib.reload(sys.modules['mimicnet.gram'])
importlib.reload(sys.modules['mimicnet.train_snonet'])
importlib.reload(sys.modules['mimicnet.models'])

<module 'mimicnet.models' from '/home/asem/GP/MIMIC-SNONET/mimicnet/models.py'>

In [14]:
KG = dag.CCSDAG()


In [15]:
dir(KG)

['CCS_DIR',
 'DIAG_MULTI_CCS_FILE',
 'DIAG_SINGLE_CCS_FILE',
 'DIR',
 'PROC_MULTI_CCS_FILE',
 'PROC_SINGLE_CCS_FILE',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 'ancestors_linkage',
 'common_parent',
 'diag_ccs_children_traversal',
 'diag_icd_codes',
 'diag_icd_label',
 'diag_multi_ccs2icd',
 'diag_multi_ccs_codes',
 'diag_multi_ccs_df',
 'diag_multi_ccs_pt2ch',
 'diag_multi_icd2ccs',
 'diag_single_ccs2icd',
 'diag_single_ccs_codes',
 'diag_single_ccs_df',
 'diag_single_icd2ccs',
 'digraph_from_dataframe',
 'find_diag_icd_name',
 'find_proc_icd_name',
 'get_ccs_parents',
 'get_diag_ccs_children',
 'get_diag_multi_ccs',
 'get_proc_ccs_children',
 'get_proc_multi_cc

In [20]:
KG.diag_multi_ccs_codes

{'1',
 '1.1',
 '1.1.1',
 '1.1.2',
 '1.1.2.1',
 '1.1.2.2',
 '1.1.2.3',
 '1.1.2.4',
 '1.1.2.5',
 '1.1.2.6',
 '1.1.3',
 '1.1.4',
 '1.2',
 '1.2.1',
 '1.2.2',
 '1.3',
 '1.3.1',
 '1.3.2',
 '1.3.3',
 '1.3.3.1',
 '1.3.3.2',
 '1.3.3.3',
 '1.4',
 '1.5',
 '10',
 '10.1',
 '10.1.1',
 '10.1.2',
 '10.1.2.1',
 '10.1.2.2',
 '10.1.3',
 '10.1.4',
 '10.1.4.1',
 '10.1.4.2',
 '10.1.4.3',
 '10.1.5',
 '10.1.5.1',
 '10.1.5.2',
 '10.1.5.3',
 '10.1.6',
 '10.1.6.1',
 '10.1.6.2',
 '10.1.7',
 '10.1.7.1',
 '10.1.7.2',
 '10.1.8',
 '10.1.8.1',
 '10.1.8.2',
 '10.1.8.3',
 '10.2',
 '10.2.1',
 '10.2.2',
 '10.2.3',
 '10.3',
 '10.3.1',
 '10.3.2',
 '10.3.2.1',
 '10.3.2.2',
 '10.3.2.3',
 '10.3.2.4',
 '10.3.3',
 '10.3.4',
 '10.3.5',
 '10.3.6',
 '10.3.7',
 '10.3.8',
 '10.3.9',
 '10.3.9.1',
 '10.3.9.2',
 '11',
 '11.1',
 '11.1.1',
 '11.1.2',
 '11.2',
 '11.2.1',
 '11.2.2',
 '11.2.3',
 '11.3',
 '11.3.1',
 '11.3.2',
 '11.3.2.1',
 '11.3.2.2',
 '11.3.2.3',
 '11.3.3',
 '11.3.3.1',
 '11.3.3.2',
 '11.3.4',
 '11.3.4.1',
 '11.3.4.2',
 '11.

In [17]:
KG.get_ccs_parents(KG.diag_multi_ccs_codes[0])

['1.1', '1']

In [4]:
# multi_visit_mimic_dir = '/home/am8520/GP/ehr-data/mimic3-multi-visit'
multi_visit_mimic_dir = '/home/asem/GP/ehr-data/mimic3-multi-visit'
transformed_mimic_dir = '/home/asem/GP/ehr-data/mimic3-transforms'
mimic_dir = '/home/asem/GP/ehr-data/mimic3-v1.4/physionet.org/files/mimiciii/1.4'
# mimic_dir = '/home/asem/GP/MIMIC-SNONET/RAW/mimic-iii-clinical-database-1.4'

experiments_dir = '/home/asem/GP/ehr-data/mimic3-snonet-exp'
experiment_prefix = 'DEC03'

### [FORK] Skip the cell below to load the jaxified data from a stored file on disc

In [6]:
static_df = pd.read_csv(f'{transformed_mimic_dir}/static_df.csv.gz')
adm_df = pd.read_csv(f'{transformed_mimic_dir}/adm_df.csv.gz')
diag_df = pd.read_csv(f'{transformed_mimic_dir}/diag_df.csv.gz', dtype={'ICD9_CODE': str})
proc_df = pd.read_csv(f'{transformed_mimic_dir}/proc_df.csv.gz', dtype={'ICD9_CODE': str})
test_df = pd.read_csv(f'{transformed_mimic_dir}/test_df.csv.gz')


# Cast columns of dates to datetime64

static_df['DOB'] = pd.to_datetime(static_df.DOB, infer_datetime_format=True).dt.normalize()
adm_df['ADMITTIME'] = pd.to_datetime(adm_df.ADMITTIME, infer_datetime_format=True).dt.normalize()
adm_df['DISCHTIME'] = pd.to_datetime(adm_df.DISCHTIME, infer_datetime_format=True).dt.normalize()
test_df['DATE'] = pd.to_datetime(test_df.DATE, infer_datetime_format=True).dt.normalize()


patients = concept.Subject.to_list(static_df, adm_df, diag_df, proc_df, test_df)

KG = dag.CCSDAG()

subjects_interface = jax_interface.SubjectJAXInterface(patients, set(test_df.ITEMID), KG)
import pickle
with open(f'{experiments_dir}/{experiment_prefix}_subjects_interface.pkl', 'wb') as pickleFile:
    pickle.dump(subjects_interface, pickleFile)

In [7]:
import pickle
with open(f'{experiments_dir}/{experiment_prefix}_subjects_interface.pkl', 'rb') as pickleFile:
    subjects_interface = pickle.load(pickleFile)

## GloVe Initialization

In [8]:
%load_ext autoreload
%autoreload 2

In [9]:
glove_args = {
    'diag_idx': subjects_interface.diag_multi_ccs_idx,
    'proc_idx': subjects_interface.proc_multi_ccs_idx,
    'ccs_dag': subjects_interface.dag,
    'subjects': subjects_interface.subjects.values(),
    'diag_vector_size': 100,
    'proc_vector_size': 60,
    'iterations': 30,
    'window_size_days': 2 * 365
}

diag_glove_rep, proc_glove_rep = glove.glove_representation(**glove_args)

In [10]:
print(f'#point_indices: {len(subjects_interface.nth_points)}')
print(f'#total_points: {sum(len(points) for n, points in subjects_interface.nth_points.items())}')

#[len(points) for n, points in subjects_interface.nth_points.items()]

#point_indices: 1085
#total_points: 128657


## GRAM objects

In [11]:
from datetime import datetime
daily_tracer = "/tmp/tensorboard/"+ datetime.now().strftime("%Y%m%d-%H%M%S") 
print(daily_tracer)

/tmp/tensorboard/20211220-214936


In [12]:


logs = '/tmp/tensorboard/20210708-182059'
#server = jax.profiler.start_server(9999)

In [13]:
       
# config = {
#     'gram_config': {
#         'diag': {
#             'ccs_dag': KG,
#             'code2index': subjects_interface.diag_multi_ccs_idx,
#             'attention_method': 'tanh', #l2, tanh
#             'attention_dim': 50,
#             'ancestors_mat': subjects_interface.diag_multi_ccs_ancestors_mat,
#             'basic_embeddings': diag_glove_rep
#         },
#         'proc': {
#             'ccs_dag': KG,
#             'code2index': subjects_interface.proc_multi_ccs_idx,
#             'attention_method': 'tanh',
#             'attention_dim': 50,
#             'ancestors_mat': subjects_interface.proc_multi_ccs_ancestors_mat,
#             'basic_embeddings': proc_glove_rep
#         }
#     },
#     'model': {
#         'ode_dyn': 'mlp', # gru, mlp
#         'state_size': 50,
#         'numeric_hidden_size': 50,
#         'bias': True
#     },
#     'training': {
#         'train_validation_split': 0.8,
#         'batch_size': 4,
#         'epochs': 200,
#         'lr': 1e-3,
#         'diag_loss': 'balanced_focal', # balanced_focal, bce
#         'tay_reg': 3, # Order of regularized derivative of the dynamics function (None for disable).
#         'loss_mixing': {
#             'num_alpha': 0.1,
#             'diag_alpha': 0.1,
#             'ode_alpha': 1e-3,
#             'l1_reg': 1e-6,
#             'l2_reg': 1e-5,
#             'dyn_reg': 1e-5
#         },
#         'eval_freq': 10,
#         'save_freq': 100,
#         'save_params_prefix': None
#     }
# }


In [19]:
       
config = {
    'gram_config': {
        'diag': {
            'ccs_dag': subjects_interface.dag,
            'code2index': subjects_interface.diag_multi_ccs_idx,
            'attention_method': 'tanh', #l2, tanh
            'attention_dim': 150,
            'ancestors_mat': subjects_interface.diag_multi_ccs_ancestors_mat,
            'basic_embeddings': diag_glove_rep
        },
        'proc': {
            'ccs_dag': subjects_interface.dag,
            'code2index': subjects_interface.proc_multi_ccs_idx,
            'attention_method': 'tanh',
            'attention_dim': 100,
            'ancestors_mat': subjects_interface.proc_multi_ccs_ancestors_mat,
            'basic_embeddings': proc_glove_rep
        }
    },
    'model': {
        'ode_dyn': 'gru', # gru, mlp, res
        'ode_depth': 2,
        'state_size': 120,
        'numeric_hidden_size': 200,
        'init_depth': 2,
        'bias': True,
        'max_odeint_days': 8 * 7 # two months
    },
    'training': {
        'train_validation_split': 0.8,
        'batch_size': 20,
        'epochs': 200,
        'lr': 1e-3,
        'diag_loss': 'balanced_focal', # balanced_focal, bce
        'tay_reg': 3, # Order of regularized derivative of the dynamics function (None for disable).
        'loss_mixing': {
            'num_alpha': 0.1,
            'diag_alpha': 0.1,
            'ode_alpha': 1e-6,
            'l1_reg': 1e-6,
            'l2_reg': 1e-5,
            'dyn_reg': 1e3
        },
        'eval_freq': 5,
        'save_freq': 100,
        'output_dir': f'{experiments_dir}/{experiment_prefix}'
    }
}


In [20]:
diag_gram = gram.DAGGRAM(**config['gram_config']['diag'])

In [21]:
proc_gram = gram.DAGGRAM(**config['gram_config']['proc'])

## GRU-ODE-Bayes

In [24]:
import random
%load_ext autoreload
%autoreload 2
from absl import logging
logging.set_verbosity(logging.INFO)

#with jax.profiler.trace(logs):
res = train_snonet.train_ehr(subject_interface=subjects_interface,
                diag_gram=diag_gram,
                proc_gram=proc_gram,
                rng=random.Random(42),
                model_config=config['model'],
                **config['training'])

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


INFO:absl:#params: 688714
INFO:absl:shape(params): {'diag_gram': ((589, 100), FlatMap({
  'None_DAG_Attention/~/linear': FlatMap({'b': (150,), 'w': (200, 150)}),
  'None_DAG_Attention/~/linear_1': FlatMap({'w': (150, 1)}),
})), 'f_dec': FlatMap({
  'f_dec/~/lin_gram': FlatMap({'b': (100,), 'w': (50, 100)}),
  'f_dec/~/lin_h_hidden': FlatMap({'b': (50,), 'w': (120, 50)}),
  'f_dec/~/lin_out': FlatMap({'b': (284,), 'w': (100, 284)}),
}), 'f_n_ode': FlatMap({
  'n_ode/~/ode_dyn_augment/~/ode_dyn/~/hc_r': FlatMap({'b': (120,), 'w': (190, 120)}),
  'n_ode/~/ode_dyn_augment/~/ode_dyn/~/hc_z': FlatMap({'b': (120,), 'w': (190, 120)}),
  'n_ode/~/ode_dyn_augment/~/ode_dyn/~/rhc_g': FlatMap({'b': (120,), 'w': (190, 120)}),
}), 'f_num': FlatMap({
  'f_numeric/linear': FlatMap({'b': (200,), 'w': (120, 200)}),
  'f_numeric/~/linear': FlatMap({'b': (550,), 'w': (200, 550)}),
  'f_numeric/~/linear_1': FlatMap({'b': (550,), 'w': (200, 550)}),
}), 'f_state_init': FlatMap({
  'f_init/~/lin_0': FlatMap({

KeyboardInterrupt: 

#### Possible modifications:
- Add more layers to the adjustment function
- Use days instead of weeks for odeint