In [1]:
%load_ext autoreload
%autoreload 2
import sys
import os
import glob
import random
from collections import defaultdict
from pathlib import Path

from IPython.display import display

import pandas as pd

from tqdm import tqdm
import jax
jax.config.update('jax_platform_name', 'gpu')
jax.config.update("jax_debug_nans", True)
# jax.config.update("jax_enable_x64", True)

In [2]:


sys.path.append("../..")

from lib import utils as U
from lib.ehr.dataset import load_dataset
from lib.ehr.inpatient_interface import Patients

In [3]:
import logging
logging.root.level = logging.DEBUG

In [4]:
# from lib.ehr.coding_scheme import MIMIC4Procedures, MIMIC4ProcedureGroups
# from lib.ehr.coding_scheme import MIMIC4Input, MIMIC4InputGroups

# cproc = MIMIC4Procedures()
# cproc_g = MIMIC4ProcedureGroups()
# cinp = MIMIC4Input()
# cinp_g = MIMIC4InputGroups()

In [5]:
# Assign the folder of the dataset to `DATA_FILE`.
import dask

HOME = os.environ.get('HOME')
DATA_DIR = f'{HOME}/GP/ehr-data'
SOURCE_DIR = os.path.abspath("..")

with U.modified_environ(DATA_DIR=DATA_DIR), dask.config.set(scheduler='processes'):
    m4icu_dataset = load_dataset('M4ICU')
   

DEBUG:root:Loading dataframe files
DEBUG:fsspec.local:open file: /home/asem/GP/ehr-data/mimic4icu-cohort/adm_df.csv
DEBUG:fsspec.local:open file: /home/asem/GP/ehr-data/mimic4icu-cohort/dx_df.csv
DEBUG:fsspec.local:open file: /home/asem/GP/ehr-data/mimic4icu-cohort/static_df.csv
DEBUG:fsspec.local:open file: /home/asem/GP/ehr-data/mimic4icu-cohort/obs_df.csv
DEBUG:fsspec.local:open file: /home/asem/GP/ehr-data/mimic4icu-cohort/int_input.csv
DEBUG:fsspec.local:open file: /home/asem/GP/ehr-data/mimic4icu-cohort/int_proc.csv
DEBUG:root:[DONE] Loading dataframe files
DEBUG:root:Preprocess admissions
DEBUG:root:Removing subjects with at least one negative adm_interval: 84
DEBUG:root:adm: Merging overlapping admissions
DEBUG:root:adm: Merged 356 overlapping admissions
DEBUG:root:[DONE] Preprocess admissions
DEBUG:root:Matching admission_id
DEBUG:root:[DONE] Matching admission_id
DEBUG:root:Time casting..
DEBUG:root:[DONE] Time casting..
DEBUG:root:Constructing dx_icd9 (<class 'lib.ehr.coding

In [6]:
splits = m4icu_dataset.random_splits([0.8, 0.9], random_seed=42, balanced='subjects')

In [7]:
preprocessing = m4icu_dataset.fit_preprocessing(splits[0])

In [8]:
m4icu_dataset.apply_preprocessing(preprocessing)

DEBUG:root:Removed 2323089 (0.023) outliers from obs


In [9]:
m4inpatients = Patients(m4icu_dataset)

In [10]:
# df = m4icu_dataset.df['int_input']
# df[df['normalised_amount_per_hour'].isnull()]

In [11]:
# from concurrent.futures import ThreadPoolExecutor
# with dask.config.set(pool=ThreadPoolExecutor(12)):
with dask.config.set(scheduler='processes', num_workers=12):
    m4inpatients = m4inpatients.load_subjects(splits[0][:100], num_workers=12)

  dob = anchor_date + anchor_age
DEBUG:root:Constructing mimic4_eth32 (<class 'lib.ehr.coding_scheme.MIMIC4Eth32'>) scheme
DEBUG:root:Constructing mimic4_eth5 (<class 'lib.ehr.coding_scheme.MIMIC4Eth5'>) scheme
DEBUG:root:Extracting dx codes...
DEBUG:root:Constructing dx_icd10 (<class 'lib.ehr.coding_scheme.DxICD10'>) scheme
DEBUG:root:Constructing dx_icd9 (<class 'lib.ehr.coding_scheme.DxICD9'>) scheme
                            dx_icd10->dx_icd9 Unrecognised t_codes
                            (169):
                            ['041.41', '041.42', '041.43', '041.49', '173.00', '173.01', '173.02', '173.09', '173.10', '173.11', '173.12', '173.19', '173.20', '173.21', '173.22', '173.29', '173.30', '173.31', '173.32', '173.39']...
                            dx_icd10->dx_icd9 Unrecognised s_codes
                            (49910):
                            ['E08.3211', 'E08.3212', 'E08.3213', 'E08.3219', 'E08.3291', 'E08.3292', 'E08.3293', 'E08.3299', 'E08.3311', 'E08.3312', 'E08.3

In [None]:
# m4inpatients.size_in_bytes() / 1024 ** 3

In [None]:
# val_batch = m4inpatients.device_batch(splits[1])

In [None]:
# tst_batch = m4inpatients.device_batch(splits[2])

In [None]:
# val_batch.size_in_bytes() / 1024 ** 3, tst_batch.size_in_bytes() / 1024 ** 3

In [12]:
batch = m4inpatients.device_batch(splits[0][:32])

Loading to device:   0%|          | 0/32 [00:00<?, ?subject/s]

DEBUG:jax._src.xla_bridge:Initializing backend 'cpu'
DEBUG:jax._src.xla_bridge:Backend 'cpu' initialized
DEBUG:jax._src.xla_bridge:Initializing backend 'cuda'
DEBUG:jax._src.xla_bridge:Backend 'cuda' initialized
DEBUG:jax._src.xla_bridge:Initializing backend 'rocm'
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: Interpreter CUDA
DEBUG:jax._src.xla_bridge:Initializing backend 'tpu'
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'






In [13]:
batch.size_in_bytes() / 1024 ** 3

0.004255837760865688

In [14]:
len(batch.subjects)

32

In [15]:
batch.n_admissions()

95

In [16]:
batch.n_segments()

4765

In [17]:
batch.n_obs_times()

3793

In [None]:
# import numpy as np
# import matplotlib.pyplot as plt

# a = m4inpatients_jax.obs_coocurrence_matrix
# a = np.array(a)
# plt.imshow(a, cmap='hot', interpolation='nearest')
# plt.show()

In [None]:
s = batch.subjects[splits[0][6]].admissions[0].interventions.input_
s

In [None]:
batch.interval_hours(splits[0][:10])

### التدريب على نموذج المعادلات التفاضلية الاعتيادية العصبية


In [None]:
from lib.ml.in_icenode import InICENODE, InICENODEDimensions
import jax.random as jrandom

In [None]:
dims = InICENODEDimensions(state_m=15, 
                state_dx_e=10,
                state_obs_e=25,
                input_e=10,
                proc_e=10,
                demo_e=5,
                int_e=15)
key = jrandom.PRNGKey(0)

m = InICENODE(dims=dims, 
              scheme=m4icu_dataset.scheme,
              key=key)

In [None]:
# res = m.batch_predict(m4inpatients.device_batch(), leave_pbar=True)

In [None]:
from lib.ml import InTrainer, MetricsHistory
from lib.metric import  (CodeAUC, UntilFirstCodeAUC, AdmissionAUC,
                      CodeGroupTopAlarmAccuracy, LossMetric, MetricsCollection)

from lib.ml import MinibatchLogger, EvaluationDiskWriter, ParamsDiskWriter, ConfigDiskWriter

In [None]:
config = {        
    "batch_size": 32,
    "lr": 1e-3,
    "epochs": 150,
    "opt": "adam",
    "reg_hyperparams": None
}
trainer = InTrainer(**config)
expt_dir = 'inicenode'

In [None]:
metrics = [
    CodeAUC(m4inpatients),
    UntilFirstCodeAUC(m4inpatients),
    AdmissionAUC(m4inpatients),
    LossMetric(m4inpatients)
]
reporters = [
        MinibatchLogger(config),
#         EvaluationDiskWriter(output_dir=expt_dir),
#         ParamsDiskWriter(output_dir=expt_dir),
#         ConfigDiskWriter(output_dir=expt_dir, config=config)
    ]
metrics = MetricsCollection(metrics)
history = MetricsHistory(metrics)

In [None]:
    
splits = m4inpatients.random_splits([0.9, 0.95], 
                                    balanced='admissions')
res = trainer(m, m4inpatients, splits=splits, history=history, 
             reporters=reporters)

In [None]:
# import jax.tree_util as jtu
# import jax.numpy as jnp
# import equinox as eqx

# jtu.tree_map(lambda x: f'{x.shape} {jnp.any(jnp.isnan(x)).item()}' if eqx.is_array(x) else None , m)

In [None]:
# emb_subj = {i: m.f_emb(s) for i, s in m4inpatients.device_batch().subjects.items()}