In [1]:
%load_ext autoreload
%autoreload 2
import sys
import os
import glob
import random
from collections import defaultdict
from pathlib import Path

from IPython.display import display

import pandas as pd
import dask

from tqdm import tqdm
import jax
jax.config.update('jax_platform_name', 'cpu')
# jax.config.update('jax_log_compiles', True)
# jax.config.update("jax_debug_nans", True)
# jax.config.update("jax_enable_x64", True)

In [2]:


sys.path.append("../..")

from lib import utils as U
from lib.ehr.dataset import load_dataset, load_dataset_scheme, Dataset
from lib.ehr.interface import Patients
from lib.ehr.concepts import DemographicVectorConfig


In [3]:
import logging
logging.root.level = logging.INFO

In [4]:
# Assign the folder of the dataset to `DATA_FILE`.

HOME = os.environ.get('HOME')
DATA_DIR = f'{HOME}/GP/ehr-data'
SOURCE_DIR = os.path.abspath("..")
cache_to_disk = 'cached_inteface/patients'#_200'
use_cached =  False #'cached_inteface/patients' #cache_to_disk        # False # 


In [5]:
if use_cached:
    m3patients = Patients.load(use_cached)
    splits = m3patients.dataset.random_splits([0.8, 0.9], random_seed=42, balanced='admissions')

else:
    with U.modified_environ(DATA_DIR=DATA_DIR), dask.config.set(scheduler='processes', num_workers=12):
        
        # Load dataset
        m3_dataset = load_dataset('M3', sample=None)
        # Use training-split for fitting the outlier_remover and the scalers.
        splits = m3_dataset.random_splits([0.8, 0.9], random_seed=42, balanced='admissions')
        
        # Demographic vector attributes
        demographic_vector_conf = DemographicVectorConfig(age=True, gender=True, ethnicity=True)
        # Load interface
        m3patients = Patients(m3_dataset, demographic_vector_conf).load_subjects(num_workers=12)

        # Cache to disk
        m3patients.save(cache_to_disk, overwrite=True)

Please ensure that each individual file can fit in memory and
use the keyword ``blocksize=None to remove this message``
Setting ``blocksize=None``
  warn(
Please ensure that each individual file can fit in memory and
use the keyword ``blocksize=None to remove this message``
Setting ``blocksize=None``
  warn(
Please ensure that each individual file can fit in memory and
use the keyword ``blocksize=None to remove this message``
Setting ``blocksize=None``
  warn(

Mapping converts codes that are not supported by the source scheme.|M-domain - S|=0; |M-domain - S|/|M-domain|=0.00); first5(M-domain - S)=[]

Source codes that not covered by the mapping. |S - M-domain|=2497; |S - M-domain|/|S|=0.14; first5(S - M-domain)=['001', '002', '003', '003.2', '004']

|M-domain|=14878; first5(M-domain) ['001.0', '001.1', '001.9', '002.0', '002.1'].

|S|=17375; first5(S)=['001', '001.0', '001.1', '001.9', '002']
        

Mapping converts codes that are not supported by the source scheme.|M-domain - S|=0

In [6]:
len(m3patients.subjects)

7514

In [7]:
# m3patients.size_in_bytes() / 1024 ** 3

In [8]:
# val_batch = m3patients.device_batch(splits[1])

In [9]:
# tst_batch = m3patients.device_batch(splits[2])

In [10]:
# val_batch.size_in_bytes() / 1024 ** 3, tst_batch.size_in_bytes() / 1024 ** 3

In [11]:
# batch = m3patients.device_batch(splits[0][:32])

In [12]:
# batch.size_in_bytes() / 1024 ** 3

In [13]:
# len(batch.subjects)

In [14]:
# batch.n_admissions()

In [15]:
# batch.n_segments()

In [16]:
# batch.n_obs_times()

In [17]:
# s = batch.subjects[splits[0][6]].admissions[0]

### التدريب على نموذج المعادلات التفاضلية الاعتيادية العصبية


In [18]:
from lib.ml import (ICENODE, ICENODEDimensions, PatientEmbeddingDimensions, 
                    Trainer, TrainerReporting, OptimizerConfig, WarmupConfig)
from lib.metric import  (CodeAUC, UntilFirstCodeAUC, AdmissionAUC,
                      CodeGroupTopAlarmAccuracy, LossMetric, ObsCodeLevelLossMetric)

import jax.random as jrandom

In [19]:
emb_dims = PatientEmbeddingDimensions(dx=10, demo=5)
dims = ICENODEDimensions(mem=15, emb=emb_dims)
key = jrandom.PRNGKey(0)

m = ICENODE(dims=dims, 
              scheme=m3patients.dataset.scheme,
              demographic_vector_config=m3patients.demographic_vector_config,
              key=key)

INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: Interpreter CUDA
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'


In [20]:
# res = m.batch_predict(m4inpatients.device_batch(), leave_pbar=True)

In [21]:
splits = m3patients.random_splits([0.9, 0.95], 
                                    balanced='admissions')

trainer = Trainer(optimizer_config=OptimizerConfig(opt='adam', lr=1e-3),
                    reg_hyperparams=None,
                    epochs=80,
                    batch_size=256,
                    dx_loss='allpairs_sigmoid_rank')

warmup = WarmupConfig(epochs=0.1, 
                      batch_size=8,
                      opt='adam', lr=1e-3, 
                      decay_rate=0.5)

loss_metric =  LossMetric(m3patients, 
                          dx_loss=('softmax_bce', 'balanced_focal_softmax_bce', 
                                   'balanced_focal_bce', 'allpairs_exp_rank', 'allpairs_hard_rank', 
                                   'allpairs_sigmoid_rank'))

metrics = [CodeAUC(m3patients), 
           AdmissionAUC(m3patients), 
           CodeGroupTopAlarmAccuracy(m3patients, n_partitions=5, 
                                     top_k_list=[3, 5, 10, 15, 20],
                                     train_split=splits[0]),
           loss_metric]


reporting = TrainerReporting(output_dir='dx_icenode',
                             metrics=metrics,
                             console=False,
                             parameter_snapshots=True,
                             config_json=True)

In [None]:

res = trainer(m, m3patients, 
              splits=splits,
              reporting=reporting,
              n_evals=100,
              warmup_config=warmup,
             continue_training=False
             )

INFO:root:Continuing training from step 17
INFO:root:HPs: {'opt_config': {'opt': 'adam', 'lr': 0.001, 'decay_rate': None, 'reverse_schedule': False}, 'reg_hyperparams': None, 'epochs': 20, 'batch_size': 512}


Loading to device:   0%|          | 0/360 [00:00<?, ?subject/s]

  0%|          | 0/20 [00:00<?, ?Epoch/s]

  0%|          | 0/21 [00:00<?, ?Batch/s]

Loading to device:   0%|          | 0/341 [00:00<?, ?subject/s]

Loading to device:   0%|          | 0/316 [00:00<?, ?subject/s]

Loading to device:   0%|          | 0/335 [00:00<?, ?subject/s]

Loading to device:   0%|          | 0/299 [00:00<?, ?subject/s]

Loading to device:   0%|          | 0/333 [00:00<?, ?subject/s]

Loading to device:   0%|          | 0/274 [00:00<?, ?subject/s]

Loading to device:   0%|          | 0/331 [00:00<?, ?subject/s]

Loading to device:   0%|          | 0/327 [00:00<?, ?subject/s]

Loading to device:   0%|          | 0/334 [00:00<?, ?subject/s]

Loading to device:   0%|          | 0/321 [00:00<?, ?subject/s]

Loading to device:   0%|          | 0/305 [00:00<?, ?subject/s]

Loading to device:   0%|          | 0/340 [00:00<?, ?subject/s]

Loading to device:   0%|          | 0/365 [00:00<?, ?subject/s]

Loading to device:   0%|          | 0/338 [00:00<?, ?subject/s]

Loading to device:   0%|          | 0/300 [00:00<?, ?subject/s]

Loading to device:   0%|          | 0/342 [00:00<?, ?subject/s]

Loading to device:   0%|          | 0/312 [00:00<?, ?subject/s]

Loading to device:   0%|          | 0/317 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/317 [00:00<?, ?subject/s]

  0%|          | 0.00/213482.04 [00:00<?, ?odeint-days/s]

Loading to device:   0%|          | 0/341 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/341 [00:00<?, ?subject/s]

  0%|          | 0.00/220950.33 [00:00<?, ?odeint-days/s]

Loading to device:   0%|          | 0/301 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/301 [00:00<?, ?subject/s]

  0%|          | 0.00/218839.46 [00:00<?, ?odeint-days/s]

Loading to device:   0%|          | 0/297 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/297 [00:00<?, ?subject/s]

  0%|          | 0.00/209441.21 [00:00<?, ?odeint-days/s]

Embedding:   0%|          | 0/297 [00:00<?, ?subject/s]

  0%|          | 0.00/209441.21 [00:00<?, ?odeint-days/s]

Embedding:   0%|          | 0/360 [00:00<?, ?subject/s]

  0%|          | 0.00/250814.31 [00:00<?, ?odeint-days/s]

INFO:root:train    CodeAUC.I0C1.auc  CodeAUC.I0C1.n  CodeAUC.I1C10.auc  CodeAUC.I1C10.n  \
21          0.285448               2            0.45341               15   

    CodeAUC.I2C100.auc  CodeAUC.I2C100.n  CodeAUC.I3C101.auc  \
21            0.486886                36            0.565723   

    CodeAUC.I3C101.n  CodeAUC.I4C102.auc  CodeAUC.I4C102.n  ...  G1k20.acc  \
21               177            0.622558                 9  ...   0.158602   

    G2k20.acc  G3k20.acc  G4k20.acc  LossMetric.allpairs_exp_rank  \
21   0.233365   0.395122    0.48329                     3.7896497   

    LossMetric.allpairs_hard_rank  LossMetric.allpairs_sigmoid_rank  \
21                    0.073687896                        0.28351343   

    LossMetric.balanced_focal_bce  LossMetric.balanced_focal_softmax_bce  \
21                     0.20307705                              0.7666606   

    LossMetric.softmax_bce  
21                8.719011  

[1 rows x 525 columns]
INFO:root:val    CodeAUC.I0C1

  0%|          | 0/21 [00:00<?, ?Batch/s]

Loading to device:   0%|          | 0/334 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/334 [00:00<?, ?subject/s]

  0%|          | 0.00/217016.32 [00:00<?, ?odeint-days/s]

Loading to device:   0%|          | 0/319 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/319 [00:00<?, ?subject/s]

  0%|          | 0.00/204529.69 [00:00<?, ?odeint-days/s]

Loading to device:   0%|          | 0/321 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/321 [00:00<?, ?subject/s]

  0%|          | 0.00/218744.45 [00:00<?, ?odeint-days/s]

Loading to device:   0%|          | 0/323 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/323 [00:00<?, ?subject/s]

  0%|          | 0.00/223033.58 [00:00<?, ?odeint-days/s]

Loading to device:   0%|          | 0/346 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/346 [00:00<?, ?subject/s]

  0%|          | 0.00/217654.21 [00:00<?, ?odeint-days/s]

Embedding:   0%|          | 0/346 [00:00<?, ?subject/s]

  0%|          | 0.00/217654.21 [00:00<?, ?odeint-days/s]

Embedding:   0%|          | 0/360 [00:00<?, ?subject/s]

  0%|          | 0.00/250814.31 [00:00<?, ?odeint-days/s]

INFO:root:train    CodeAUC.I0C1.auc  CodeAUC.I0C1.n  CodeAUC.I1C10.auc  CodeAUC.I1C10.n  \
21          0.285448               2           0.453410               15   
26          0.685728               4           0.531379               14   

    CodeAUC.I2C100.auc  CodeAUC.I2C100.n  CodeAUC.I3C101.auc  \
21            0.486886                36            0.565723   
26            0.538402                51            0.556575   

    CodeAUC.I3C101.n  CodeAUC.I4C102.auc  CodeAUC.I4C102.n  ...  G1k20.acc  \
21               177            0.622558                 9  ...   0.158602   
26               184            0.583019                 3  ...   0.173010   

    G2k20.acc  G3k20.acc  G4k20.acc  LossMetric.allpairs_exp_rank  \
21   0.233365   0.395122    0.48329                     3.7896497   
26   0.225184   0.417081    0.61868                      3.432319   

    LossMetric.allpairs_hard_rank  LossMetric.allpairs_sigmoid_rank  \
21                    0.073687896                

Loading to device:   0%|          | 0/354 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/354 [00:00<?, ?subject/s]

  0%|          | 0.00/219958.34 [00:00<?, ?odeint-days/s]

Loading to device:   0%|          | 0/313 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/313 [00:00<?, ?subject/s]

  0%|          | 0.00/215509.48 [00:00<?, ?odeint-days/s]

Loading to device:   0%|          | 0/324 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/324 [00:00<?, ?subject/s]

  0%|          | 0.00/220829.46 [00:00<?, ?odeint-days/s]

Loading to device:   0%|          | 0/349 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/349 [00:00<?, ?subject/s]

  0%|          | 0.00/239650.63 [00:00<?, ?odeint-days/s]

Embedding:   0%|          | 0/349 [00:00<?, ?subject/s]

  0%|          | 0.00/239650.63 [00:00<?, ?odeint-days/s]

Embedding:   0%|          | 0/360 [00:00<?, ?subject/s]

  0%|          | 0.00/250814.31 [00:00<?, ?odeint-days/s]

INFO:root:train    CodeAUC.I0C1.auc  CodeAUC.I0C1.n  CodeAUC.I1C10.auc  CodeAUC.I1C10.n  \
21          0.285448               2           0.453410               15   
26          0.685728               4           0.531379               14   
30          0.573585               2           0.684637                8   

    CodeAUC.I2C100.auc  CodeAUC.I2C100.n  CodeAUC.I3C101.auc  \
21            0.486886                36            0.565723   
26            0.538402                51            0.556575   
30            0.527062                32            0.539939   

    CodeAUC.I3C101.n  CodeAUC.I4C102.auc  CodeAUC.I4C102.n  ...  G1k20.acc  \
21               177            0.622558                 9  ...   0.158602   
26               184            0.583019                 3  ...   0.173010   
30               167            0.263653                 1  ...   0.154930   

    G2k20.acc  G3k20.acc  G4k20.acc  LossMetric.allpairs_exp_rank  \
21   0.233365   0.395122   0.483290      

Loading to device:   0%|          | 0/305 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/305 [00:00<?, ?subject/s]

  0%|          | 0.00/202374.73 [00:00<?, ?odeint-days/s]

Loading to device:   0%|          | 0/298 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/298 [00:00<?, ?subject/s]

  0%|          | 0.00/235650.37 [00:00<?, ?odeint-days/s]

Loading to device:   0%|          | 0/310 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/310 [00:00<?, ?subject/s]

  0%|          | 0.00/222101.04 [00:00<?, ?odeint-days/s]

Loading to device:   0%|          | 0/304 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/304 [00:00<?, ?subject/s]

  0%|          | 0.00/216490.19 [00:00<?, ?odeint-days/s]

Loading to device:   0%|          | 0/312 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/312 [00:00<?, ?subject/s]

  0%|          | 0.00/220808.95 [00:00<?, ?odeint-days/s]

Embedding:   0%|          | 0/312 [00:00<?, ?subject/s]

  0%|          | 0.00/220808.95 [00:00<?, ?odeint-days/s]

Embedding:   0%|          | 0/360 [00:00<?, ?subject/s]

  0%|          | 0.00/250814.31 [00:00<?, ?odeint-days/s]

INFO:root:train    CodeAUC.I0C1.auc  CodeAUC.I0C1.n  CodeAUC.I1C10.auc  CodeAUC.I1C10.n  \
21          0.285448               2           0.453410               15   
26          0.685728               4           0.531379               14   
30          0.573585               2           0.684637                8   
35          0.092979               1           0.529750                7   

    CodeAUC.I2C100.auc  CodeAUC.I2C100.n  CodeAUC.I3C101.auc  \
21            0.486886                36            0.565723   
26            0.538402                51            0.556575   
30            0.527062                32            0.539939   
35            0.528842                43            0.521657   

    CodeAUC.I3C101.n  CodeAUC.I4C102.auc  CodeAUC.I4C102.n  ...  G1k20.acc  \
21               177            0.622558                 9  ...   0.158602   
26               184            0.583019                 3  ...   0.173010   
30               167            0.263653         

Loading to device:   0%|          | 0/321 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/321 [00:00<?, ?subject/s]

  0%|          | 0.00/226895.85 [00:00<?, ?odeint-days/s]

Loading to device:   0%|          | 0/314 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/314 [00:00<?, ?subject/s]

  0%|          | 0.00/230132.74 [00:00<?, ?odeint-days/s]

Loading to device:   0%|          | 0/346 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/346 [00:00<?, ?subject/s]

  0%|          | 0.00/252597.28 [00:00<?, ?odeint-days/s]

Loading to device:   0%|          | 0/316 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/316 [00:00<?, ?subject/s]

In [None]:
from lib.ml.trainer import TrainerSignals

In [None]:
signals = TrainerSignals()

In [None]:
signals.start_training.receivers