In [1]:
%load_ext autoreload
%autoreload 2
import sys
import os
import glob
import random
from collections import defaultdict
from pathlib import Path

from IPython.display import display

import pandas as pd

from tqdm import tqdm
import jax
jax.config.update('jax_platform_name', 'gpu')
# jax.config.update('jax_log_compiles', True)
# jax.config.update("jax_debug_nans", True)
# jax.config.update("jax_enable_x64", True)

In [2]:


sys.path.append("../..")

from lib import utils as U
from lib.ehr.dataset import load_dataset, load_dataset_scheme
from lib.ehr.interface import Patients


In [3]:
# scheme = load_dataset_scheme('M4ICU')

In [4]:
# import logging
# logging.root.level = logging.DEBUG

In [5]:
# from lib.ehr.coding_scheme import MIMIC4Procedures, MIMIC4ProcedureGroups
# from lib.ehr.coding_scheme import MIMIC4Input, MIMIC4InputGroups

# cproc = MIMIC4Procedures()
# cproc_g = MIMIC4ProcedureGroups()
# cinp = MIMIC4Input()
# cinp_g = MIMIC4InputGroups()

In [None]:
# Assign the folder of the dataset to `DATA_FILE`.
import dask

HOME = os.environ.get('HOME')
DATA_DIR = f'{HOME}/GP/ehr-data'
SOURCE_DIR = os.path.abspath("..")

with U.modified_environ(DATA_DIR=DATA_DIR), dask.config.set(scheduler='processes'):
    m4icu_dataset = load_dataset('M4ICU', sample=None)
   

In [None]:
splits = m4icu_dataset.random_splits([0.8, 0.9], random_seed=42, balanced='subjects')

In [None]:
preprocessing = m4icu_dataset.fit_preprocessing(splits[0])

In [None]:
m4icu_dataset.apply_preprocessing(preprocessing)

In [None]:
from lib.ehr.concepts import DemographicVectorConfig

demographic_vector_conf = DemographicVectorConfig(age=True,
                                                  gender=True,
                                                  ethnicity=True)

In [None]:
with dask.config.set(scheduler='processes', num_workers=12):
    m4inpatients = Patients(m4icu_dataset, demographic_vector_conf).load_subjects(num_workers=12)

In [None]:
# m4inpatients.size_in_bytes() / 1024 ** 3

In [None]:
# val_batch = m4inpatients.device_batch(splits[1])

In [None]:
# tst_batch = m4inpatients.device_batch(splits[2])

In [None]:
# val_batch.size_in_bytes() / 1024 ** 3, tst_batch.size_in_bytes() / 1024 ** 3

In [None]:
# batch = m4inpatients.device_batch(splits[0][:32])

In [None]:
# batch.size_in_bytes() / 1024 ** 3

In [None]:
# len(batch.subjects)

In [None]:
# batch.n_admissions()

In [None]:
# batch.n_segments()

In [None]:
# batch.n_obs_times()

In [None]:
# import numpy as np
# import matplotlib.pyplot as plt

# a = m4inpatients_jax.obs_coocurrence_matrix
# a = np.array(a)
# plt.imshow(a, cmap='hot', interpolation='nearest')
# plt.show()

In [None]:
# s = batch.subjects[splits[0][6]].admissions[0]
# s

In [None]:
# batch.interval_hours(splits[0][:10])

### التدريب على نموذج المعادلات التفاضلية الاعتيادية العصبية


In [None]:
import numpy as np
a = np.array([4, 5])
b = np.array([5, 6])

c1 = a.sum
c2 = b.sum

print(id(c1), id(c2))

In [None]:
from lib.ml import (InICENODE, InICENODEDimensions, InpatientEmbeddingDimensions, 
                    InTrainer, TrainerReporting, OptimizerConfig)
from lib.metric import  (CodeAUC, UntilFirstCodeAUC, AdmissionAUC,
                      CodeGroupTopAlarmAccuracy, LossMetric)

import jax.random as jrandom

In [None]:
emb_dims = InpatientEmbeddingDimensions(dx=10, inp=10, proc=10, demo=5, inp_proc_demo=15)
dims = InICENODEDimensions(mem=15, obs=25, emb=emb_dims)
key = jrandom.PRNGKey(0)

m = InICENODE(dims=dims, 
              scheme=m4icu_dataset.scheme,
              demographic_vector_config=demographic_vector_conf,
              key=key)

In [None]:
# res = m.batch_predict(m4inpatients.device_batch(), leave_pbar=True)

In [None]:
trainer = InTrainer(optimizer_config=OptimizerConfig(opt='adam', lr=1e-3),
                    reg_hyperparams=None,
                    epochs=150,
                    batch_size=32)
metrics = [CodeAUC(m4inpatients), UntilFirstCodeAUC(m4inpatients), 
           AdmissionAUC(m4inpatients), LossMetric(m4inpatients) ]

reporting = TrainerReporting(output_dir='inicenode',
                             metrics=metrics,
                            console=True,
                            parameter_snapshots=True,
                            config_json=True)

In [None]:
    
splits = m4inpatients.random_splits([0.9, 0.95], 
                                    balanced='admissions')
res = trainer(m, m4inpatients, 
              splits=splits,
              reporting=reporting)

In [None]:
# import jax.tree_util as jtu
# import jax.numpy as jnp
# import equinox as eqx

# jtu.tree_map(lambda x: f'{x.shape} {jnp.any(jnp.isnan(x)).item()}' if eqx.is_array(x) else None , m)

In [None]:
# emb_subj = {i: m.f_emb(s) for i, s in m4inpatients.device_batch().subjects.items()}