In [1]:
%load_ext autoreload
%autoreload 2
import sys
import os
import glob
import random
from collections import defaultdict
from pathlib import Path

from IPython.display import display

import pandas as pd
import dask

from tqdm import tqdm
import jax
jax.config.update('jax_platform_name', 'gpu')
# jax.config.update('jax_log_compiles', True)
# jax.config.update("jax_debug_nans", True)
# jax.config.update("jax_enable_x64", True)

In [2]:


sys.path.append("../..")

from lib import utils as U
from lib.ehr.dataset import load_dataset, load_dataset_scheme, Dataset
from lib.ehr.interface import Patients
from lib.ehr.concepts import DemographicVectorConfig


In [3]:
import logging
logging.root.level = logging.INFO

In [4]:
# Assign the folder of the dataset to `DATA_FILE`.

HOME = os.environ.get('HOME')
DATA_DIR = f'{HOME}/GP/ehr-data'
SOURCE_DIR = os.path.abspath("..")
# cache_to_disk = 'cached_inteface/m4inpatients_8000'
use_cached =  'cached_inteface/m4inpatients_8000' #cache_to_disk        # False # 


In [5]:
if use_cached:
    m4inpatients = Patients.load(use_cached)
    splits = m4inpatients.dataset.random_splits([0.8, 0.9], random_seed=42, balanced='admissions')

else:
    with U.modified_environ(DATA_DIR=DATA_DIR), dask.config.set(scheduler='processes', num_workers=12):
        
        # Load dataset
        m4icu_dataset = load_dataset('M4ICU', sample=8000)
        # Use training-split for fitting the outlier_remover and the scalers.
        splits = m4icu_dataset.random_splits([0.8, 0.9], random_seed=42, balanced='admissions')
        # Outlier removal
        outlier_remover = m4icu_dataset.fit_outlier_remover(splits[0])
        m4icu_dataset = m4icu_dataset.remove_outliers(outlier_remover)
        
        # Scale
        scalers = m4icu_dataset.fit_scalers(splits[0])
        m4icu_dataset = m4icu_dataset.apply_scalers(scalers)
        
        # Demographic vector attributes
        demographic_vector_conf = DemographicVectorConfig(age=True,
                                                      gender=True,
                                                      ethnicity=True)
        # Load interface
        m4inpatients = Patients(m4icu_dataset, demographic_vector_conf).load_subjects(num_workers=12)

        # Cache to disk
        m4inpatients.save(cache_to_disk, overwrite=True)

In [6]:
len(m4inpatients.subjects)

7984

In [7]:
# m4inpatients.size_in_bytes() / 1024 ** 3

In [8]:
# val_batch = m4inpatients.device_batch(splits[1])

In [9]:
# tst_batch = m4inpatients.device_batch(splits[2])

In [10]:
# val_batch.size_in_bytes() / 1024 ** 3, tst_batch.size_in_bytes() / 1024 ** 3

In [11]:
# batch = m4inpatients.device_batch(splits[0][:32])

In [12]:
# batch.size_in_bytes() / 1024 ** 3

In [13]:
# len(batch.subjects)

In [14]:
# batch.n_admissions()

In [15]:
# batch.n_segments()

In [16]:
# batch.n_obs_times()

In [17]:
# import numpy as np
# import matplotlib.pyplot as plt

# a = m4inpatients_jax.obs_coocurrence_matrix
# a = np.array(a)
# plt.imshow(a, cmap='hot', interpolation='nearest')
# plt.show()

In [18]:
# s = batch.subjects[splits[0][6]].admissions[0]
# s.observables[0].value

In [19]:
# batch.interval_hours(splits[0][:10])

### التدريب على نموذج المعادلات التفاضلية الاعتيادية العصبية


In [20]:
from lib.ml import (InICENODE, InICENODEDimensions, InpatientEmbeddingDimensions, 
                    Trainer, TrainerReporting, OptimizerConfig, WarmupConfig)
from lib.metric import  (CodeAUC, UntilFirstCodeAUC, AdmissionAUC,
                      CodeGroupTopAlarmAccuracy, LossMetric, ObsCodeLevelLossMetric)

import jax.random as jrandom

In [21]:
emb_dims = InpatientEmbeddingDimensions(dx=30, inp=15, proc=15, demo=5, inp_proc_demo=15)
dims = InICENODEDimensions(mem=15, obs=25, emb=emb_dims)
key = jrandom.PRNGKey(0)

m = InICENODE(dims=dims, 
              scheme=m4inpatients.dataset.scheme,
              demographic_vector_config=m4inpatients.demographic_vector_config,
              key=key)

In [22]:
# res = m.batch_predict(m4inpatients.device_batch(), leave_pbar=True)

In [23]:
splits = m4inpatients.random_splits([0.9, 0.95], 
                                    balanced='admissions')

trainer = Trainer(optimizer_config=OptimizerConfig(opt='adam', lr=1e-3),
                    reg_hyperparams=None,
                    epochs=80,
                    batch_size=256,
                    dx_loss='allpairs_sigmoid_rank')

warmup = WarmupConfig(epochs=0.1, 
                      batch_size=8,
                      opt='adam', lr=1e-3, 
                      decay_rate=0.5)

loss_metric =  LossMetric(m4inpatients, 
                          dx_loss=('softmax_bce', 'balanced_focal_softmax_bce', 
                                   'balanced_focal_bce', 'allpairs_exp_rank', 'allpairs_hard_rank', 
                                   'allpairs_sigmoid_rank'))

metrics = [CodeAUC(m4inpatients), 
           AdmissionAUC(m4inpatients), 
           CodeGroupTopAlarmAccuracy(m4inpatients, n_partitions=5, 
                                     top_k_list=[3, 5, 10, 15, 20],
                                     train_split=splits[0]), 
           loss_metric]


reporting = TrainerReporting(output_dir='inicenode',
                             metrics=metrics,
                             console=False,
                             parameter_snapshots=True,
                             config_json=True)

In [None]:
res = trainer(m, m4inpatients, 
              splits=splits,
              reporting=reporting,
              n_evals=100,
              warmup_config=warmup,
             continue_training=False)

INFO:root:Warming up...


Loading to device: 0subject [00:00, ?subject/s]

  0%|          | 0/1 [00:00<?, ?Epoch/s]

  0%|          | 0/294 [00:00<?, ?Batch/s]

Loading to device:   0%|          | 0/5 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/5 [00:00<?, ?subject/s]

  0%|          | 0.00/73.24 [00:00<?, ?odeint-days/s]

Loading to device:   0%|          | 0/3 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/3 [00:00<?, ?subject/s]

  0%|          | 0.00/60.41 [00:00<?, ?odeint-days/s]

Loading to device:   0%|          | 0/3 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/3 [00:00<?, ?subject/s]

  0%|          | 0.00/85.71 [00:00<?, ?odeint-days/s]

Loading to device:   0%|          | 0/3 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/3 [00:00<?, ?subject/s]

  0%|          | 0.00/71.58 [00:00<?, ?odeint-days/s]

Loading to device:   0%|          | 0/1 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/1 [00:00<?, ?subject/s]

  0%|          | 0.00/74.77 [00:00<?, ?odeint-days/s]

Loading to device:   0%|          | 0/2 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/2 [00:00<?, ?subject/s]

  0%|          | 0.00/72.27 [00:00<?, ?odeint-days/s]

Loading to device:   0%|          | 0/2 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/2 [00:00<?, ?subject/s]

  0%|          | 0.00/29.72 [00:00<?, ?odeint-days/s]

Loading to device:   0%|          | 0/4 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/4 [00:00<?, ?subject/s]

  0%|          | 0.00/116.34 [00:00<?, ?odeint-days/s]

Loading to device:   0%|          | 0/6 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/6 [00:00<?, ?subject/s]

  0%|          | 0.00/121.09 [00:00<?, ?odeint-days/s]