In [1]:
%load_ext autoreload
%autoreload 2
import sys
import os
import glob
import random
from collections import defaultdict
from pathlib import Path

from IPython.display import display

import pandas as pd
import dask

from tqdm import tqdm
import jax
jax.config.update('jax_platform_name', 'cpu')
# jax.config.update('jax_log_compiles', True)
# jax.config.update("jax_debug_nans", True)
# jax.config.update("jax_enable_x64", True)

In [2]:


sys.path.append("../..")

from lib import utils as U
from lib.ehr.dataset import load_dataset, load_dataset_scheme, load_dataset_config, Dataset
from lib.ehr.interface import Patients, InterfaceConfig
from lib.ehr.concepts import DemographicVectorConfig


In [3]:
import logging
logging.root.level = logging.INFO

In [4]:
tag = 'M3'
PATH = f'{os.environ.get("HOME")}/GP/ehr-data/mimic3-cohort'
sample = None
cache =  f'cached_inteface/patients_{tag}'
dataset_config = load_dataset_config(tag, 
                                     sample=sample,
                                     path=PATH)
dataset_gen = lambda c: load_dataset(config=c)

##### Possible Interface Scheme Configurations

In [5]:
import json
dataset_scheme = load_dataset_scheme(tag)
interface_schem_options = dataset_scheme.supported_target_scheme_options
print(json.dumps(interface_schem_options, sort_keys=True, indent=4))

{
    "dx": [
        "DxICD9",
        "DxICD10",
        "DxCCS",
        "DxFlatCCS"
    ],
    "ethnicity": [
        "MIMIC3Eth37",
        "MIMIC3Eth7"
    ],
    "gender": [
        "Gender"
    ],
    "outcome": [
        "dx_flatccs_mlhc_groups",
        "dx_flatccs_filter_v1",
        "dx_icd9_filter_v1",
        "dx_icd9_filter_v2_groups",
        "dx_icd9_filter_v3_groups"
    ]
}


In [6]:
interface_scheme = dataset_scheme.make_target_scheme_config(dx='DxICD9',
                                                            outcome='dx_icd9_filter_v3_groups',
                                                            ethnicity='MIMIC3Eth7')


# Demographic vector attributes
demographic_vector_conf = DemographicVectorConfig(
    age=False, 
    gender=False, 
    ethnicity=False
)
interface_config = InterfaceConfig(scheme=interface_scheme,
                                   dataset_scheme=dataset_scheme,
                                   demographic_vector=demographic_vector_conf,
                                   cache=cache)

In [7]:
interface_config

InterfaceConfig(
  demographic_vector=DemographicVectorConfig(
    gender=False,
    age=False,
    ethnicity=False
  ),
  leading_observable=None,
  scheme={
    'dx':
    'DxICD9',
    'ethnicity':
    'MIMIC3Eth7',
    'gender':
    'Gender',
    'outcome':
    'dx_icd9_filter_v3_groups'
  },
  cache='cached_inteface/patients_M3'
)

In [8]:
m3patients = Patients.try_load_cached(interface_config,
                                      dataset_config=dataset_config,
                                      dataset_generator=dataset_gen,
                                      num_workers=8)

INFO:root:Loading cached subjects.


In [9]:
splits = m3patients.dataset.random_splits([0.8, 0.9], random_seed=42, balanced='admissions')

In [10]:
len(m3patients.subjects)

7514

In [11]:
m3patients.subjects[splits[0][0]]

Patient(
  subject_id=44464,
  static_info=StaticInfo(
    demographic_vector_config=DemographicVectorConfig(
      gender=False,
      age=False,
      ethnicity=False
    ),
    gender=BinaryCodesVector(
      vec=bool[](numpy),
      scheme=<lib.ehr.coding_scheme.Gender object at 0x7ff14305ac70>
    ),
    ethnicity=CodesVector(
      vec=bool[7](numpy),
      scheme=<lib.ehr.coding_scheme.MIMIC3Eth7 object at 0x7ff13e2ab970>
    ),
    date_of_birth=Timestamp('2123-03-06 00:00:00'),
    constant_vec=f16[0](numpy)
  ),
  admissions=[
    Admission(
      admission_id=118659,
      admission_dates=(
        Timestamp('2184-12-13 12:23:00'),
        Timestamp('2184-12-21 16:15:00')
      ),
      dx_codes=CodesVector(
        vec=bool[17375](numpy),
        scheme=<lib.ehr.coding_scheme.DxICD9 object at 0x7ff14305a7f0>
      ),
      dx_codes_history=CodesVector(
        vec=bool[17375](numpy),
        scheme=<lib.ehr.coding_scheme.DxICD9 object at 0x7ff14305a7f0>
      ),
      outco

In [12]:
# m3patients.size_in_bytes() / 1024 ** 3

In [13]:
# val_batch = m3patients.device_batch(splits[1])

In [14]:
# tst_batch = m3patients.device_batch(splits[2])

In [15]:
# val_batch.size_in_bytes() / 1024 ** 3, tst_batch.size_in_bytes() / 1024 ** 3

In [16]:
# batch = m3patients.device_batch(splits[0][:32])

In [17]:
# batch.size_in_bytes() / 1024 ** 3

In [18]:
# len(batch.subjects)

In [19]:
# batch.n_admissions()

In [20]:
# batch.n_segments()

In [21]:
# batch.n_obs_times()

### التدريب على نموذج المعادلات التفاضلية الاعتيادية العصبية


In [22]:
from lib.ml import (ICENODE, ICENODEDimensions, 
                    GRU, GRUDimensions,
                    RETAIN, RETAINDimensions,
                    PatientEmbeddingDimensions, 
                    Trainer, TrainerReporting, OptimizerConfig, WarmupConfig,
                    TrainerConfig, ReportingConfig)
from lib.metric import  (CodeAUC, UntilFirstCodeAUC, AdmissionAUC,
                      CodeGroupTopAlarmAccuracy, LossMetric, ObsCodeLevelLossMetric)

import jax.random as jrandom

In [23]:
emb_dims = PatientEmbeddingDimensions(dx=50, demo=0)
key = jrandom.PRNGKey(0)

def icenode_model():
    dims = ICENODEDimensions(mem=20, emb=emb_dims)
    return ICENODE(dims=dims, schemes=m3patients.schemes, 
                   demographic_vector_config=m3patients.config.demographic_vector,
                   key=key)

def gru_model():
    dims = GRUDimensions(emb=emb_dims)
    return GRU(dims=dims, schemes=m3patients.schemes, 
                   demographic_vector_config=m3patients.config.demographic_vector,
                   key=key)

def retain_model():
    dims = RETAINDimensions(mem_a=25, mem_b=25, emb=emb_dims)
    return RETAIN(dims=dims, schemes=m3patients.schemes, 
                   demographic_vector_config=m3patients.config.demographic_vector,
                   key=key)

models = {
    'rnk_dx_icenode': icenode_model(),
#     'rnk_dx_gru': gru_model(),
#     'rnk_dx_retain': retain_model()
}

2023-09-04 18:48:56.485938: W external/xla/xla/service/platform_util.cc:198] unable to create StreamExecutor for CUDA:0: failed initializing StreamExecutor for CUDA device ordinal 0: INTERNAL: failed call to cuDevicePrimaryCtxRetain: CUDA_ERROR_OUT_OF_MEMORY: out of memory; total memory reported: 10492641280
INFO:jax._src.xla_bridge:Unable to initialize backend 'cuda': INTERNAL: no supported devices found for platform CUDA
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: Interpreter CUDA
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'


In [24]:
# res = m.batch_predict(m4inpatients.device_batch(), leave_pbar=True)

In [25]:

trainer_conf = TrainerConfig(optimizer=OptimizerConfig(opt='adam', lr=1e-3),
                             epochs=150,
                             batch_size=32,
                             dx_loss='allpairs_sigmoid_rank')
trainer = Trainer(trainer_conf)

warmup = WarmupConfig(epochs=0.1, 
                      batch_size=8,
                      opt='adam', lr=1e-3, 
                      decay_rate=0.5)

loss_metric =  LossMetric(m3patients, 
                          dx_loss=('softmax_bce', 'balanced_focal_softmax_bce', 
                                   'balanced_focal_bce', 'allpairs_exp_rank', 'allpairs_hard_rank', 
                                   'allpairs_sigmoid_rank'))

metrics = [CodeAUC(m3patients), 
           AdmissionAUC(m3patients), 
           CodeGroupTopAlarmAccuracy(m3patients, n_partitions=5, 
                                     top_k_list=[3, 5, 10, 15, 20],
                                     train_split=splits[0]),
           loss_metric]




In [None]:
res = {}
for name in models:
    model = models[name]
    print(name)
    reporting_conf = ReportingConfig(output_dir=name,
                                     console=True,
                                     model_stats=False,
                                     parameter_snapshots=True,
                                     config_json=True)
    reporting = TrainerReporting(reporting_conf, metrics=metrics)
    
    res[name] = trainer(model, m3patients, 
                  splits=splits,
                  reporting=reporting,
                  n_evals=100,
                  warmup_config=warmup,
                  continue_training=True)

rnk_dx_icenode


INFO:root:Continuing training from step 0
INFO:root:HPs: TrainerConfig(
  optimizer=OptimizerConfig(
    opt='adam',
    lr=0.001,
    decay_rate=None,
    reverse_schedule=False
  ),
  epochs=150,
  batch_size=32,
  dx_loss='allpairs_sigmoid_rank',
  obs_loss='mse',
  lead_loss='mse'
)


Loading to device:   0%|          | 0/752 [00:00<?, ?subject/s]

  0%|          | 0/150 [00:00<?, ?Epoch/s]

  0%|          | 0/309 [00:00<?, ?Batch/s]

Loading to device:   0%|          | 0/21 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/21 [00:00<?, ?subject/s]

  0%|          | 0.00/13545.16 [00:00<?, ?longitudinal-days/s]

Loading to device:   0%|          | 0/23 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/23 [00:00<?, ?subject/s]

  0%|          | 0.00/6827.14 [00:00<?, ?longitudinal-days/s]

Loading to device:   0%|          | 0/17 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/17 [00:00<?, ?subject/s]

  0%|          | 0.00/11354.87 [00:00<?, ?longitudinal-days/s]

Loading to device:   0%|          | 0/22 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/22 [00:00<?, ?subject/s]

  0%|          | 0.00/20565.34 [00:00<?, ?longitudinal-days/s]

Loading to device:   0%|          | 0/27 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/27 [00:00<?, ?subject/s]

  0%|          | 0.00/15321.49 [00:00<?, ?longitudinal-days/s]

Loading to device:   0%|          | 0/15 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/15 [00:00<?, ?subject/s]

  0%|          | 0.00/9832.51 [00:00<?, ?longitudinal-days/s]

Loading to device:   0%|          | 0/23 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/23 [00:00<?, ?subject/s]

  0%|          | 0.00/13310.51 [00:00<?, ?longitudinal-days/s]

Loading to device:   0%|          | 0/22 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/22 [00:00<?, ?subject/s]

  0%|          | 0.00/19270.76 [00:00<?, ?longitudinal-days/s]

Loading to device:   0%|          | 0/16 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/16 [00:00<?, ?subject/s]

  0%|          | 0.00/10212.27 [00:00<?, ?longitudinal-days/s]

Loading to device:   0%|          | 0/18 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/18 [00:00<?, ?subject/s]

  0%|          | 0.00/14519.02 [00:00<?, ?longitudinal-days/s]

Loading to device:   0%|          | 0/19 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/19 [00:00<?, ?subject/s]

  0%|          | 0.00/11033.49 [00:00<?, ?longitudinal-days/s]

Loading to device:   0%|          | 0/19 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/19 [00:00<?, ?subject/s]

  0%|          | 0.00/14989.17 [00:00<?, ?longitudinal-days/s]

Loading to device:   0%|          | 0/16 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/16 [00:00<?, ?subject/s]

  0%|          | 0.00/18543.36 [00:00<?, ?longitudinal-days/s]

Loading to device:   0%|          | 0/18 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/18 [00:00<?, ?subject/s]

  0%|          | 0.00/10902.89 [00:00<?, ?longitudinal-days/s]

Loading to device:   0%|          | 0/16 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/16 [00:00<?, ?subject/s]

  0%|          | 0.00/9872.39 [00:00<?, ?longitudinal-days/s]

Loading to device:   0%|          | 0/19 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/19 [00:00<?, ?subject/s]

  0%|          | 0.00/12978.34 [00:00<?, ?longitudinal-days/s]

Loading to device:   0%|          | 0/24 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/24 [00:00<?, ?subject/s]

  0%|          | 0.00/6506.66 [00:00<?, ?longitudinal-days/s]

Loading to device:   0%|          | 0/17 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/17 [00:00<?, ?subject/s]

  0%|          | 0.00/16058.15 [00:00<?, ?longitudinal-days/s]

Loading to device:   0%|          | 0/24 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/24 [00:00<?, ?subject/s]

  0%|          | 0.00/13622.13 [00:00<?, ?longitudinal-days/s]

###### 