In [1]:
%load_ext autoreload
%autoreload 2
import sys
import os
import glob
import random
from collections import defaultdict
from pathlib import Path

from IPython.display import display

import pandas as pd
import dask

from tqdm import tqdm
import jax
jax.config.update('jax_platform_name', 'cpu')
# jax.config.update('jax_log_compiles', True)
# jax.config.update("jax_debug_nans", True)
# jax.config.update("jax_enable_x64", True)

In [2]:


sys.path.append("../..")

from lib import utils as U
from lib.ehr import load_dataset, load_dataset_scheme, load_dataset_config, Dataset
from lib.ehr.interface import Patients, InterfaceConfig
from lib.ehr.concepts import DemographicVectorConfig


In [3]:
import logging
logging.root.level = logging.INFO

In [4]:
tag = 'M4'
PATH = f'{os.environ.get("HOME")}/GP/ehr-data/mimic4-cohort'
sample = None
cache =  f'cached_inteface/patients_{tag}_{sample or ""}'
dataset_config = load_dataset_config(tag, 
                                     sample=sample,
                                     path=PATH)

##### Possible Interface Scheme Configurations

In [5]:
import json
dataset_scheme = load_dataset_scheme(tag)
interface_schem_options = dataset_scheme.supported_target_scheme_options
print(json.dumps(interface_schem_options, sort_keys=True, indent=4))

{
    "dx": [
        "DxICD10",
        "DxCCS",
        "DxICD9",
        "DxFlatCCS"
    ],
    "ethnicity": [
        "MIMIC4Eth32",
        "MIMIC4Eth5"
    ],
    "gender": [
        "Gender"
    ],
    "outcome": [
        "dx_flatccs_filter_v1",
        "dx_icd9_filter_v1",
        "dx_flatccs_mlhc_groups",
        "dx_icd9_filter_v2_groups",
        "dx_icd9_filter_v3_groups"
    ]
}


In [6]:
interface_scheme = dataset_scheme.make_target_scheme_config(dx='DxICD9',
                                                            outcome='dx_icd9_filter_v3_groups',
                                                            ethnicity='MIMIC4Eth5')


# Demographic vector attributes
demographic_vector_conf = DemographicVectorConfig(
    age=False, 
    gender=False, 
    ethnicity=False
)
interface_config = InterfaceConfig(scheme=interface_scheme,
                                   dataset_scheme=dataset_scheme,
                                   demographic_vector=demographic_vector_conf,
                                   cache=cache)

In [7]:
from lib.ml import (ICENODE, ICENODEConfig, OutpatientEmbeddingConfig,  SplitConfig,
                    Trainer, TrainerConfig, TrainerReporting, OptimizerConfig, WarmupConfig, ReportingConfig)
from lib.metric import  (CodeAUC, UntilFirstCodeAUC, AdmissionAUC, CodeLevelMetricConfig, MetricLevelsConfig,
                         LossMetricConfig,
                      CodeGroupTopAlarmAccuracy, LossMetric, ObsCodeLevelLossMetric, CodeGroupTopAlarmAccuracyConfig)
from lib.ml import Experiment, ExperimentConfig, SplitConfig

import jax.random as jrandom

In [8]:
emb_dims = OutpatientEmbeddingConfig(dx=30, demo=0)
model_config = ICENODEConfig(mem=15, emb=emb_dims)
model_classname = ICENODE.__name__

In [9]:
trainer_config = TrainerConfig(optimizer=OptimizerConfig(opt='adam', lr=1e-3),
                          epochs=80,
                          batch_size=128,
                          dx_loss='balanced_focal_bce',
                          obs_loss='mse',
                          lead_loss='mse')

warmup = WarmupConfig(epochs=0.1, 
                      batch_size=8,
                      opt='adam', lr=1e-3, 
                      decay_rate=0.5)




In [10]:
dx_loss = ["softmax_bce", "balanced_focal_softmax_bce", "balanced_focal_bce",
          "allpairs_exp_rank", "allpairs_hard_rank", "allpairs_sigmoid_rank"]
obs_loss =  ["mse", "mae", "rms"]
                
metrics_conf = [
    (CodeAUC, CodeLevelMetricConfig(aggregate_level=True, code_level=True)),
    (AdmissionAUC, MetricLevelsConfig(admission=False, aggregate=True, subject_aggregate=False)),
    (CodeGroupTopAlarmAccuracy, CodeGroupTopAlarmAccuracyConfig(n_partitions=5, top_k_list=[3, 5, 10, 15, 20])),
    (LossMetric, LossMetricConfig(dx_loss=dx_loss))
]
metrics_conf = [m.export_module_class(c) for m, c in metrics_conf]

In [11]:
reporting_conf = ReportingConfig(output_dir='icenode',
                                 console=True,
                                 model_stats=False,
                                 parameter_snapshots=True,
                                 config_json=True)

In [12]:
expt_config = ExperimentConfig(dataset=dataset_config,
                              interface=interface_config,
                              split=SplitConfig(train=0.8, val=0.1, test=0.1, balanced='admissions'),
                              trainer=trainer_config,
                              metrics=metrics_conf,
                              reporting=reporting_conf,
                              model=model_config,
                              model_classname=model_classname,
                              n_evals=100,
                              continue_training=True,
                              warmup=None,
                              reg_hyperparams=None)

In [13]:
experiment = Experiment(expt_config)

In [14]:
expt_config

ExperimentConfig(
  dataset=DatasetConfig(
    path='/home/asem/GP/ehr-data/mimic4-cohort',
    scheme=DatasetSchemeConfig(
      dx={'10': 'DxICD10', '9': 'DxICD9'},
      ethnicity='MIMIC4Eth32',
      gender='Gender',
      outcome=None
    ),
    scheme_classname='MIMIC4DatasetScheme',
    colname={
      'adm':
      {
        'admittime':
        'admittime',
        'dischtime':
        'dischtime',
        'index':
        'hadm_id',
        'subject_id':
        'subject_id'
      },
      'dx':
      {'admission_id': 'hadm_id', 'code': 'icd_code', 'version': 'icd_version'},
      'static':
      {
        'anchor_age':
        'anchor_age',
        'anchor_year':
        'anchor_year',
        'ethnicity':
        'race',
        'gender':
        'gender',
        'index':
        'subject_id'
      }
    },
    files={
      'adm':
      'adm_df.csv.gz',
      'dx':
      'dx_df.csv.gz',
      'static':
      'static_df.csv.gz'
    },
    sample=None,
    meta_fpath='',
   

In [None]:
ds = experiment.load_dataset(experiment.config.dataset)

Please ensure that each individual file can fit in memory and
use the keyword ``blocksize=None to remove this message``
Setting ``blocksize=None``
  warn(
Please ensure that each individual file can fit in memory and
use the keyword ``blocksize=None to remove this message``
Setting ``blocksize=None``
  warn(
Please ensure that each individual file can fit in memory and
use the keyword ``blocksize=None to remove this message``
Setting ``blocksize=None``
  warn(


In [None]:
type(ds)

In [None]:
logging.level = logging.DEBUG

In [None]:
result = experiment.run()

###### 