In [1]:
%load_ext autoreload
%autoreload 2
import sys
import os
import glob
import random
from collections import defaultdict
from pathlib import Path

from IPython.display import display

import pandas as pd
import dask

from tqdm import tqdm
import jax
jax.config.update('jax_platform_name', 'cpu')
# jax.config.update('jax_log_compiles', True)
# jax.config.update("jax_debug_nans", True)
# jax.config.update("jax_enable_x64", True)

In [2]:


sys.path.append("../..")

from lib import utils as U
from lib.ehr.dataset import load_dataset, load_dataset_scheme, load_dataset_config, Dataset
from lib.ehr.interface import Patients, InterfaceConfig
from lib.ehr.concepts import DemographicVectorConfig


In [3]:
import logging
logging.root.level = logging.INFO

In [4]:
tag = 'M4'
PATH = f'{os.environ.get("HOME")}/GP/ehr-data/mimic4-cohort'
sample = None
cache =  f'cached_inteface/patients_{tag}'
dataset_config = load_dataset_config(tag, 
                                     sample=sample,
                                     path=PATH)
dataset_gen = lambda c: load_dataset(config=c)


In [5]:
dataset_config

##### Possible Interface Scheme Configurations

In [6]:
import json
dataset_scheme = load_dataset_scheme(tag)
interface_schem_options = dataset_scheme.supported_target_scheme_options
print(json.dumps(interface_schem_options, sort_keys=True, indent=4))

In [7]:
dataset_scheme

In [8]:
interface_scheme = dataset_scheme.make_target_scheme_config(dx='DxICD9',
                                                            outcome='dx_icd9_filter_v3_groups',
                                                            ethnicity='MIMIC4Eth5')


# Demographic vector attributes
demographic_vector_conf = DemographicVectorConfig(
    age=False, 
    gender=False, 
    ethnicity=False
)
interface_config = InterfaceConfig(scheme=interface_scheme,
                                   dataset_scheme=dataset_scheme,
                                   demographic_vector=demographic_vector_conf,
                                   cache=cache)

In [9]:
interface_config

In [10]:
m4patients = Patients.try_load_cached(interface_config,
                                      dataset_config=dataset_config,
                                      dataset_generator=dataset_gen,
                                      num_workers=8)

In [11]:
splits = m4patients.dataset.random_splits([0.8, 0.9], random_seed=42, balanced='admissions')

In [12]:
len(m4patients.subjects)

In [13]:
m4patients.subjects[splits[0][0]]

In [14]:
# m4patients.size_in_bytes() / 1024 ** 3

In [15]:
# val_batch = m4patients.device_batch(splits[1])

In [16]:
# tst_batch = m4patients.device_batch(splits[2])

In [17]:
# val_batch.size_in_bytes() / 1024 ** 3, tst_batch.size_in_bytes() / 1024 ** 3

In [18]:
# batch = m4patients.device_batch(splits[0][:32])

In [19]:
# batch.size_in_bytes() / 1024 ** 3

In [20]:
# len(batch.subjects)

In [21]:
# batch.n_admissions()

In [22]:
# batch.n_segments()

In [23]:
# batch.n_obs_times()

In [24]:
# s = batch.subjects[splits[0][6]].admissions[0]

### التدريب على نموذج المعادلات التفاضلية الاعتيادية العصبية


In [25]:
from lib.ml import (ICENODE, ICENODEDimensions, 
                    GRU, GRUDimensions,
                    RETAIN, RETAINDimensions,
                    PatientEmbeddingDimensions, 
                    Trainer, TrainerReporting, OptimizerConfig, WarmupConfig,
                    TrainerConfig, ReportingConfig)
from lib.metric import  (CodeAUC, UntilFirstCodeAUC, AdmissionAUC,
                      CodeGroupTopAlarmAccuracy, LossMetric, ObsCodeLevelLossMetric)

import jax.random as jrandom

In [26]:
emb_dims = PatientEmbeddingDimensions(dx=30, demo=0)
key = jrandom.PRNGKey(0)

def icenode_model():
    dims = ICENODEDimensions(mem=15, emb=emb_dims)
    return ICENODE(dims=dims, schemes=m4patients.schemes, 
                   demographic_vector_config=m4patients.config.demographic_vector,
                   key=key)

def gru_model():
    dims = GRUDimensions(emb=emb_dims)
    return GRU(dims=dims, schemes=m4patients.schemes, 
                   demographic_vector_config=m4patients.config.demographic_vector,
                   key=key)

def retain_model():
    dims = RETAINDimensions(mem_a=25, mem_b=25, emb=emb_dims)
    return RETAIN(dims=dims, schemes=m4patients.schemes, 
                   demographic_vector_config=m4patients.config.demographic_vector,
                   key=key)

models = {
    'dx_icenode': icenode_model(),
    'dx_gru': gru_model(),
    'dx_retain': retain_model()
}

In [27]:
# res = m.batch_predict(m4inpatients.device_batch(), leave_pbar=True)

In [28]:
trainer_conf = TrainerConfig(optimizer=OptimizerConfig(opt='adam', lr=1e-3),
                             epochs=150,
                             batch_size=32,
                             dx_loss='allpairs_sigmoid_rank')
trainer = Trainer(trainer_conf)

warmup = WarmupConfig(epochs=0.1, 
                      batch_size=8,
                      opt='adam', 
                      lr=1e-3, 
                      decay_rate=0.5)

loss_metric =  LossMetric(m4patients, 
                          dx_loss=('softmax_bce', 'balanced_focal_softmax_bce', 
                                   'balanced_focal_bce', 'allpairs_exp_rank', 'allpairs_hard_rank', 
                                   'allpairs_sigmoid_rank'))

metrics = [CodeAUC(m4patients), 
           AdmissionAUC(m4patients), 
           CodeGroupTopAlarmAccuracy(m4patients, n_partitions=5,  top_k_list=[3, 5, 10, 15, 20], train_split=splits[0]),
           loss_metric]

In [None]:
res = {}
for name in ['dx_gru', 'dx_retain']:
    model = models[name]
    print(name)
    reporting_conf = ReportingConfig(output_dir=name,
                                     console=True,
                                     model_stats=False,
                                     parameter_snapshots=True,
                                     config_json=True)
    reporting = TrainerReporting(reporting_conf, metrics=metrics)
    res[name] = trainer(model, m4patients, 
                  splits=splits,
                  reporting=reporting,
                  n_evals=100,
#                   warmup_config=warmup,
                  continue_training=False)