In [1]:
%load_ext autoreload
%autoreload 2
import sys
import os
import glob
import random
from collections import defaultdict
from pathlib import Path

from IPython.display import display

import pandas as pd
import dask

from tqdm import tqdm
import jax
jax.config.update('jax_platform_name', 'gpu')
# jax.config.update('jax_log_compiles', True)
# jax.config.update("jax_debug_nans", True)
# jax.config.update("jax_enable_x64", True)

In [2]:


sys.path.append("../..")

from lib import utils as U
from lib.ehr.dataset import load_dataset, load_dataset_scheme, Dataset
from lib.ehr.interface import Patients
from lib.ehr.concepts import DemographicVectorConfig


In [3]:
# import logging
# logging.root.level = logging.DEBUG

In [4]:
# Assign the folder of the dataset to `DATA_FILE`.

HOME = os.environ.get('HOME')
DATA_DIR = f'{HOME}/GP/ehr-data'
SOURCE_DIR = os.path.abspath("..")
cache_to_disk = 'cached_inteface/m4inpatients_5000'
use_cached =  False #'cached_inteface/m4inpatients' #cache_to_disk        # False # 


In [5]:
if use_cached:
    m4inpatients = Patients.load(use_cached)
    splits = m4inpatients.dataset.random_splits([0.8, 0.9], random_seed=42, balanced='admissions')

else:
    with U.modified_environ(DATA_DIR=DATA_DIR), dask.config.set(scheduler='processes', num_workers=12):
        
        # Load dataset
        m4icu_dataset = load_dataset('M4ICU', sample=5000)
        # Use training-split for fitting the outlier_remover and the scalers.
        splits = m4icu_dataset.random_splits([0.8, 0.9], random_seed=42, balanced='admissions')
        # Outlier removal
        outlier_remover = m4icu_dataset.fit_outlier_remover(splits[0])
        m4icu_dataset = m4icu_dataset.remove_outliers(outlier_remover)
        
        # Scale
        scalers = m4icu_dataset.fit_scalers(splits[0])
        m4icu_dataset = m4icu_dataset.apply_scalers(scalers)
        
        # Demographic vector attributes
        demographic_vector_conf = DemographicVectorConfig(age=True,
                                                      gender=True,
                                                      ethnicity=True)
        # Load interface
        m4inpatients = Patients(m4icu_dataset, demographic_vector_conf).load_subjects(num_workers=12)

        # Cache to disk
        m4inpatients.save(cache_to_disk, overwrite=True)

  dob = anchor_date + anchor_age
                            dx_icd10->dx_icd9 Unrecognised t_codes
                            (169):
                            ['041.41', '041.42', '041.43', '041.49', '173.00', '173.01', '173.02', '173.09', '173.10', '173.11', '173.12', '173.19', '173.20', '173.21', '173.22', '173.29', '173.30', '173.31', '173.32', '173.39']...
                            dx_icd10->dx_icd9 Unrecognised s_codes
                            (49910):
                            ['E08.3211', 'E08.3212', 'E08.3213', 'E08.3219', 'E08.3291', 'E08.3292', 'E08.3293', 'E08.3299', 'E08.3311', 'E08.3312', 'E08.3313', 'E08.3319', 'E08.3391', 'E08.3392', 'E08.3393', 'E08.3399', 'E08.3411', 'E08.3412', 'E08.3413', 'E08.3419']...
                            dx_icd10->dx_icd9 Unrecognised t_codes
                            (169):
                            ['041.41', '041.42', '041.43', '041.49', '173.00', '173.01', '173.02', '173.09', '173.10', '173.11', '173.12', '173.19', '173.2

In [6]:
len(m4inpatients.subjects)

4990

In [7]:
# m4inpatients.size_in_bytes() / 1024 ** 3

In [8]:
# val_batch = m4inpatients.device_batch(splits[1])

In [9]:
# tst_batch = m4inpatients.device_batch(splits[2])

In [10]:
# val_batch.size_in_bytes() / 1024 ** 3, tst_batch.size_in_bytes() / 1024 ** 3

In [11]:
# batch = m4inpatients.device_batch(splits[0][:32])

In [12]:
# batch.size_in_bytes() / 1024 ** 3

In [13]:
# len(batch.subjects)

In [14]:
# batch.n_admissions()

In [15]:
# batch.n_segments()

In [16]:
# batch.n_obs_times()

In [17]:
# import numpy as np
# import matplotlib.pyplot as plt

# a = m4inpatients_jax.obs_coocurrence_matrix
# a = np.array(a)
# plt.imshow(a, cmap='hot', interpolation='nearest')
# plt.show()

In [18]:
# s = batch.subjects[splits[0][6]].admissions[0]
# s.observables[0].value

In [19]:
# batch.interval_hours(splits[0][:10])

### التدريب على نموذج المعادلات التفاضلية الاعتيادية العصبية


In [20]:
from lib.ml import (InICENODE, InICENODEDimensions, InpatientEmbeddingDimensions, 
                    InTrainer, TrainerReporting, OptimizerConfig, WarmupConfig)
from lib.metric import  (CodeAUC, UntilFirstCodeAUC, AdmissionAUC,
                      CodeGroupTopAlarmAccuracy, LossMetric, ObsCodeLevelLossMetric)

import jax.random as jrandom

In [21]:
emb_dims = InpatientEmbeddingDimensions(dx=10, inp=10, proc=10, demo=5, inp_proc_demo=15)
dims = InICENODEDimensions(mem=15, obs=25, emb=emb_dims)
key = jrandom.PRNGKey(0)

m = InICENODE(dims=dims, 
              scheme=m4inpatients.dataset.scheme,
              demographic_vector_config=m4inpatients.demographic_vector_config,
              key=key)

In [22]:
# res = m.batch_predict(m4inpatients.device_batch(), leave_pbar=True)

In [23]:
trainer = InTrainer(optimizer_config=OptimizerConfig(opt='adam', lr=1e-3),
                    reg_hyperparams=None,
                    epochs=20,
                    batch_size=512,
                    dx_loss='balanced_focal_bce',
                    obs_loss='mse')

warmup = WarmupConfig(epochs=0.1, 
                      batch_size=32,
                      opt='adam', lr=1e-3, 
                      decay_rate=0.5)

loss_metric =  LossMetric(m4inpatients, 
                          dx_loss=('softmax_bce', 'balanced_focal_softmax_bce', 
                                   'balanced_focal_bce', 'allpairs_sigmoid_rank'),
                          obs_loss=('mse', 'rms', 'mae'))
obs_code_loss_metric = ObsCodeLevelLossMetric(m4inpatients, obs_loss=('mse', 'rms', 'mae'))

metrics = [#CodeAUC(m4inpatients), AdmissionAUC(m4inpatients), 
           obs_code_loss_metric, loss_metric]


reporting = TrainerReporting(output_dir='inicenode',
                             metrics=metrics,
                             console=True,
                             parameter_snapshots=True,
                             config_json=True)

In [24]:
for s in m4inpatients.subjects.values():
    if s.d2d_interval_days < 0:
        print(s.subject_id, s.d2d_interval_days, list(map(lambda a: a.admission_id, s.admissions)))

In [None]:
# pbar_metric = LossMetric(m4inpatients, 
#                           dx_loss=('allpairs_sigmoid_rank',),
#                           obs_loss=('mse', ))
# rank_loss_extractor = pbar_metric.value_extractor({'field': 'allpairs_sigmoid_rank'})
# obs_loss_extractor = pbar_metric.value_extractor({'field': 'mse'})
# def pbar_desc_update(predictions):
#     df = pbar_metric.to_df(0, predictions)
#     return f'dx_rank_loss={rank_loss_extractor(df).item(): .3f}, obs_mse_loss={obs_loss_extractor(df).item(): .3f}'


splits = m4inpatients.random_splits([0.9, 0.95], 
                                    balanced='admissions')
res = trainer(m, m4inpatients, 
              splits=splits,
              reporting=reporting,
              n_evals=100,
              warmup_config=warmup)

Loading to device: 0subject [00:00, ?subject/s]

  0%|          | 0/1 [00:00<?, ?Epoch/s]

  0%|          | 0/50 [00:00<?, ?Batch/s]

Loading to device:   0%|          | 0/9 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/9 [00:00<?, ?subject/s]

  0%|          | 0.00/142.18 [00:00<?, ?odeint-days/s]

Loading to device:   0%|          | 0/8 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/8 [00:00<?, ?subject/s]

  0%|          | 0.00/190.33 [00:00<?, ?odeint-days/s]

Loading to device:   0%|          | 0/9 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/9 [00:00<?, ?subject/s]

  0%|          | 0.00/134.67 [00:00<?, ?odeint-days/s]

In [None]:
# import jax.tree_util as jtu
# import jax.numpy as jnp
# import equinox as eqx

# jtu.tree_map(lambda x: f'{x.shape} {jnp.any(jnp.isnan(x)).item()}' if eqx.is_array(x) else None , m)

In [None]:
# emb_subj = {i: m.f_emb(s) for i, s in m4inpatients.device_batch().subjects.items()}