In [1]:
# %env XLA_PYTHON_CLIENT_PREALLOCATE=false


In [2]:

%load_ext autoreload
%autoreload 2
import sys
import os
import glob
import random
from collections import defaultdict
from pathlib import Path

from IPython.display import display

import pandas as pd

import jax
jax.config.update('jax_platform_name', 'cpu')
# jax.config.update('jax_log_compiles', True)
# jax.config.update("jax_debug_nans", True)
jax.config.update("jax_enable_x64", True)

In [3]:


sys.path.append("../..")

from lib import utils as U
from lib.ehr import load_dataset, load_dataset_scheme, load_dataset_config, Dataset
from lib.ehr.interface import Patients, InterfaceConfig
from lib.ehr.concepts import DemographicVectorConfig, LeadingObservableConfig


In [4]:
from lib.ml import EvaluationConfig, Evaluation

In [5]:
conf = U.load_config('~/GP/ICENODE/experiment_templates/icu/eval.json')

In [6]:
conf = EvaluationConfig.from_dict(conf)
conf

In [7]:
conf = conf.path_update('experiments_dir', '/home/asem/GP/ehr-data/m4icu_out')
conf = conf.path_update('db', 'db.sqlite')
conf

In [8]:
import logging
logging.root.level = logging.INFO

In [9]:
ev = Evaluation(conf)

In [11]:
exp = ev.get_experiment('onestate_mlp_mse_inskelkoopman48')

In [12]:
IF = exp.load_interface()

In [13]:
splits = exp.load_splits(IF.dataset)


In [14]:
metrics = ev.load_metrics(IF, splits)

In [19]:
metrics.metrics[1].fields()

In [10]:
ev.start()

## Load Dataset

In [5]:
tag = 'M4ICU'
PATH = f'{os.environ.get("HOME")}/GP/ehr-data/mimic4icu-cohort'
sample = 100
cache =  f'cached_inteface/patients_{tag}_{sample or ""}'
dataset_config = load_dataset_config(tag, 
                                     sample=sample,
                                     path=PATH)

In [6]:
import json
dataset_scheme = load_dataset_scheme(tag)
interface_schem_options = dataset_scheme.supported_target_scheme_options
print(json.dumps(interface_schem_options, sort_keys=True, indent=4))

In [7]:

# Demographic vector attributes
demographic_vector_conf = DemographicVectorConfig(
    age=True, 
    gender=True, 
    ethnicity=True
)

# Leading 
leading_AKI = LeadingObservableConfig(leading_hours=[6, 12, 24, 48, 72],
                                      window_aggregate='max',
                                      scheme=dataset_scheme.obs,
                                      index=42)

In [8]:
interface_scheme = dataset_scheme.make_target_scheme_config(dx='DxICD9',
                                                            outcome='dx_icd9_filter_v3_groups',
                                                            ethnicity='MIMIC4Eth5')
interface_config = InterfaceConfig(scheme=interface_scheme,
                                   dataset_scheme=dataset_scheme,
                                   demographic_vector=demographic_vector_conf,
                                   leading_observable=leading_AKI,
                                   cache=cache)

In [9]:
def dataset_gen(dataset_config):
    dataset = load_dataset(config=dataset_config)
    # Use training-split for fitting the outlier_remover and the scalers.
    splits = dataset.random_splits([0.8, 0.9], random_seed=42, balanced='admissions')
    # Outlier removal
    outlier_remover = dataset.fit_outlier_remover(splits[0])
    dataset = dataset.remove_outliers(outlier_remover)

    # Scale
    scalers = dataset.fit_scalers(splits[0])
    return dataset.apply_scalers(scalers)
        

In [10]:
m4patients = Patients.try_load_cached(interface_config,
                                      dataset_config=dataset_config,
                                      dataset_generator=dataset_gen,
                                      num_workers=8)

## Load Model

In [11]:
from lib.ml import InpatientExperiment
from lib.ehr import TrajectoryConfig
from lib import Config, Module

In [12]:
experiments_dir = f'{os.environ.get("HOME")}/GP/ehr-data/m4icu_out/onestate_mlp_dtw_inicenode'

params_file = 'step1470.eqx'

experiment_config = U.load_config(f'{experiments_dir}/config.json')
experiment_config = Config.from_dict(experiment_config)
experiment = InpatientExperiment(config=experiment_config)

In [13]:
splits = experiment.load_splits(m4patients.dataset)
model = experiment.load_model(m4patients)

In [14]:
[len(s) for s in splits]

In [15]:
metrics = experiment.load_metrics(m4patients, splits)

In [16]:
model = model.load_params_from_archive(f'{experiments_dir}/params.zip', params_file)


In [17]:
test_split = m4patients.device_batch(splits[2])

In [18]:
from lib.visualisables import ModelVisualiser
ds_src_scheme, ds_t_scheme = m4patients.schemes


In [19]:
visualiser = ModelVisualiser(
    scalers_history=m4patients.dataset.scalers_history,
    trajectory_config=TrajectoryConfig(sampling_rate=0.5),
    obs_scheme=ds_t_scheme.obs,
    int_input_scheme=ds_src_scheme.int_input,
    int_proc_scheme=ds_t_scheme.int_proc,
    leading_observable_config=m4patients.config.leading_observable)

In [43]:
vis = visualiser.batch_predict(model, test_split)

In [44]:
from bokeh.plotting import figure, show
from bokeh.io import output_notebook
output_notebook()

In [45]:

[k for k in vis['13672788'] if len(vis['13672788'][k].lead) > 0]

In [87]:
figures = visualiser.make_bokeh(vis['13672788']['21019221'])

In [88]:
show(figures['interventions'])

In [76]:
show(figures['obs'])

In [77]:
show(figures['lead'])