In [1]:
%load_ext autoreload
%autoreload 2
import sys
import os
import glob
import random
from collections import defaultdict
from pathlib import Path

from IPython.display import display

import pandas as pd

from tqdm import tqdm
import jax
jax.config.update('jax_platform_name', 'gpu')

In [2]:


sys.path.append("../..")

from lib import utils as U
from lib.ehr.dataset import load_dataset
from lib.ehr.inpatient_interface import Inpatients

In [3]:
import logging
logging.root.level = logging.DEBUG

In [4]:
from lib.ehr.coding_scheme import MIMIC4Procedures, MIMIC4ProcedureGroups
from lib.ehr.coding_scheme import MIMIC4Input, MIMIC4InputGroups

cproc = MIMIC4Procedures()
cproc_g = MIMIC4ProcedureGroups()
cinp = MIMIC4Input()
cinp_g = MIMIC4InputGroups()

In [5]:
cinp_g.groups


{'Abciximab': {'Abciximab (Reopro)'},
 'Acetaminophen-IV': {'Acetaminophen-IV'},
 'Acetylcysteine': {'Acetylcysteine'},
 'Adenosine': {'Adenosine'},
 'Alteplase (TPA)': {'Alteplase (TPA)'},
 'Ambisome': {'Ambisome'},
 'Amikacin': {'Amikacin'},
 'Amino Acids': {'Amino Acids'},
 'Aminocaproic acid (Amicar)': {'Aminocaproic acid (Amicar)'},
 'Aminophylline': {'Aminophylline'},
 'Amiodarone': {'Amiodarone', 'Amiodarone 450/250', 'Amiodarone 600/500'},
 'Ampicillin': {'Ampicillin'},
 'Ampicillin/Sulbactam (Unasyn)': {'Ampicillin/Sulbactam (Unasyn)'},
 'Angiotensin II (Giapreza)': {'Angiotensin II (Giapreza)'},
 'Argatroban': {'Argatroban'},
 'Atovaquone': {'Atovaquone'},
 'Atropine': {'Atropine'},
 'Azithromycin': {'Azithromycin', 'Erythromycin'},
 'Aztreonam': {'Aztreonam'},
 'Bactrim (SMX/TMP)': {'Bactrim (SMX/TMP)'},
 'Beneprotein': {'Beneprotein', 'Beneprotein.'},
 'Bivalirudin (Angiomax)': {'Bivalirudin (Angiomax)',
  'Bivalirudin (Angiomax) (Impella)'},
 'Boost Glucose Control': {'Boo

In [None]:
# Assign the folder of the dataset to `DATA_FILE`.
import dask

HOME = os.environ.get('HOME')
DATA_DIR = f'{HOME}/GP/ehr-data'
SOURCE_DIR = os.path.abspath("..")

with U.modified_environ(DATA_DIR=DATA_DIR), dask.config.set(scheduler='processes'):
    m4inpatient_dataset = load_dataset('M4ICU')
   

In [None]:
splits = m4inpatient_dataset.random_splits([0.1, 0.7], 42, 'subjects')



In [None]:
preprocessing = m4inpatient_dataset.fit_preprocessing(splits[0])

In [None]:

m4inpatient_dataset.apply_preprocessing(preprocessing)

In [None]:
# from concurrent.futures import ThreadPoolExecutor
# with dask.config.set(pool=ThreadPoolExecutor(12)):
with dask.config.set(scheduler='processes', num_workers=12):
    m4inpatients = Inpatients(m4inpatient_dataset, splits[0][:10], num_workers=12)

In [None]:
m4inpatients.size_in_bytes() / 1024 ** 3

In [None]:
m4inpatients_jax = m4inpatients.to_jax_arrays(splits[0][:10])

In [None]:
m4inpatients_jax.size_in_bytes() / 1024 ** 3

In [None]:
len(m4inpatients_jax.subjects)

In [None]:
m4inpatients_jax.n_admissions()

In [None]:
m4inpatients_jax.n_segments()

In [None]:
m4inpatients_jax.n_obs_times()

In [None]:
# import numpy as np
# import matplotlib.pyplot as plt

# a = m4inpatients_jax.obs_coocurrence_matrix
# a = np.array(a)
# plt.imshow(a, cmap='hot', interpolation='nearest')
# plt.show()

In [None]:
s = m4inpatients_jax.subjects[splits[0][6]].admissions[0].interventions.input_
s

In [None]:
m4inpatients_jax.interval_hours(splits[0][:10])

In [None]:
s

### التدريب على نموذج المعادلات التفاضلية الاعتيادية العصبية


In [None]:
from lib.ml.in_icenode import InICENODE, InICENODEDimensions
import jax.random as jrandom

In [None]:
dims = InICENODEDimensions(state_m=15, 
                state_dx_e=10,
                state_obs_e=25,
                input_e=10,
                proc_e=10,
                demo_e=5,
                int_e=15)
key = jrandom.PRNGKey(0)

m = InICENODE(dims=dims, 
              scheme=m4inpatient_dataset.scheme,
              key=key)

In [None]:
res = m.batch_predict(m4inpatients_jax, splits[0][:10])

In [None]:
m4inpatients_jax.subjects[splits[0][0]].admissions[0].interventions

In [None]:
m.f_emb.f_inp_agg.splits

In [None]:
m4inpatients_jax.subjects[splits[0][0]].admissions[0].admission_dates