In [1]:
import pandas as pd
import json
from collections import defaultdict 
from functools import partial
from tqdm import tqdm

import jax

# Global flag to set a specific platform, must be used at startup.
jax.config.update('jax_platform_name', 'cpu')
jax.config.update("jax_debug_nans", True)
jax.config.update("jax_debug_infs", True)

jax.config.update('jax_log_compiles', False)
jax.config.update('jax_check_tracer_leaks', False)

In [2]:
import jax.numpy as jnp
a = jnp.array(range(1000)) 
a.sum()

DeviceArray(499500, dtype=int32)

In [3]:
# Good read: https://iq-inc.com/importerror-attempted-relative-import/

import sys
import importlib
from mimicnet import concept
from mimicnet import jax_interface
from mimicnet import dag
from mimicnet import glove
from mimicnet import gram
from mimicnet import train
from mimicnet import models

importlib.reload(sys.modules['mimicnet.concept'])
importlib.reload(sys.modules['mimicnet.dag'])
importlib.reload(sys.modules['mimicnet.jax_interface'])
importlib.reload(sys.modules['mimicnet.glove'])
importlib.reload(sys.modules['mimicnet.gram'])
importlib.reload(sys.modules['mimicnet.train'])
importlib.reload(sys.modules['mimicnet.models'])



<module 'mimicnet.models' from '/home/asem/GP/MIMIC-SNONET/mimicnet/models.py'>

In [4]:
# multi_visit_mimic_dir = '/home/am8520/GP/ehr-data/mimic3-multi-visit'
multi_visit_mimic_dir = '/home/asem/GP/ehr-data/mimic3-multi-visit'
transformed_mimic_dir = '/home/asem/GP/ehr-data/mimic3-transforms'
mimic_dir = '/home/asem/GP/ehr-data/mimic3-v1.4/physionet.org/files/mimiciii/1.4'
# mimic_dir = '/home/asem/GP/MIMIC-SNONET/RAW/mimic-iii-clinical-database-1.4'

experiments_dir = '/home/asem/GP/ehr-data/mimic3-snonet-exp'
experiment_prefix = 'DEC03'

### [FORK] Skip the cell below to load the jaxified data from a stored file on disc

In [5]:
# static_df = pd.read_csv(f'{transformed_mimic_dir}/static_df.csv.gz')
# adm_df = pd.read_csv(f'{transformed_mimic_dir}/adm_df.csv.gz')
# diag_df = pd.read_csv(f'{transformed_mimic_dir}/diag_df.csv.gz', dtype={'ICD9_CODE': str})
# proc_df = pd.read_csv(f'{transformed_mimic_dir}/proc_df.csv.gz', dtype={'ICD9_CODE': str})
# test_df = pd.read_csv(f'{transformed_mimic_dir}/test_df.csv.gz')


# # Cast columns of dates to datetime64

# static_df['DOB'] = pd.to_datetime(static_df.DOB, infer_datetime_format=True).dt.normalize()
# adm_df['ADMITTIME'] = pd.to_datetime(adm_df.ADMITTIME, infer_datetime_format=True).dt.normalize()
# adm_df['DISCHTIME'] = pd.to_datetime(adm_df.DISCHTIME, infer_datetime_format=True).dt.normalize()
# test_df['DATE'] = pd.to_datetime(test_df.DATE, infer_datetime_format=True).dt.normalize()


# patients = concept.Subject.to_list(static_df, adm_df, diag_df, proc_df, test_df)

# KG = dag.CCSDAG()

# subjects_interface = jax_interface.SubjectJAXInterface(patients, set(test_df.ITEMID), KG)
# import pickle
# with open(f'{experiments_dir}/{experiment_prefix}_subjects_interface.pkl', 'wb') as pickleFile:
#     pickle.dump(subjects_interface, pickleFile)

In [6]:
import pickle
with open(f'{experiments_dir}/{experiment_prefix}_subjects_interface.pkl', 'rb') as pickleFile:
    subjects_interface = pickle.load(pickleFile)

## GloVe Initialization

In [7]:
%load_ext autoreload
%autoreload 2

In [8]:
import logging
logging.basicConfig(level=logging.DEBUG)
logging.debug("test")

In [9]:
glove_args = {
    'diag_idx': subjects_interface.diag_multi_ccs_idx,
    'proc_idx': subjects_interface.proc_multi_ccs_idx,
    'ccs_dag': subjects_interface.dag,
    'subjects': subjects_interface.subjects.values(),
    'diag_vector_size': 100,
    'proc_vector_size': 60,
    'iterations': 30,
    'window_size_days': 2 * 365
}

diag_glove_rep, proc_glove_rep = glove.glove_representation(**glove_args)

In [10]:
print(f'#point_indices: {len(subjects_interface.nth_points)}')
print(f'#total_points: {sum(len(points) for n, points in subjects_interface.nth_points.items())}')

#[len(points) for n, points in subjects_interface.nth_points.items()]

#point_indices: 1085
#total_points: 129334


## GRAM objects

In [11]:
from datetime import datetime
daily_tracer = "/tmp/tensorboard/"+ datetime.now().strftime("%Y%m%d-%H%M%S") 
print(daily_tracer)

/tmp/tensorboard/20211212-152131


In [12]:


logs = '/tmp/tensorboard/20210708-182059'
#server = jax.profiler.start_server(9999)

In [13]:
       
# config = {
#     'gram_config': {
#         'diag': {
#             'ccs_dag': KG,
#             'code2index': subjects_interface.diag_multi_ccs_idx,
#             'attention_method': 'tanh', #l2, tanh
#             'attention_dim': 50,
#             'ancestors_mat': subjects_interface.diag_multi_ccs_ancestors_mat,
#             'basic_embeddings': diag_glove_rep
#         },
#         'proc': {
#             'ccs_dag': KG,
#             'code2index': subjects_interface.proc_multi_ccs_idx,
#             'attention_method': 'tanh',
#             'attention_dim': 50,
#             'ancestors_mat': subjects_interface.proc_multi_ccs_ancestors_mat,
#             'basic_embeddings': proc_glove_rep
#         }
#     },
#     'model': {
#         'ode_dyn': 'mlp', # gru, mlp
#         'state_size': 50,
#         'numeric_hidden_size': 50,
#         'bias': True
#     },
#     'training': {
#         'train_validation_split': 0.8,
#         'batch_size': 4,
#         'epochs': 200,
#         'lr': 1e-3,
#         'diag_loss': 'balanced_focal', # balanced_focal, bce
#         'tay_reg': 3, # Order of regularized derivative of the dynamics function (None for disable).
#         'loss_mixing': {
#             'num_alpha': 0.1,
#             'diag_alpha': 0.1,
#             'ode_alpha': 1e-3,
#             'l1_reg': 1e-6,
#             'l2_reg': 1e-5,
#             'dyn_reg': 1e-5
#         },
#         'eval_freq': 10,
#         'save_freq': 100,
#         'save_params_prefix': None
#     }
# }


In [14]:
       
config = {
    'gram_config': {
        'diag': {
            'ccs_dag': subjects_interface.dag,
            'code2index': subjects_interface.diag_multi_ccs_idx,
            'attention_method': 'tanh', #l2, tanh
            'attention_dim': 150,
            'ancestors_mat': subjects_interface.diag_multi_ccs_ancestors_mat,
            'basic_embeddings': diag_glove_rep
        },
        'proc': {
            'ccs_dag': subjects_interface.dag,
            'code2index': subjects_interface.proc_multi_ccs_idx,
            'attention_method': 'tanh',
            'attention_dim': 100,
            'ancestors_mat': subjects_interface.proc_multi_ccs_ancestors_mat,
            'basic_embeddings': proc_glove_rep
        }
    },
    'model': {
        'ode_dyn': 'gru', # gru, mlp, res
        'ode_depth': 2,
        'state_size': 120,
        'numeric_hidden_size': 200,
        'init_depth': 2,
        'bias': True,
        'max_odeint_days': 8 * 7 # two months
    },
    'training': {
        'train_validation_split': 0.8,
        'batch_size': 20,
        'epochs': 200,
        'lr': 1e-3,
        'diag_loss': 'balanced_focal', # balanced_focal, bce
        'tay_reg': 3, # Order of regularized derivative of the dynamics function (None for disable).
        'loss_mixing': {
            'num_alpha': 0.1,
            'diag_alpha': 0.1,
            'ode_alpha': 1e-6,
            'l1_reg': 1e-6,
            'l2_reg': 1e-5,
            'dyn_reg': 1e3
        },
        'eval_freq': 5,
        'save_freq': 100,
        'save_params_prefix': None
    }
}


In [15]:
diag_gram = gram.DAGGRAM(**config['gram_config']['diag'])

In [16]:
proc_gram = gram.DAGGRAM(**config['gram_config']['proc'])

## GRU-ODE-Bayes

In [17]:
import random
%load_ext autoreload
%autoreload 2
from absl import logging
logging.set_verbosity(logging.INFO)

#with jax.profiler.trace(logs):
res = train.train_ehr(subject_interface=subjects_interface,
                diag_gram=diag_gram,
                proc_gram=proc_gram,
                rng=random.Random(42),
                model_config=config['model'],
                **config['training'],
                verbose_debug=False,
                shape_debug=False,
                nan_debug=False,
                memory_profile=False)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


INFO:ode:#params: 731364
INFO:ode:shape(params): {'diag_gram': ((589, 100), FlatMap({
  'None_DAG_Attention/~/linear': FlatMap({'b': (150,), 'w': (200, 150)}),
  'None_DAG_Attention/~/linear_1': FlatMap({'w': (150, 1)}),
})), 'f_dec': FlatMap({
  'f_dec/~/lin_gram': FlatMap({'b': (100,), 'w': (150, 100)}),
  'f_dec/~/lin_h_hidden': FlatMap({'b': (50,), 'w': (120, 50)}),
  'f_dec/~/lin_num_hidden1': FlatMap({'b': (50,), 'w': (550, 50)}),
  'f_dec/~/lin_num_hidden2': FlatMap({'b': (100,), 'w': (50, 100)}),
  'f_dec/~/lin_out': FlatMap({'b': (284,), 'w': (100, 284)}),
}), 'f_num': FlatMap({
  'f_numeric/linear': FlatMap({'b': (200,), 'w': (120, 200)}),
  'f_numeric/~/linear': FlatMap({'b': (550,), 'w': (200, 550)}),
  'f_numeric/~/linear_1': FlatMap({'b': (550,), 'w': (200, 550)}),
}), 'f_state_init': FlatMap({
  'f_init/~/lin_0': FlatMap({'b': (100,), 'w': (110, 100)}),
  'f_init/~/lin_1': FlatMap({'b': (100,), 'w': (100, 100)}),
  'f_init/~/lin_out': FlatMap({'b': (120,), 'w': (100, 120

Invalid value encountered in the output of a jit/pmap-ed function. Calling the de-optimized version.
Invalid value encountered in the output of a jit/pmap-ed function. Calling the de-optimized version.


  File "/home/asem/.conda/envs/mimic3-snonet/lib/python3.9/site-packages/jax/interpreters/xla.py", line 690, in _xla_call_impl
    out = compiled_fun(*args)
  File "/home/asem/.conda/envs/mimic3-snonet/lib/python3.9/site-packages/jax/interpreters/xla.py", line 1101, in _execute_compiled
    check_special(name, out_bufs)
  File "/home/asem/.conda/envs/mimic3-snonet/lib/python3.9/site-packages/jax/interpreters/xla.py", line 483, in check_special
    _check_special(name, buf.xla_shape(), buf)
  File "/home/asem/.conda/envs/mimic3-snonet/lib/python3.9/site-packages/jax/interpreters/xla.py", line 489, in _check_special
    raise FloatingPointError(f"invalid value (nan) encountered in {name}")
FloatingPointError: invalid value (nan) encountered in transpose(jvp(apply_fn))

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/asem/.conda/envs/mimic3-snonet/lib/python3.9/site-packages/jax/interpreters/xla.py", line 690, in _xla_c

#### Possible modifications:
- Add more layers to the adjustment function
- Use days instead of weeks for odeint