In [1]:
%load_ext autoreload
%autoreload 2
import sys
import os
import glob
import random
from collections import defaultdict
from pathlib import Path

from IPython.display import display

import pandas as pd
import dask

from tqdm import tqdm
import jax
jax.config.update('jax_platform_name', 'cpu')
# jax.config.update('jax_log_compiles', True)
# jax.config.update("jax_debug_nans", True)
# jax.config.update("jax_enable_x64", True)

In [2]:


sys.path.append("../..")

from lib import utils as U
from lib.ehr import load_dataset, load_dataset_scheme, load_dataset_config, Dataset
from lib.ehr.interface import Patients, InterfaceConfig
from lib.ehr.concepts import DemographicVectorConfig, LeadingObservableConfig


In [3]:
import logging
logging.root.level = logging.DEBUG

In [4]:
tag = 'M4ICU'
PATH = f'{os.environ.get("HOME")}/GP/ehr-data/mimic4icu-cohort'
sample = 30000
cache =  f'cached_inteface/patients_{tag}_{sample or ""}'
dataset_config = load_dataset_config(tag, 
                                     sample=sample,
                                     path=PATH)

In [5]:
cache

##### Possible Interface Scheme Configurations

In [6]:
import json
dataset_scheme = load_dataset_scheme(tag)
interface_schem_options = dataset_scheme.supported_target_scheme_options
print(json.dumps(interface_schem_options, sort_keys=True, indent=4))

#### Leading Observable for Early Prediction Task

In [7]:
scheme_df = dataset_scheme.obs.as_dataframe()
display(scheme_df[scheme_df.desc.str.contains('aki')])

In [8]:

# Demographic vector attributes
demographic_vector_conf = DemographicVectorConfig(
    age=False, 
    gender=False, 
    ethnicity=False
)

# Leading 
leading_AKI = LeadingObservableConfig(leading_hours=tuple(6 * i for i in range(1, 13)),
                                      window_aggregate='max',
                                      scheme=dataset_scheme.obs,
                                      index=42)

In [9]:
interface_scheme = dataset_scheme.make_target_scheme_config(dx='DxICD9',
                                                            outcome='dx_icd9_filter_v3_groups',
                                                            ethnicity='MIMIC4Eth5')
interface_config = InterfaceConfig(scheme=interface_scheme,
                                   dataset_scheme=dataset_scheme,
                                   demographic_vector=demographic_vector_conf,
                                   leading_observable=leading_AKI,
                                   cache=cache)

def dataset_gen(dataset_config):
    dataset = load_dataset(config=dataset_config)
    # Use training-split for fitting the outlier_remover and the scalers.
    splits = dataset.random_splits([0.8, 0.9], random_seed=42, balanced='admissions')
    # Outlier removal
    outlier_remover = dataset.fit_outlier_remover(splits[0])
    dataset = dataset.remove_outliers(outlier_remover)

    # Scale
    scalers = dataset.fit_scalers(splits[0])
    return dataset.apply_scalers(scalers)
        

In [10]:
m4patients = Patients.try_load_cached(interface_config,
                                      dataset_config=dataset_config,
                                      dataset_generator=dataset_gen,
                                      num_workers=8)

In [14]:
len(m4patients.subjects)

In [18]:
m4patients.config.leading_observable.index

In [21]:
sub_ids = sorted(m4patients.subjects)
sid = 0
aid = 3
adm = m4patients.subjects[sub_ids[sid]].admissions[aid]
adm

In [38]:
adm.interventions.time

In [31]:
all_obs = adm.observables[0].concat(adm.observables)

In [37]:
all_obs.as_dataframe(dataset_scheme.obs, True)

In [None]:
all_obs.value[:, 42]
# value[:, 42]

In [None]:
all_leading = adm.leading_observable[0].concat(adm.leading_observable)
all_leading.value

In [None]:
all_leading.mask

In [None]:
adm.leading_observable[0].mask

In [None]:
# m4inpatients.size_in_bytes() / 1024 ** 3

In [None]:
# val_batch = m4inpatients.device_batch(splits[1])

In [None]:
# tst_batch = m4inpatients.device_batch(splits[2])

In [None]:
# val_batch.size_in_bytes() / 1024 ** 3, tst_batch.size_in_bytes() / 1024 ** 3

In [None]:
# batch = m4inpatients.device_batch(splits[0][:32])

In [None]:
# batch.size_in_bytes() / 1024 ** 3

In [None]:
# len(batch.subjects)

In [None]:
# batch.n_admissions()

In [None]:
# batch.n_segments()

In [None]:
# batch.n_obs_times()

In [None]:
# import numpy as np
# import matplotlib.pyplot as plt

# a = m4inpatients_jax.obs_coocurrence_matrix
# a = np.array(a)
# plt.imshow(a, cmap='hot', interpolation='nearest')
# plt.show()

In [None]:
# s = batch.subjects[splits[0][6]].admissions[0]
# s.observables[0].value

In [None]:
# batch.interval_hours(splits[0][:10])

### التدريب على نموذج المعادلات التفاضلية الاعتيادية العصبية


In [None]:
from lib.ml import (InICENODE, InICENODEDimensions, InpatientEmbeddingDimensions, 
                    InTrainer, TrainerReporting, OptimizerConfig, WarmupConfig)
from lib.metric import  (CodeAUC, UntilFirstCodeAUC, AdmissionAUC,
                      CodeGroupTopAlarmAccuracy, LossMetric, ObsCodeLevelLossMetric)

import jax.random as jrandom

In [None]:
emb_dims = InpatientEmbeddingDimensions(dx=30, inp=15, proc=15, 
                                        demo=0, 
                                        inp_proc_demo=10)
dims = InICENODEDimensions(mem=15, obs=25, emb=emb_dims)
key = jrandom.PRNGKey(0)

m = InICENODE(dims=dims, 
              schemes=m4inpatients.schemes,
              demographic_vector_config=m4inpatients.demographic_vector_config,
              key=key)

In [None]:
# res = m.batch_predict(m4inpatients.device_batch(), leave_pbar=True)

In [None]:
splits = m4inpatients.random_splits([0.9, 0.95], 
                                    balanced='admissions')

trainer = InTrainer(optimizer_config=OptimizerConfig(opt='adam', lr=1e-3),
                    reg_hyperparams=None,
                    epochs=80,
                    batch_size=128,
                    dx_loss='allpairs_sigmoid_rank',
                    obs_loss='mse')

warmup = WarmupConfig(epochs=0.1, 
                      batch_size=8,
                      opt='adam', lr=1e-3, 
                      decay_rate=0.5)

loss_metric =  LossMetric(m4inpatients, 
                          dx_loss=('softmax_bce', 'balanced_focal_softmax_bce', 
                                   'balanced_focal_bce', 'allpairs_exp_rank', 'allpairs_hard_rank', 
                                   'allpairs_sigmoid_rank'),
                         obs_loss=('mse', 'mae', 'rms'))

metrics = [CodeAUC(m4inpatients), 
           AdmissionAUC(m4inpatients), 
           CodeGroupTopAlarmAccuracy(m4inpatients, n_partitions=5, 
                                     top_k_list=[3, 5, 10, 15, 20],
                                     train_split=splits[0]), 
           loss_metric]


reporting = TrainerReporting(output_dir='inicenode',
                             metrics=metrics,
                             console=True,
                             parameter_snapshots=True,
                             config_json=True)

In [None]:
res = trainer(m, m4inpatients, 
              splits=splits,
              reporting=reporting,
              n_evals=100,
              warmup_config=warmup,
              continue_training=False)

In [None]:
import numpy as np
import pandas as pd
sigmoid = lambda x: 1 / (1 + np.exp(x))

p_delta = np.array(np.linspace(0, 1, 11))
scales = np.array([2**i for i in range(5)])
p_delta_scaled = np.outer(p_delta, scales)
leading_loss = sigmoid(p_delta_scaled)
lagging_loss = sigmoid(-p_delta_scaled)
df1 = pd.DataFrame(leading_loss, columns=[f'scale:{s}' for s in scales],
                  index=p_delta).rename_axis('p_delta')
df2 = pd.DataFrame(lagging_loss, columns=[f'scale:{s}' for s in scales],
                  index=p_delta).rename_axis('p_delta')
df3 = df1 / df2
df1

In [None]:
df2

In [None]:
df3