In [1]:
%load_ext autoreload
%autoreload 2
import sys
import os
import glob
import random
from collections import defaultdict
from pathlib import Path

from IPython.display import display

import pandas as pd

from tqdm import tqdm
import jax
jax.config.update('jax_platform_name', 'gpu')
# jax.config.update('jax_log_compiles', True)
# jax.config.update("jax_debug_nans", True)
# jax.config.update("jax_enable_x64", True)

In [2]:


sys.path.append("../..")

from lib import utils as U
from lib.ehr.dataset import load_dataset, load_dataset_scheme
from lib.ehr.interface import Patients


In [3]:
# scheme = load_dataset_scheme('M4ICU')

In [4]:
import logging
logging.root.level = logging.DEBUG

In [5]:
# from lib.ehr.coding_scheme import MIMIC4Procedures, MIMIC4ProcedureGroups
# from lib.ehr.coding_scheme import MIMIC4Input, MIMIC4InputGroups

# cproc = MIMIC4Procedures()
# cproc_g = MIMIC4ProcedureGroups()
# cinp = MIMIC4Input()
# cinp_g = MIMIC4InputGroups()

In [6]:
# Assign the folder of the dataset to `DATA_FILE`.
import dask

HOME = os.environ.get('HOME')
DATA_DIR = f'{HOME}/GP/ehr-data'
SOURCE_DIR = os.path.abspath("..")

with U.modified_environ(DATA_DIR=DATA_DIR), dask.config.set(scheduler='processes'):
    m4icu_dataset = load_dataset('M4ICU', sample=10000)
   

DEBUG:root:Loading dataframe files
DEBUG:fsspec.local:open file: /home/asem/GP/ehr-data/mimic4icu-cohort/adm_df.csv
DEBUG:fsspec.local:open file: /home/asem/GP/ehr-data/mimic4icu-cohort/dx_df.csv
DEBUG:fsspec.local:open file: /home/asem/GP/ehr-data/mimic4icu-cohort/static_df.csv
DEBUG:fsspec.local:open file: /home/asem/GP/ehr-data/mimic4icu-cohort/obs_df.csv
DEBUG:fsspec.local:open file: /home/asem/GP/ehr-data/mimic4icu-cohort/int_input.csv
DEBUG:fsspec.local:open file: /home/asem/GP/ehr-data/mimic4icu-cohort/int_proc.csv
DEBUG:root:[DONE] Loading dataframe files
DEBUG:root:Preprocess admissions
DEBUG:root:Removing subjects with at least one negative adm_interval: 18
DEBUG:root:adm: Merging overlapping admissions
DEBUG:root:adm: Merged 84 overlapping admissions
DEBUG:root:[DONE] Preprocess admissions
DEBUG:root:Matching admission_id
DEBUG:root:[DONE] Matching admission_id
DEBUG:root:Time casting..
DEBUG:root:[DONE] Time casting..
DEBUG:root:Constructing dx_icd9 (<class 'lib.ehr.coding_

In [7]:
splits = m4icu_dataset.random_splits([0.8, 0.9], random_seed=42, balanced='admissions')

In [8]:
outlier_remover = m4icu_dataset.fit_outlier_remover(splits[0])
m4icu_dataset.remove_outliers(outlier_remover)

DEBUG:root:Removed 472034 (0.024) outliers from obs


In [9]:
scalers = m4icu_dataset.fit_scalers(splits[0])
m4icu_dataset.apply_scalers(scalers)

In [10]:
from lib.ehr.concepts import DemographicVectorConfig

demographic_vector_conf = DemographicVectorConfig(age=True,
                                                  gender=True,
                                                  ethnicity=True)

In [11]:
with dask.config.set(scheduler='processes', num_workers=12):
    m4inpatients = Patients(m4icu_dataset, demographic_vector_conf).load_subjects(num_workers=12)

  dob = anchor_date + anchor_age
DEBUG:root:Constructing mimic4_eth32 (<class 'lib.ehr.coding_scheme.MIMICEth32'>) scheme
DEBUG:root:Constructing mimic4_eth5 (<class 'lib.ehr.coding_scheme.MIMICEth5'>) scheme
DEBUG:root:Extracting dx codes...
DEBUG:root:Constructing dx_icd10 (<class 'lib.ehr.coding_scheme.DxICD10'>) scheme
DEBUG:root:Constructing dx_icd9 (<class 'lib.ehr.coding_scheme.DxICD9'>) scheme
                            dx_icd10->dx_icd9 Unrecognised t_codes
                            (169):
                            ['041.41', '041.42', '041.43', '041.49', '173.00', '173.01', '173.02', '173.09', '173.10', '173.11', '173.12', '173.19', '173.20', '173.21', '173.22', '173.29', '173.30', '173.31', '173.32', '173.39']...
                            dx_icd10->dx_icd9 Unrecognised s_codes
                            (49910):
                            ['E08.3211', 'E08.3212', 'E08.3213', 'E08.3219', 'E08.3291', 'E08.3292', 'E08.3293', 'E08.3299', 'E08.3311', 'E08.3312', 'E08.331

In [None]:
m4inpatients.save('m4inpatients.pickle')

In [None]:
l_m4inpatients = Patients.load('m4inpatients.pickle')

In [14]:
len(m4inpatients.subjects)

9981

In [12]:
# m4inpatients.size_in_bytes() / 1024 ** 3

In [13]:
# val_batch = m4inpatients.device_batch(splits[1])

In [14]:
# tst_batch = m4inpatients.device_batch(splits[2])

In [15]:
# val_batch.size_in_bytes() / 1024 ** 3, tst_batch.size_in_bytes() / 1024 ** 3

In [16]:
# batch = m4inpatients.device_batch(splits[0][:32])

In [17]:
# batch.size_in_bytes() / 1024 ** 3

In [18]:
# len(batch.subjects)

In [19]:
# batch.n_admissions()

In [20]:
# batch.n_segments()

In [21]:
# batch.n_obs_times()

In [22]:
# import numpy as np
# import matplotlib.pyplot as plt

# a = m4inpatients_jax.obs_coocurrence_matrix
# a = np.array(a)
# plt.imshow(a, cmap='hot', interpolation='nearest')
# plt.show()

In [23]:
# s = batch.subjects[splits[0][6]].admissions[0]
# s

In [24]:
# batch.interval_hours(splits[0][:10])

### التدريب على نموذج المعادلات التفاضلية الاعتيادية العصبية


In [25]:
import numpy as np
a = np.array([4, 5])
b = np.array([5, 6])

c1 = a.sum
c2 = b.sum

print(id(c1), id(c2))

139749651832640 139760821493824


In [17]:
from lib.ml import (InICENODE, InICENODEDimensions, InpatientEmbeddingDimensions, 
                    InTrainer, TrainerReporting, OptimizerConfig)
from lib.metric import  (CodeAUC, UntilFirstCodeAUC, AdmissionAUC,
                      CodeGroupTopAlarmAccuracy, LossMetric, ObsCodeLevelLossMetric)

import jax.random as jrandom

DEBUG:matplotlib:matplotlib data path: /home/asem/GP/env/icenode-dev/lib/python3.9/site-packages/matplotlib/mpl-data
DEBUG:matplotlib:CONFIGDIR=/home/asem/.config/matplotlib
DEBUG:matplotlib:interactive is False
DEBUG:matplotlib:platform is linux
DEBUG:matplotlib:CACHEDIR=/home/asem/.cache/matplotlib
DEBUG:matplotlib.font_manager:Using fontManager instance from /home/asem/.cache/matplotlib/fontlist-v330.json


In [27]:
emb_dims = InpatientEmbeddingDimensions(dx=10, inp=10, proc=10, demo=5, inp_proc_demo=15)
dims = InICENODEDimensions(mem=15, obs=25, emb=emb_dims)
key = jrandom.PRNGKey(0)

m = InICENODE(dims=dims, 
              scheme=m4icu_dataset.scheme,
              demographic_vector_config=demographic_vector_conf,
              key=key)

In [28]:
# res = m.batch_predict(m4inpatients.device_batch(), leave_pbar=True)

In [18]:
trainer = InTrainer(optimizer_config=OptimizerConfig(opt='adam', lr=1e-3),
                    reg_hyperparams=None,
                    epochs=150,
                    batch_size=32)
loss_metric =  LossMetric(m4inpatients, 
                          dx_loss=('softmax_bce', 'balanced_focal_softmax_bce', 'balanced_focal_bce'),
                          obs_loss=('mse', 'rms', 'mae'))
obs_code_loss_metric = ObsCodeLevelLossMetric(m4inpatients, obs_loss=('mse', 'rms', 'mae'))

metrics = [CodeAUC(m4inpatients), AdmissionAUC(m4inpatients), obs_code_loss_metric, loss_metric]

reporting = TrainerReporting(output_dir='inicenode',
                             metrics=metrics,
                             console=True,
                             parameter_snapshots=True,
                             config_json=True)

NameError: name 'm4inpatients' is not defined

In [32]:
    
splits = m4inpatients.random_splits([0.9, 0.95], 
                                    balanced='admissions')
res = trainer(m, m4inpatients, 
              splits=splits,
              reporting=reporting)

Loading to device:   0%|          | 0/516 [00:00<?, ?subject/s]

  0%|          | 0/150 [00:00<?, ?Epoch/s]

  0%|          | 0/5326 [00:00<?, ?Batch/s]

Loading to device:   0%|          | 0/5 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/5 [00:00<?, ?subject/s]

  0%|          | 0.00/113.04 [00:00<?, ?odeint-days/s]

Embedding:   0%|          | 0/5 [00:00<?, ?subject/s]

  0%|          | 0.00/113.04 [00:00<?, ?odeint-days/s]

Embedding:   0%|          | 0/516 [00:00<?, ?subject/s]

  0%|          | 0.00/10606.41 [00:00<?, ?odeint-days/s]



Loading to device:   0%|          | 0/10 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/10 [00:00<?, ?subject/s]

  0%|          | 0.00/180.51 [00:00<?, ?odeint-days/s]

Loading to device:   0%|          | 0/8 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/8 [00:00<?, ?subject/s]

  0%|          | 0.00/179.42 [00:00<?, ?odeint-days/s]

Embedding:   0%|          | 0/8 [00:00<?, ?subject/s]

  0%|          | 0.00/179.42 [00:00<?, ?odeint-days/s]

Embedding:   0%|          | 0/516 [00:00<?, ?subject/s]

  0%|          | 0.00/10606.41 [00:00<?, ?odeint-days/s]

KeyboardInterrupt: 

In [None]:
# import jax.tree_util as jtu
# import jax.numpy as jnp
# import equinox as eqx

# jtu.tree_map(lambda x: f'{x.shape} {jnp.any(jnp.isnan(x)).item()}' if eqx.is_array(x) else None , m)

In [None]:
# emb_subj = {i: m.f_emb(s) for i, s in m4inpatients.device_batch().subjects.items()}