In [None]:
%load_ext autoreload
%autoreload 2
import sys
import os
import glob
import random
from collections import defaultdict
from pathlib import Path

from IPython.display import display

import pandas as pd
import dask

from tqdm import tqdm
import jax
jax.config.update('jax_platform_name', 'cpu')
# jax.config.update('jax_log_compiles', True)
# jax.config.update("jax_debug_nans", True)
# jax.config.update("jax_enable_x64", True)

In [None]:


sys.path.append("../..")

from lib import utils as U
from lib.ehr.dataset import load_dataset, load_dataset_scheme, Dataset
from lib.ehr.interface import Patients
from lib.ehr.concepts import DemographicVectorConfig


In [None]:
# import logging
# logging.root.level = logging.DEBUG

In [None]:
# Assign the folder of the dataset to `DATA_FILE`.

HOME = os.environ.get('HOME')
DATA_DIR = f'{HOME}/GP/ehr-data'
SOURCE_DIR = os.path.abspath("..")
cache_to_disk = 'cached_inteface/patients'#_200'
use_cached =  False #'cached_inteface/patients' #cache_to_disk        # False # 


In [None]:
if use_cached:
    m3patients = Patients.load(use_cached)
    splits = m3patients.dataset.random_splits([0.8, 0.9], random_seed=42, balanced='admissions')

else:
    with U.modified_environ(DATA_DIR=DATA_DIR), dask.config.set(scheduler='processes', num_workers=12):
        
        # Load dataset
        m3_dataset = load_dataset('M3', sample=None)
        # Use training-split for fitting the outlier_remover and the scalers.
        splits = m3_dataset.random_splits([0.8, 0.9], random_seed=42, balanced='admissions')
        
        # Demographic vector attributes
        demographic_vector_conf = DemographicVectorConfig(age=True, gender=True, ethnicity=True)
        # Load interface
        m3patients = Patients(m3_dataset, demographic_vector_conf).load_subjects(num_workers=12)

        # Cache to disk
        m3patients.save(cache_to_disk, overwrite=True)

In [None]:
len(m3patients.subjects)

In [None]:
# m3patients.size_in_bytes() / 1024 ** 3

In [None]:
# val_batch = m3patients.device_batch(splits[1])

In [None]:
# tst_batch = m3patients.device_batch(splits[2])

In [None]:
# val_batch.size_in_bytes() / 1024 ** 3, tst_batch.size_in_bytes() / 1024 ** 3

In [None]:
# batch = m3patients.device_batch(splits[0][:32])

In [None]:
# batch.size_in_bytes() / 1024 ** 3

In [None]:
# len(batch.subjects)

In [None]:
# batch.n_admissions()

In [None]:
# batch.n_segments()

In [None]:
# batch.n_obs_times()

In [None]:
# s = batch.subjects[splits[0][6]].admissions[0]

### التدريب على نموذج المعادلات التفاضلية الاعتيادية العصبية


In [None]:
from lib.ml import (ICENODE, ICENODEDimensions, PatientEmbeddingDimensions, 
                    Trainer, TrainerReporting, OptimizerConfig, WarmupConfig)
from lib.metric import  (CodeAUC, UntilFirstCodeAUC, AdmissionAUC,
                      CodeGroupTopAlarmAccuracy, LossMetric, ObsCodeLevelLossMetric)

import jax.random as jrandom

In [None]:
emb_dims = PatientEmbeddingDimensions(dx=10, demo=5)
dims = ICENODEDimensions(mem=15, emb=emb_dims)
key = jrandom.PRNGKey(0)

m = ICENODE(dims=dims, 
              scheme=m3patients.dataset.scheme,
              demographic_vector_config=m3patients.demographic_vector_config,
              key=key)

In [None]:
# res = m.batch_predict(m4inpatients.device_batch(), leave_pbar=True)

In [None]:
trainer = Trainer(optimizer_config=OptimizerConfig(opt='adam', lr=1e-3),
                    reg_hyperparams=None,
                    epochs=20,
                    batch_size=512,
                    dx_loss='allpairs_sigmoid_rank')

warmup = WarmupConfig(epochs=0.1, 
                      batch_size=8,
                      opt='adam', lr=1e-3, 
                      decay_rate=0.5)

loss_metric =  LossMetric(m3patients, 
                          dx_loss=('softmax_bce', 'balanced_focal_softmax_bce', 
                                   'balanced_focal_bce', 'allpairs_sigmoid_rank'))

metrics = [#CodeAUC(m3patients), AdmissionAUC(m3patients),   
    loss_metric]


reporting = TrainerReporting(output_dir='dx_icenode',
                             metrics=metrics,
                             console=True,
                             parameter_snapshots=True,
                             config_json=True)

In [None]:
splits = m3patients.random_splits([0.9, 0.95], 
                                    balanced='admissions')
res = trainer(m, m3patients, 
              splits=splits,
              reporting=reporting,
              n_evals=100,
              warmup_config=warmup)