In [1]:
# %env XLA_PYTHON_CLIENT_PREALLOCATE=false


In [2]:

%load_ext autoreload
%autoreload 2
import sys
import os
import glob
import random
from collections import defaultdict
from pathlib import Path

from IPython.display import display

import pandas as pd

import jax
jax.config.update('jax_platform_name', 'cpu')
# jax.config.update('jax_log_compiles', True)
# jax.config.update("jax_debug_nans", True)
# jax.config.update("jax_enable_x64", True)

In [3]:


sys.path.append("../..")

from lib import utils as U
from lib.ehr import load_dataset, load_dataset_scheme, load_dataset_config, Dataset
from lib.ehr.interface import Patients, InterfaceConfig
from lib.ehr.concepts import DemographicVectorConfig, LeadingObservableConfig


In [4]:
import logging
logging.root.level = logging.INFO

## Load Dataset

In [5]:
tag = 'M4ICU'
PATH = f'{os.environ.get("HOME")}/GP/ehr-data/mimic4icu-cohort'
sample = 100
cache =  f'cached_inteface/patients_{tag}_{sample or ""}'
dataset_config = load_dataset_config(tag, 
                                     sample=sample,
                                     path=PATH)

In [6]:
import json
dataset_scheme = load_dataset_scheme(tag)
interface_schem_options = dataset_scheme.supported_target_scheme_options
print(json.dumps(interface_schem_options, sort_keys=True, indent=4))

{
    "dx": [
        "DxICD10",
        "DxCCS",
        "DxICD9",
        "DxFlatCCS"
    ],
    "ethnicity": [
        "MIMIC4Eth32",
        "MIMIC4Eth5"
    ],
    "gender": [
        "Gender"
    ],
    "int_input": [
        "MIMICInput",
        "MIMICInputGroups"
    ],
    "int_proc": [
        "MIMICProcedures",
        "MIMICProcedureGroups"
    ],
    "obs": [
        "MIMICObservables"
    ],
    "outcome": [
        "dx_icd9_filter_v1",
        "dx_icd9_filter_v2_groups",
        "dx_flatccs_mlhc_groups",
        "dx_flatccs_filter_v1",
        "dx_icd9_filter_v3_groups"
    ]
}


In [7]:

# Demographic vector attributes
demographic_vector_conf = DemographicVectorConfig(
    age=True, 
    gender=True, 
    ethnicity=True
)

# Leading 
leading_AKI = LeadingObservableConfig(leading_hours=[6, 12, 24, 48, 72],
                                      window_aggregate='max',
                                      scheme=dataset_scheme.obs,
                                      index=42)

In [8]:
interface_scheme = dataset_scheme.make_target_scheme_config(dx='DxICD9',
                                                            outcome='dx_icd9_filter_v3_groups',
                                                            ethnicity='MIMIC4Eth5')
interface_config = InterfaceConfig(scheme=interface_scheme,
                                   dataset_scheme=dataset_scheme,
                                   demographic_vector=demographic_vector_conf,
                                   leading_observable=leading_AKI,
                                   cache=cache)

In [9]:
def dataset_gen(dataset_config):
    dataset = load_dataset(config=dataset_config)
    # Use training-split for fitting the outlier_remover and the scalers.
    splits = dataset.random_splits([0.8, 0.9], random_seed=42, balanced='admissions')
    # Outlier removal
    outlier_remover = dataset.fit_outlier_remover(splits[0])
    dataset = dataset.remove_outliers(outlier_remover)

    # Scale
    scalers = dataset.fit_scalers(splits[0])
    return dataset.apply_scalers(scalers)
        

In [10]:
m4patients = Patients.try_load_cached(interface_config,
                                      dataset_config=dataset_config,
                                      dataset_generator=dataset_gen,
                                      num_workers=8)

INFO:root:Loading cached subjects.


## Load Model

In [11]:
from lib.ml import InpatientExperiment
from lib.ehr import TrajectoryConfig
from lib import Config, Module

In [12]:
experiments_dir = f'{os.environ.get("HOME")}/GP/ehr-data/m4icu_out/onestate_mlp_dtw_inicenode'

params_file = 'step1470.eqx'

experiment_config = U.load_config(f'{experiments_dir}/config.json')
experiment_config = Config.from_dict(experiment_config)
experiment = InpatientExperiment(config=experiment_config)

In [13]:
splits = experiment.load_splits(m4patients.dataset)
model = experiment.load_model(m4patients)

In [14]:
[len(s) for s in splits]

[77, 18, 5]

In [15]:
metrics = experiment.load_metrics(m4patients, splits)

In [16]:
model = model.load_params_from_archive(f'{experiments_dir}/params.zip', params_file)


In [17]:
test_split = m4patients.device_batch(splits[2])

Loading to device:   0%|          | 0/5 [00:00<?, ?subject/s]

In [18]:
traj_conf = TrajectoryConfig(sampling_rate=0.5)

In [25]:
scaled_vis, unscaled_vis = model.predict_visualisables(test_split, store_embeddings=traj_conf)

Embedding:   0%|          | 0/5 [00:00<?, ?subject/s]

  0%|          | 0.00/166.23 [00:00<?, ?longitudinal-days/s]



In [41]:
scaled_vis['13672788']['24862501'].obs['o23'].value

array([0.1301, 0.1328], dtype=float16)

In [None]:
cpu_predictions.admission.interventions.segmented_input.shape[1]

In [None]:
import equinox as eqx
arrs, others = eqx.partition(predictions, eqx.is_array)
arrs

In [None]:
arrs, others = eqx.partition(predictions, eqx.is_array)
arrs

In [None]:
import numpy as np
import jax.numpy as jnp

arrs, others = eqx.partition({'x': np.arange(10), 'z': {'x': np.arange(5), 'z': {'y': jnp.arange(3)}}}, eqx.is_array)
others

In [None]:
cpu_predictions['13672788']['24862501'].admission.interventions.segmented_input.device()

In [None]:
cpu_predictions['13672788']['24862501'].admission.interventions.segmented_input.device()

In [None]:
m4patients.dataset.scalers_history['int_input'].max_val.index