In [1]:
%load_ext autoreload
%autoreload 2
import sys
import os
import glob
import random
from collections import defaultdict
from pathlib import Path

from IPython.display import display

import pandas as pd

from tqdm import tqdm
import jax
jax.config.update('jax_platform_name', 'cpu')

In [2]:


sys.path.append("../..")

from lib import utils as U
from lib.ehr.dataset import load_dataset
from lib.ehr.inpatient_interface import Inpatients

In [3]:
import logging
# logging.root.level = logging.DEBUG


In [4]:
# Assign the folder of the dataset to `DATA_FILE`.
import dask

HOME = os.environ.get('HOME')
DATA_DIR = f'{HOME}/GP/ehr-data'
SOURCE_DIR = os.path.abspath("..")

with U.modified_environ(DATA_DIR=DATA_DIR), dask.config.set(scheduler='processes'):
    m4inpatient_dataset = load_dataset('M4ICU')
   

In [5]:
splits = m4inpatient_dataset.random_splits([0.1, 0.7], 42, 'subjects')



In [6]:
preprocessing = m4inpatient_dataset.fit_preprocessing(splits[0])

In [7]:

m4inpatient_dataset.apply_preprocessing(preprocessing)

In [8]:
# from concurrent.futures import ThreadPoolExecutor
# with dask.config.set(pool=ThreadPoolExecutor(12)):
with dask.config.set(scheduler='processes', num_workers=12):
    m4inaptients = Inpatients(m4inpatient_dataset, splits[0], num_workers=12)

  dob = anchor_date + anchor_age
                            dx_icd10->dx_icd9 Unrecognised t_codes
                            (169):
                            ['041.41', '041.42', '041.43', '041.49', '173.00', '173.01', '173.02', '173.09', '173.10', '173.11', '173.12', '173.19', '173.20', '173.21', '173.22', '173.29', '173.30', '173.31', '173.32', '173.39']...
                            dx_icd10->dx_icd9 Unrecognised s_codes
                            (49910):
                            ['E08.3211', 'E08.3212', 'E08.3213', 'E08.3219', 'E08.3291', 'E08.3292', 'E08.3293', 'E08.3299', 'E08.3311', 'E08.3312', 'E08.3313', 'E08.3319', 'E08.3391', 'E08.3392', 'E08.3393', 'E08.3399', 'E08.3411', 'E08.3412', 'E08.3413', 'E08.3419']...
                            dx_icd10->dx_icd9 Unrecognised t_codes
                            (169):
                            ['041.41', '041.42', '041.43', '041.49', '173.00', '173.01', '173.02', '173.09', '173.10', '173.11', '173.12', '173.19', '173.2

In [9]:
m4inaptients.size_in_bytes() / 1024 ** 3

0.7889224896207452

In [10]:
m4inaptients_jax = m4inaptients.to_jax_arrays(splits[0])

In [11]:
m4inaptients_jax.size_in_bytes() / 1024 ** 3

0.7730227569118142

In [12]:
len(m4inaptients_jax.subjects)

5092

In [13]:
m4inaptients_jax.n_admissions()

17292

In [14]:
m4inaptients.n_segments()

893563

In [15]:
m4inaptients.n_obs_times()

724696

In [17]:
m4inaptients.p_obs()

0.22646098042029578

In [18]:
# import numpy as np
# import matplotlib.pyplot as plt

# a = m4inaptients.obs_coocurrence_matrix
# a = np.array(a)
# plt.imshow(a, cmap='hot', interpolation='nearest')
# plt.show()

In [19]:
s = m4inaptients_jax.subjects[splits[0][0]]

In [20]:
s

Inpatient(
  subject_id=14825539,
  static_info=InpatientStaticInfo(
    gender='M',
    date_of_birth=Timestamp('2136-01-01 00:00:00'),
    ethnicity=bool[5],
    ethnicity_scheme=<lib.ehr.coding_scheme.MIMIC4Eth5 object at 0x7f5cb5075c40>,
    constant_vec=bool[6]
  ),
  admissions=[
    InpatientAdmission(
      admission_id=24350756,
      admission_dates=(
        Timestamp('2181-01-27 13:21:00'),
        Timestamp('2181-02-02 13:15:00')
      ),
      dx_codes=CodesVector(
        vec=bool[17375],
        scheme=<lib.ehr.coding_scheme.DxICD9 object at 0x7f5cb49bb910>
      ),
      dx_codes_history=CodesVector(
        vec=bool[17375],
        scheme=<lib.ehr.coding_scheme.DxICD9 object at 0x7f5cb49bb910>
      ),
      outcome=CodesVector(
        vec=bool[2081],
        scheme=<lib.ehr.outcome.OutcomeExtractor object at 0x7f5cb507bd60>
      ),
      observables=[
        InpatientObservables(time=f32[0], value=f16[0,60], mask=bool[0,60]),
        InpatientObservables(time=f32[

In [21]:
m4inaptients_jax.outcome_frequency_partitions(5, splits[0])



[Array([   0,    1,    2, ..., 1376, 1618, 1511], dtype=int32),
 Array([2039, 1591, 1989, 1817, 1414, 1419, 1897, 1478, 1424,  374, 1861,
        1402,  382, 1486, 1946, 1405,  380, 1606, 1830,  367, 1592, 1959,
        1596, 1811, 1862, 1885, 1950, 1504, 1466, 1417, 1412, 1422, 1549,
        1392, 1395, 1398, 1616, 1401, 1562, 1404, 1607, 1810], dtype=int32),
 Array([1621, 1516, 1550, 1958, 1489, 1615, 1828, 1408, 1620, 1791, 1894,
        1835, 1961, 1482, 1631], dtype=int32),
 Array([1852, 1962, 2061, 1407, 1386, 1387, 1954], dtype=int32),
 Array([1595, 1490, 1610, 1388], dtype=int32)]

In [22]:
s.admissions[0].observables[4].time

Array([0.65], dtype=float32)

In [23]:
s.static_info

InpatientStaticInfo(
  gender='M',
  date_of_birth=Timestamp('2136-01-01 00:00:00'),
  ethnicity=bool[5],
  ethnicity_scheme=<lib.ehr.coding_scheme.MIMIC4Eth5 object at 0x7f5cb5075c40>,
  constant_vec=bool[6]
)

In [24]:
s.static_info.age(s.admissions[0].admission_dates[0])

45.073237508555785

### التدريب على نموذج المعادلات التفاضلية الاعتيادية العصبية


In [25]:
from lib.ml.in_icenode import InICENODE, InICENODEDimensions
import jax.random as jrandom

- the fwd and bwd functions take an extra `perturbed` argument, which     indicates which primals actually need a gradient. You can use this     to skip computing the gradient for any unperturbed value. (You can     also safely just ignore this if you wish.)
- `None` was previously passed to indicate a symbolic zero gradient for     all objects that weren't inexact arrays, but all inexact arrays     always had an array-valued gradient. Now, `None` may also be passed     to indicate that an inexact array has a symbolic zero gradient.


In [34]:
dims = InICENODEDimensions(state_m=15, 
                state_dx_e=10,
                state_obs_e=25,
                input_e=10,
                proc_e=10,
                demo_e=5,
                int_e=15)
key = jrandom.PRNGKey(0)

m = InICENODE(dims=dims, 
              scheme=m4inpatient_dataset.scheme,
              key=key)

In [37]:
m.batch_predict(m4inaptients_jax, splits[0][:10])

  0%|                                                                           | 1.65/3986.9999999999995 [00:00<01:50, 36.05it/s]

0.06666664282480872 <class 'float'>
-2.3841857821338408e-08 <class 'float'>





XlaRuntimeError: INTERNAL: Generated function failed: CpuCallback error: RuntimeError: Must have (t1 - t0) * dt0 >= 0, we instead got t1 with value Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=2/0)> and type <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>, t0 with value Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=2/0)> and type <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>, dt0 with value Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=2/0)> and type <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>

At:
  /home/asem/GP/env/icenode-dev/lib/python3.9/site-packages/equinox/_errors.py(56): raises
  /home/asem/GP/env/icenode-dev/lib/python3.9/site-packages/jax/_src/callback.py(186): _flat_callback
  /home/asem/GP/env/icenode-dev/lib/python3.9/site-packages/jax/_src/callback.py(46): pure_callback_impl
  /home/asem/GP/env/icenode-dev/lib/python3.9/site-packages/jax/_src/callback.py(108): _callback
  /home/asem/GP/env/icenode-dev/lib/python3.9/site-packages/jax/_src/interpreters/mlir.py(1964): _wrapped_callback
  /home/asem/GP/env/icenode-dev/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py(1229): __call__
  /home/asem/GP/env/icenode-dev/lib/python3.9/site-packages/jax/_src/profiler.py(314): wrapper
  /home/asem/GP/env/icenode-dev/lib/python3.9/site-packages/jax/_src/pjit.py(1148): _pjit_call_impl_python
  /home/asem/GP/env/icenode-dev/lib/python3.9/site-packages/jax/_src/pjit.py(1192): call_impl_cache_miss
  /home/asem/GP/env/icenode-dev/lib/python3.9/site-packages/jax/_src/pjit.py(1209): _pjit_call_impl
  /home/asem/GP/env/icenode-dev/lib/python3.9/site-packages/jax/_src/core.py(821): process_primitive
  /home/asem/GP/env/icenode-dev/lib/python3.9/site-packages/jax/_src/core.py(389): bind_with_trace
  /home/asem/GP/env/icenode-dev/lib/python3.9/site-packages/jax/_src/core.py(2596): bind
  /home/asem/GP/env/icenode-dev/lib/python3.9/site-packages/jax/_src/pjit.py(166): _python_pjit_helper
  /home/asem/GP/env/icenode-dev/lib/python3.9/site-packages/jax/_src/pjit.py(253): cache_miss
  /home/asem/GP/env/icenode-dev/lib/python3.9/site-packages/jax/_src/traceback_util.py(166): reraise_with_filtered_traceback
  /home/asem/GP/env/icenode-dev/lib/python3.9/site-packages/equinox/_jit.py(103): _call
  /home/asem/GP/env/icenode-dev/lib/python3.9/site-packages/equinox/_jit.py(107): __call__
  /home/asem/GP/env/icenode-dev/lib/python3.9/site-packages/equinox/_module.py(522): __call__
  /home/asem/GP/ICENODE/notebooks/mimic_icu/../../lib/ml/in_icenode.py(203): step_segment
  /home/asem/GP/ICENODE/notebooks/mimic_icu/../../lib/ml/in_icenode.py(229): __call__
  /home/asem/GP/ICENODE/notebooks/mimic_icu/../../lib/ml/in_icenode.py(252): batch_predict
  /tmp/ipykernel_104230/2853606103.py(1): <module>
  /home/asem/GP/env/icenode-dev/lib/python3.9/site-packages/IPython/core/interactiveshell.py(3505): run_code
  /home/asem/GP/env/icenode-dev/lib/python3.9/site-packages/IPython/core/interactiveshell.py(3445): run_ast_nodes
  /home/asem/GP/env/icenode-dev/lib/python3.9/site-packages/IPython/core/interactiveshell.py(3266): run_cell_async
  /home/asem/GP/env/icenode-dev/lib/python3.9/site-packages/IPython/core/async_helpers.py(129): _pseudo_sync_runner
  /home/asem/GP/env/icenode-dev/lib/python3.9/site-packages/IPython/core/interactiveshell.py(3061): _run_cell
  /home/asem/GP/env/icenode-dev/lib/python3.9/site-packages/IPython/core/interactiveshell.py(3006): run_cell
  /home/asem/GP/env/icenode-dev/lib/python3.9/site-packages/ipykernel/zmqshell.py(531): run_cell
  /home/asem/GP/env/icenode-dev/lib/python3.9/site-packages/ipykernel/ipkernel.py(411): do_execute
  /home/asem/GP/env/icenode-dev/lib/python3.9/site-packages/ipykernel/kernelbase.py(729): execute_request
  /home/asem/GP/env/icenode-dev/lib/python3.9/site-packages/ipykernel/kernelbase.py(406): dispatch_shell
  /home/asem/GP/env/icenode-dev/lib/python3.9/site-packages/ipykernel/kernelbase.py(499): process_one
  /home/asem/GP/env/icenode-dev/lib/python3.9/site-packages/ipykernel/kernelbase.py(510): dispatch_queue
  /home/asem/GP/env/icenode-dev/lib/python3.9/asyncio/events.py(80): _run
  /home/asem/GP/env/icenode-dev/lib/python3.9/asyncio/base_events.py(1905): _run_once
  /home/asem/GP/env/icenode-dev/lib/python3.9/asyncio/base_events.py(601): run_forever
  /home/asem/GP/env/icenode-dev/lib/python3.9/site-packages/tornado/platform/asyncio.py(215): start
  /home/asem/GP/env/icenode-dev/lib/python3.9/site-packages/ipykernel/kernelapp.py(711): start
  /home/asem/GP/env/icenode-dev/lib/python3.9/site-packages/traitlets/config/application.py(992): launch_instance
  /home/asem/GP/env/icenode-dev/lib/python3.9/site-packages/ipykernel_launcher.py(17): <module>
  /home/asem/GP/env/icenode-dev/lib/python3.9/runpy.py(87): _run_code
  /home/asem/GP/env/icenode-dev/lib/python3.9/runpy.py(197): _run_module_as_main


In [43]:
m4inpatient_dataset.df['adm']['adm_interval'].min()

-22.683333333333334

In [None]:
## Check sorted!
## NOOOOOO, remove subjects with at least one adm_interval of negative values.