In [1]:
%load_ext autoreload
%autoreload 2
import sys
import os
import glob
import random
from collections import defaultdict
from pathlib import Path

from IPython.display import display

import pandas as pd
import dask

from tqdm import tqdm
import jax
jax.config.update('jax_platform_name', 'cpu')
# jax.config.update('jax_log_compiles', True)
# jax.config.update("jax_debug_nans", True)
# jax.config.update("jax_enable_x64", True)

In [2]:


sys.path.append("../..")

from lib import utils as U
from lib.ehr import load_dataset, load_dataset_scheme, load_dataset_config, Dataset
from lib.ehr.interface import Patients, InterfaceConfig
from lib.ehr.concepts import DemographicVectorConfig, LeadingObservableConfig


In [3]:
import logging
logging.root.level = logging.INFO

In [4]:
tag = 'M4ICU'
PATH = f'{os.environ.get("HOME")}/GP/ehr-data/mimic4icu-cohort'
sample = 500
cache =  f'cached_inteface/patients_{tag}_{sample or ""}'
dataset_config = load_dataset_config(tag, 
                                     sample=sample,
                                     path=PATH)

In [5]:
cache

'cached_inteface/patients_M4ICU_500'

##### Possible Interface Scheme Configurations

In [6]:
import json
dataset_scheme = load_dataset_scheme(tag)
interface_schem_options = dataset_scheme.supported_target_scheme_options
print(json.dumps(interface_schem_options, sort_keys=True, indent=4))

{
    "dx": [
        "DxICD9",
        "DxICD10",
        "DxFlatCCS",
        "DxCCS"
    ],
    "ethnicity": [
        "MIMIC4Eth32",
        "MIMIC4Eth5"
    ],
    "gender": [
        "Gender"
    ],
    "int_input": [
        "MIMICInput",
        "MIMICInputGroups"
    ],
    "int_proc": [
        "MIMICProcedures",
        "MIMICProcedureGroups"
    ],
    "obs": [
        "MIMICObservables"
    ],
    "outcome": [
        "dx_icd9_filter_v3_groups",
        "dx_icd9_filter_v2_groups",
        "dx_icd9_filter_v1",
        "dx_flatccs_filter_v1",
        "dx_flatccs_mlhc_groups"
    ]
}


#### Leading Observable for Early Prediction Task

In [7]:
scheme_df = dataset_scheme.obs.as_dataframe()
display(scheme_df[scheme_df.desc.str.contains('aki')])

Unnamed: 0,code,desc
42,o42,aki_stage_smoothed


In [8]:

# Demographic vector attributes
demographic_vector_conf = DemographicVectorConfig(
    age=False, 
    gender=False, 
    ethnicity=False
)

# Leading 
leading_AKI = LeadingObservableConfig(leading_hours=tuple(6 * i for i in range(1, 13)),
                                      window_aggregate='max',
                                      scheme=dataset_scheme.obs,
                                      index=42)

In [9]:
interface_scheme = dataset_scheme.make_target_scheme_config(dx='DxICD9',
                                                            outcome='dx_icd9_filter_v3_groups',
                                                            ethnicity='MIMIC4Eth5')
interface_config = InterfaceConfig(scheme=interface_scheme,
                                   dataset_scheme=dataset_scheme,
                                   demographic_vector=demographic_vector_conf,
                                   leading_observable=leading_AKI,
                                   cache=cache)

### التدريب على نموذج المعادلات التفاضلية الاعتيادية العصبية


In [10]:
from lib.ml import (InICENODE, InICENODEConfig, InpatientEmbeddingConfig,  SplitConfig,
                    InTrainer, TrainerConfig, TrainerReporting, OptimizerConfig, WarmupConfig, ReportingConfig)
from lib.metric import  (CodeAUC, UntilFirstCodeAUC, AdmissionAUC, CodeLevelMetricConfig, MetricLevelsConfig,
                         LossMetricConfig,
                      CodeGroupTopAlarmAccuracy, LossMetric, ObsCodeLevelLossMetric, CodeGroupTopAlarmAccuracyConfig)
from lib.ml import Experiment, InpatientExperiment, ExperimentConfig, SplitConfig

import jax.random as jrandom

In [11]:
emb_dims = InpatientEmbeddingConfig(dx=30, inp=15, proc=15, 
                                        demo=0, 
                                        inp_proc_demo=10)
model_config = InICENODEConfig(mem=15, obs=25, lead=5, emb=emb_dims)
model_classname = InICENODE.__name__

In [12]:
trainer_config = TrainerConfig(optimizer=OptimizerConfig(opt='adam', lr=1e-3),
                          epochs=80,
                          batch_size=128,
                          dx_loss='balanced_focal_bce',
                          obs_loss='mse',
                          lead_loss='mse')

warmup = WarmupConfig(epochs=0.1, 
                      batch_size=8,
                      opt='adam', lr=1e-3, 
                      decay_rate=0.5)




In [13]:
dx_loss = ["softmax_bce", "balanced_focal_softmax_bce", "balanced_focal_bce",
          "allpairs_exp_rank", "allpairs_hard_rank", "allpairs_sigmoid_rank"]
lead_loss = ["mse", "mae", "rms"]
obs_loss =  ["mse", "mae", "rms"]
                
metrics_conf = [
    (CodeAUC, CodeLevelMetricConfig(aggregate_level=True, code_level=True)),
    (AdmissionAUC, MetricLevelsConfig(admission=False, aggregate=True, subject_aggregate=False)),
    (CodeGroupTopAlarmAccuracy, CodeGroupTopAlarmAccuracyConfig(n_partitions=5, top_k_list=[3, 5, 10, 15, 20])),
    (LossMetric, LossMetricConfig(dx_loss=dx_loss, lead_loss=lead_loss, obs_loss=obs_loss))
]
metrics_conf = [m.export_module_class(c) for m, c in metrics_conf]

In [14]:
reporting_conf = ReportingConfig(output_dir='inicenode',
                                 console=True,
                                 model_stats=False,
                                 parameter_snapshots=True,
                                 config_json=True)

In [15]:
expt_config = ExperimentConfig(dataset=dataset_config,
                              interface=interface_config,
                              split=SplitConfig(train=0.8, val=0.1, test=0.1, balanced='admissions'),
                              trainer=trainer_config,
                              metrics=metrics_conf,
                              reporting=reporting_conf,
                              model=model_config,
                              model_classname=model_classname,
                              n_evals=100,
                              continue_training=True,
                              warmup=None,
                              reg_hyperparams=None)

In [16]:
experiment = InpatientExperiment(expt_config)

In [None]:
result = experiment.run()

INFO:root:Loading cached subjects.
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: Interpreter CUDA
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
INFO:root:Continuing training from step 0
INFO:root:HPs: TrainerConfig(
  optimizer=OptimizerConfig(
    opt='adam',
    lr=0.001,
    decay_rate=None,
    reverse_schedule=False
  ),
  epochs=80,
  batch_size=128,
  dx_loss='balanced_focal_bce',
  obs_loss='mse',
  lead_loss='mse'
)


Loading to device:   0%|          | 0/34 [00:00<?, ?subject/s]

  0%|          | 0/80 [00:00<?, ?Epoch/s]

  0%|          | 0/9 [00:00<?, ?Batch/s]

Loading to device:   0%|          | 0/52 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/52 [00:00<?, ?subject/s]

  0%|          | 0.00/710.04 [00:00<?, ?longitudinal-days/s]

Loading to device:   0%|          | 0/56 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/56 [00:00<?, ?subject/s]

  0%|          | 0.00/882.62 [00:00<?, ?longitudinal-days/s]

Loading to device:   0%|          | 0/46 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/46 [00:00<?, ?subject/s]

  0%|          | 0.00/752.12 [00:00<?, ?longitudinal-days/s]

Loading to device:   0%|          | 0/52 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/52 [00:00<?, ?subject/s]

  0%|          | 0.00/901.79 [00:00<?, ?longitudinal-days/s]



Loading to device:   0%|          | 0/52 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/52 [00:00<?, ?subject/s]

  0%|          | 0.00/989.21 [00:00<?, ?longitudinal-days/s]