In [1]:
%load_ext autoreload
%autoreload 2
import sys
import os
import glob
import random
from collections import defaultdict
from pathlib import Path

from IPython.display import display

import pandas as pd
import dask

from tqdm import tqdm
import jax
jax.config.update('jax_platform_name', 'cpu')
# jax.config.update('jax_log_compiles', True)
# jax.config.update("jax_debug_nans", True)
# jax.config.update("jax_enable_x64", True)

In [2]:


sys.path.append("../..")

from lib import utils as U
from lib.ehr import load_dataset, load_dataset_scheme, load_dataset_config, Dataset
from lib.ehr.interface import Patients, InterfaceConfig
from lib.ehr.concepts import DemographicVectorConfig, LeadingObservableConfig


In [3]:
import logging
logging.root.level = logging.DEBUG

In [4]:
tag = 'M4ICU'
PATH = f'{os.environ.get("HOME")}/GP/ehr-data/mimic4icu-cohort'
sample = 30000
cache =  f'cached_inteface/patients_{tag}_{sample or ""}'
dataset_config = load_dataset_config(tag, 
                                     sample=sample,
                                     path=PATH)

In [5]:
cache

'cached_inteface/patients_M4ICU_30000'

##### Possible Interface Scheme Configurations

In [6]:
import json
dataset_scheme = load_dataset_scheme(tag)
interface_schem_options = dataset_scheme.supported_target_scheme_options
print(json.dumps(interface_schem_options, sort_keys=True, indent=4))

DEBUG:root:Constructing dx_icd9 (<class 'lib.ehr.coding_scheme.DxICD9'>) scheme
DEBUG:root:Constructing dx_icd9 (<class 'lib.ehr.coding_scheme.DxICD9'>) scheme
DEBUG:root:Constructing dx_icd9_filter_v3_groups (<class 'lib.ehr.coding_scheme.OutcomeExtractor'>) scheme
DEBUG:root:Constructing mimic4_eth32 (<class 'lib.ehr.coding_scheme.MIMIC4Eth32'>) scheme
DEBUG:root:Constructing gender (<class 'lib.ehr.coding_scheme.Gender'>) scheme
DEBUG:root:Constructing int_mimic4_proc (<class 'lib.ehr.coding_scheme.MIMICProcedures'>) scheme
DEBUG:root:Constructing int_mimic4_input (<class 'lib.ehr.coding_scheme.MIMICInput'>) scheme
DEBUG:root:Constructing int_mimic4_obs (<class 'lib.ehr.coding_scheme.MIMICObservables'>) scheme
DEBUG:root:Constructing dx_icd9 (<class 'lib.ehr.coding_scheme.DxICD9'>) scheme
DEBUG:root:Constructing dx_icd10 (<class 'lib.ehr.coding_scheme.DxICD10'>) scheme


{
    "dx": [
        "DxICD9",
        "DxFlatCCS",
        "DxICD10",
        "DxCCS"
    ],
    "ethnicity": [
        "MIMIC4Eth32",
        "MIMIC4Eth5"
    ],
    "gender": [
        "Gender"
    ],
    "int_input": [
        "MIMICInput",
        "MIMICInputGroups"
    ],
    "int_proc": [
        "MIMICProcedures",
        "MIMICProcedureGroups"
    ],
    "obs": [
        "MIMICObservables"
    ],
    "outcome": [
        "dx_flatccs_mlhc_groups",
        "dx_icd9_filter_v3_groups",
        "dx_icd9_filter_v1",
        "dx_icd9_filter_v2_groups",
        "dx_flatccs_filter_v1"
    ]
}


#### Leading Observable for Early Prediction Task

In [7]:
scheme_df = dataset_scheme.obs.as_dataframe()
display(scheme_df[scheme_df.desc.str.contains('aki')])

Unnamed: 0,code,desc
42,o42,aki_stage_smoothed


In [8]:

# Demographic vector attributes
demographic_vector_conf = DemographicVectorConfig(
    age=False, 
    gender=False, 
    ethnicity=False
)

# Leading 
leading_AKI = LeadingObservableConfig(leading_hours=tuple(6 * i for i in range(1, 13)),
                                      window_aggregate='max',
                                      scheme=dataset_scheme.obs,
                                      index=42)

In [9]:
interface_scheme = dataset_scheme.make_target_scheme_config(dx='DxICD9',
                                                            outcome='dx_icd9_filter_v3_groups',
                                                            ethnicity='MIMIC4Eth5')
interface_config = InterfaceConfig(scheme=interface_scheme,
                                   dataset_scheme=dataset_scheme,
                                   demographic_vector=demographic_vector_conf,
                                   leading_observable=leading_AKI,
                                   cache=cache)

def dataset_gen(dataset_config):
    dataset = load_dataset(config=dataset_config)
    # Use training-split for fitting the outlier_remover and the scalers.
    splits = dataset.random_splits([0.8, 0.9], random_seed=42, balanced='admissions')
    # Outlier removal
    outlier_remover = dataset.fit_outlier_remover(splits[0])
    dataset = dataset.remove_outliers(outlier_remover)

    # Scale
    scalers = dataset.fit_scalers(splits[0])
    return dataset.apply_scalers(scalers)
        

In [10]:
m4patients = Patients.try_load_cached(interface_config,
                                      dataset_config=dataset_config,
                                      dataset_generator=dataset_gen,
                                      num_workers=8)

INFO:root:Cache does not match config, ignoring cache.
INFO:root:Loading subjects from scratch.
DEBUG:root:Constructing dx_icd9 (<class 'lib.ehr.coding_scheme.DxICD9'>) scheme
DEBUG:root:Constructing dx_icd9 (<class 'lib.ehr.coding_scheme.DxICD9'>) scheme
DEBUG:root:Constructing dx_icd9_filter_v3_groups (<class 'lib.ehr.coding_scheme.OutcomeExtractor'>) scheme
DEBUG:root:Constructing mimic4_eth32 (<class 'lib.ehr.coding_scheme.MIMIC4Eth32'>) scheme
DEBUG:root:Constructing gender (<class 'lib.ehr.coding_scheme.Gender'>) scheme
DEBUG:root:Constructing int_mimic4_proc (<class 'lib.ehr.coding_scheme.MIMICProcedures'>) scheme
DEBUG:root:Constructing int_mimic4_input (<class 'lib.ehr.coding_scheme.MIMICInput'>) scheme
DEBUG:root:Constructing int_mimic4_obs (<class 'lib.ehr.coding_scheme.MIMICObservables'>) scheme
DEBUG:root:Constructing dx_icd9 (<class 'lib.ehr.coding_scheme.DxICD9'>) scheme
DEBUG:root:Constructing dx_icd10 (<class 'lib.ehr.coding_scheme.DxICD10'>) scheme
DEBUG:root:Loading 

  dob = anchor_date + anchor_age
DEBUG:root:Extracting dx codes...
DEBUG:root:[DONE] Extracting dx codes
DEBUG:root:Extracting dx codes history...
DEBUG:root:[DONE] Extracting dx codes history
DEBUG:root:Extracting outcome...
DEBUG:root:[DONE] Extracting outcome
DEBUG:root:Extracting procedures...
DEBUG:root:[DONE] Extracting procedures
DEBUG:root:Extracting inputs...
DEBUG:root:[DONE] Extracting inputs
DEBUG:root:Extracting observables...
DEBUG:root:obs: filter adms
DEBUG:root:obs: dasking
DEBUG:root:obs: groupby
DEBUG:root:obs: undasking
DEBUG:root:obs: extract
DEBUG:root:obs: empty
DEBUG:root:[DONE] Extracting observables
DEBUG:root:Compiling admissions...
DEBUG:jax._src.xla_bridge:Initializing backend 'cpu'
DEBUG:jax._src.xla_bridge:Backend 'cpu' initialized
DEBUG:jax._src.xla_bridge:Initializing backend 'cuda'
DEBUG:jax._src.xla_bridge:Backend 'cuda' initialized
DEBUG:jax._src.xla_bridge:Initializing backend 'rocm'
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': NOT_

In [14]:
len(m4patients.subjects)

29948

In [18]:
m4patients.config.leading_observable.index

42

In [21]:
sub_ids = sorted(m4patients.subjects)
sid = 0
aid = 3
adm = m4patients.subjects[sub_ids[sid]].admissions[aid]
adm

Admission(
  admission_id=25742920,
  admission_dates=(
    Timestamp('2180-08-05 23:44:00'),
    Timestamp('2180-08-07 17:50:00')
  ),
  dx_codes=CodesVector(
    vec=bool[17375](numpy),
    scheme=<lib.ehr.coding_scheme.DxICD9 object at 0x7fe59ba242e0>
  ),
  dx_codes_history=CodesVector(
    vec=bool[17375](numpy),
    scheme=<lib.ehr.coding_scheme.DxICD9 object at 0x7fe59ba242e0>
  ),
  outcome=CodesVector(
    vec=bool[2081](numpy),
    scheme=<lib.ehr.coding_scheme.OutcomeExtractor object at 0x7fe57ea33070>
  ),
  observables=[
    InpatientObservables(
      time=f64[33](numpy),
      value=f16[33,60](numpy),
      mask=bool[33,60](numpy)
    ),
    InpatientObservables(
      time=f64[0](numpy),
      value=f16[0,60](numpy),
      mask=bool[0,60](numpy)
    ),
    InpatientObservables(
      time=f64[12](numpy),
      value=f16[12,60](numpy),
      mask=bool[12,60](numpy)
    )
  ],
  interventions=InpatientInterventions(
    proc=None,
    input_=None,
    time=f32[100](numpy)

In [38]:
adm.interventions.time

array([ 0.        ,  0.26666668,  1.2666667 , 42.1       ,         nan,
               nan,         nan,         nan,         nan,         nan,
               nan,         nan,         nan,         nan,         nan,
               nan,         nan,         nan,         nan,         nan,
               nan,         nan,         nan,         nan,         nan,
               nan,         nan,         nan,         nan,         nan,
               nan,         nan,         nan,         nan,         nan,
               nan,         nan,         nan,         nan,         nan,
               nan,         nan,         nan,         nan,         nan,
               nan,         nan,         nan,         nan,         nan,
               nan,         nan,         nan,         nan,         nan,
               nan,         nan,         nan,         nan,         nan,
               nan,         nan,         nan,         nan,         nan,
               nan,         nan,         nan,         nan,      

In [31]:
all_obs = adm.observables[0].concat(adm.observables)

In [37]:
all_obs.as_dataframe(dataset_scheme.obs, True)

Unnamed: 0,time,mch,mcv,hematocrit,wbc,platelet,hemoglobin,mchc,rdw,rbc,potassium,albumin,aniongap,bicarbonate,bun,calcium,sodium,chloride,creatinine,glucose
0,-579187.75,,,,,,,,,,,,,,,0.1875,,,,
1,-568031.75,,,,,,,,0.253418,,,,,,,,,,,
2,-537824.75,,,,,,,,,,,,,,,0.201416,,,,
3,-501273.71875,0.315186,,,,,,,,,,,,,,,,,,
4,-494247.71875,,,,,,,,,,,,0.0,,,,,,,
5,-478555.71875,,,,,,,,,,,,,,0.423584,,,,,
6,-475032.71875,,,,,,,,,,,,,,,,,0.06781,,
7,-456713.71875,,,,,,,,,0.363037,,,,,,,,,,
8,-449607.71875,,,,,,0.476807,,,,,,,,,,,,,
9,-445549.71875,,,,,,,,,,,,,,,,,,0.054535,


In [None]:
all_obs.value[:, 42]
# value[:, 42]

In [None]:
all_leading = adm.leading_observable[0].concat(adm.leading_observable)
all_leading.value

In [None]:
all_leading.mask

In [None]:
adm.leading_observable[0].mask

In [None]:
# m4inpatients.size_in_bytes() / 1024 ** 3

In [None]:
# val_batch = m4inpatients.device_batch(splits[1])

In [None]:
# tst_batch = m4inpatients.device_batch(splits[2])

In [None]:
# val_batch.size_in_bytes() / 1024 ** 3, tst_batch.size_in_bytes() / 1024 ** 3

In [None]:
# batch = m4inpatients.device_batch(splits[0][:32])

In [None]:
# batch.size_in_bytes() / 1024 ** 3

In [None]:
# len(batch.subjects)

In [None]:
# batch.n_admissions()

In [None]:
# batch.n_segments()

In [None]:
# batch.n_obs_times()

In [None]:
# import numpy as np
# import matplotlib.pyplot as plt

# a = m4inpatients_jax.obs_coocurrence_matrix
# a = np.array(a)
# plt.imshow(a, cmap='hot', interpolation='nearest')
# plt.show()

In [None]:
# s = batch.subjects[splits[0][6]].admissions[0]
# s.observables[0].value

In [None]:
# batch.interval_hours(splits[0][:10])

### التدريب على نموذج المعادلات التفاضلية الاعتيادية العصبية


In [None]:
from lib.ml import (InICENODE, InICENODEDimensions, InpatientEmbeddingDimensions, 
                    InTrainer, TrainerReporting, OptimizerConfig, WarmupConfig)
from lib.metric import  (CodeAUC, UntilFirstCodeAUC, AdmissionAUC,
                      CodeGroupTopAlarmAccuracy, LossMetric, ObsCodeLevelLossMetric)

import jax.random as jrandom

In [None]:
emb_dims = InpatientEmbeddingDimensions(dx=30, inp=15, proc=15, 
                                        demo=0, 
                                        inp_proc_demo=10)
dims = InICENODEDimensions(mem=15, obs=25, emb=emb_dims)
key = jrandom.PRNGKey(0)

m = InICENODE(dims=dims, 
              schemes=m4inpatients.schemes,
              demographic_vector_config=m4inpatients.demographic_vector_config,
              key=key)

In [None]:
# res = m.batch_predict(m4inpatients.device_batch(), leave_pbar=True)

In [None]:
splits = m4inpatients.random_splits([0.9, 0.95], 
                                    balanced='admissions')

trainer = InTrainer(optimizer_config=OptimizerConfig(opt='adam', lr=1e-3),
                    reg_hyperparams=None,
                    epochs=80,
                    batch_size=128,
                    dx_loss='allpairs_sigmoid_rank',
                    obs_loss='mse')

warmup = WarmupConfig(epochs=0.1, 
                      batch_size=8,
                      opt='adam', lr=1e-3, 
                      decay_rate=0.5)

loss_metric =  LossMetric(m4inpatients, 
                          dx_loss=('softmax_bce', 'balanced_focal_softmax_bce', 
                                   'balanced_focal_bce', 'allpairs_exp_rank', 'allpairs_hard_rank', 
                                   'allpairs_sigmoid_rank'),
                         obs_loss=('mse', 'mae', 'rms'))

metrics = [CodeAUC(m4inpatients), 
           AdmissionAUC(m4inpatients), 
           CodeGroupTopAlarmAccuracy(m4inpatients, n_partitions=5, 
                                     top_k_list=[3, 5, 10, 15, 20],
                                     train_split=splits[0]), 
           loss_metric]


reporting = TrainerReporting(output_dir='inicenode',
                             metrics=metrics,
                             console=True,
                             parameter_snapshots=True,
                             config_json=True)

In [None]:
res = trainer(m, m4inpatients, 
              splits=splits,
              reporting=reporting,
              n_evals=100,
              warmup_config=warmup,
              continue_training=False)

In [None]:
import numpy as np
import pandas as pd
sigmoid = lambda x: 1 / (1 + np.exp(x))

p_delta = np.array(np.linspace(0, 1, 11))
scales = np.array([2**i for i in range(5)])
p_delta_scaled = np.outer(p_delta, scales)
leading_loss = sigmoid(p_delta_scaled)
lagging_loss = sigmoid(-p_delta_scaled)
df1 = pd.DataFrame(leading_loss, columns=[f'scale:{s}' for s in scales],
                  index=p_delta).rename_axis('p_delta')
df2 = pd.DataFrame(lagging_loss, columns=[f'scale:{s}' for s in scales],
                  index=p_delta).rename_axis('p_delta')
df3 = df1 / df2
df1

In [None]:
df2

In [None]:
df3