In [1]:
# %env XLA_PYTHON_CLIENT_PREALLOCATE=false


In [2]:

%load_ext autoreload
%autoreload 2
import sys
import os
import glob
import random
from collections import defaultdict
from pathlib import Path

from IPython.display import display

import pandas as pd

import jax
jax.config.update('jax_platform_name', 'cpu')
# jax.config.update('jax_log_compiles', True)
# jax.config.update("jax_debug_nans", True)
# jax.config.update("jax_enable_x64", True)

In [3]:


sys.path.append("../..")

from lib import utils as U
from lib.ehr import load_dataset, load_dataset_scheme, load_dataset_config, Dataset
from lib.ehr.interface import Patients, InterfaceConfig
from lib.ehr.concepts import DemographicVectorConfig, LeadingObservableConfig


## Load Dataset

In [4]:
tag = 'M4ICU'
PATH = f'{os.environ.get("HOME")}/GP/ehr-data/mimic4icu-cohort'
sample = 15000
cache =  f'cached_inteface/patients_{tag}_{sample or ""}'
dataset_config = load_dataset_config(tag, 
                                     sample=sample,
                                     path=PATH)

In [5]:
import json
dataset_scheme = load_dataset_scheme(tag)
interface_schem_options = dataset_scheme.supported_target_scheme_options
print(json.dumps(interface_schem_options, sort_keys=True, indent=4))

In [6]:

# Demographic vector attributes
demographic_vector_conf = DemographicVectorConfig(
    age=False, 
    gender=False, 
    ethnicity=False
)

# Leading 
leading_AKI = LeadingObservableConfig(leading_hours=tuple(6 * i for i in range(1, 13)),
                                      window_aggregate='max',
                                      scheme=dataset_scheme.obs,
                                      index=42)

In [7]:
interface_scheme = dataset_scheme.make_target_scheme_config(dx='DxICD9',
                                                            outcome='dx_icd9_filter_v3_groups',
                                                            ethnicity='MIMIC4Eth5')
interface_config = InterfaceConfig(scheme=interface_scheme,
                                   dataset_scheme=dataset_scheme,
                                   demographic_vector=demographic_vector_conf,
                                   leading_observable=leading_AKI,
                                   cache=cache)

In [8]:
def dataset_gen(dataset_config):
    dataset = load_dataset(config=dataset_config)
    # Use training-split for fitting the outlier_remover and the scalers.
    splits = dataset.random_splits([0.8, 0.9], random_seed=42, balanced='admissions')
    # Outlier removal
    outlier_remover = dataset.fit_outlier_remover(splits[0])
    dataset = dataset.remove_outliers(outlier_remover)

    # Scale
    scalers = dataset.fit_scalers(splits[0])
    return dataset.apply_scalers(scalers)
        

In [None]:
m4patients = Patients.try_load_cached(interface_config,
                                      dataset_config=dataset_config,
                                      dataset_generator=dataset_gen,
                                      num_workers=8)

In [None]:
# import equinox as eqx
# # Delete heavy loads
# m4patients = eqx.tree_at(lambda x: x.subjects, m4patients, {})
# m4patients = eqx.tree_at(lambda x: x.dataset, m4patients, None)

## Load Model

In [None]:
from lib.ml import InpatientExperiment
from lib import Config, Module

In [None]:
experiments_dir = f'{os.environ.get("HOME")}/GP/ehr-data/m4icu_out/backup_override/sigmo_dtw_B32_icenode'
# experiments_dir = f'{os.environ.get("HOME")}/GP/ehr-data/m4icu_out/backup_override/sigmo_mse_B32_icenode'

# experiments_dir = f'{os.environ.get("HOME")}/GP/ehr-data/m4icu_out/mono_mse_icenode'
# experiments_dir = f'{os.environ.get("HOME")}/GP/ehr-data/m4icu_out/sigmo_dtw_icenode'
# experiments_dir = f'{os.environ.get("HOME")}/GP/ehr-data/m4icu_out/sigmo_mse_icenode'

params_file = 'step9293.eqx'
# params_file = 'step8260.eqx'

# params_file = 'step3355.eqx'
# params_file = 'step7227.eqx'
# params_file = 'step2839.eqx'

experiment_config = U.load_config(f'{experiments_dir}/config.json')
experiment_config = Config.from_dict(experiment_config)
experiment = InpatientExperiment(config=experiment_config)

In [None]:
splits = experiment.load_splits(m4patients.dataset)
model = experiment.load_model(m4patients)

In [None]:
[len(s) for s in splits]

In [None]:
metrics = experiment.load_metrics(m4patients, splits)

In [None]:
model = model.load_params_from_archive(f'{experiments_dir}/params.zip', params_file)


In [None]:
test_split = m4patients.device_batch(splits[2])

In [None]:
predictions = model.batch_predict(test_split)

In [None]:
predictions.save(f'{experiments_dir}/predictions_{params_file}')

In [None]:
from lib.ehr import Predictions
predictions = Predictions.load(f'{experiments_dir}/predictions_{params_file}')

In [None]:
predictions._defragment_observables()
predictions.save(f'{experiments_dir}/defrag_predictions_{params_file}')

In [16]:
from lib.ehr import Predictions
predictions = Predictions.load(f'{experiments_dir}/defrag_predictions_{params_file}')

In [None]:
from lib.metric import AKISegmentedAdmissionMetric, AKISegmentedAdmissionConfig

In [None]:
aki_metric = AKISegmentedAdmissionMetric(patients=m4patients, 
                                         config=AKISegmentedAdmissionConfig(stable_window=72))

In [None]:
res,segmented_AKI,segmented_AKI_byclass = aki_metric(predictions)

In [None]:
res

In [29]:
segmented_AKI_byclass.keys()

In [30]:
segmented_AKI['28669544']

In [60]:
segmented_AKI_byclass['AKI_pre_emergence'][3]

In [63]:
# [sid for sid in predictions if '27896316' in predictions[sid]]

In [64]:
import numpy as np
# pred = predictions['14139649']['28669544']
pred = predictions['14316710']['27896316' ]
aki_now = pred.admission.observables.value[:, 42]
lobs_gt = pred.admission.leading_observable
lobs = pred.leading_observable
aki_t = lobs.time
aki_preds = np.max(lobs.value, axis=1)
aki_mask = lobs.mask.max(axis=1) > 0
aki_gt = np.max(lobs_gt.value, axis=1)

In [65]:
import matplotlib.pyplot as plt

# plot lines 
plt.scatter(aki_t, aki_preds, label = "lead_aki_preds", marker="x") 
plt.scatter(aki_t, aki_gt, label = "lead_aki_gt", marker="x") 
plt.scatter(aki_t, aki_now, label = "aki_now", marker="x") 

plt.legend() 
plt.show()

In [57]:
aki_now

In [56]:
segmented_AKI['28669544']

In [21]:
import pandas as pd

In [38]:
val_df = pd.read_csv(f'{experiments_dir}/val_evals.csv.gz', index_col=[0])
trn_df = pd.read_csv(f'{experiments_dir}/train_evals.csv.gz', index_col=[0])

In [39]:
val_df

In [60]:
obs_cols = [c for c in val_df.columns if 'obs' in c]
lead_cols = [c for c in val_df.columns if 'lead' in c]
dx_cols = [c for c in val_df.columns if 'dx' in c]
obs_cols, lead_cols, dx_cols

In [61]:
trn_obs_df = trn_df[obs_cols + lead_cols + dx_cols]
val_obs_df = val_df[obs_cols + lead_cols + dx_cols]


In [62]:
trn_obs_df

In [43]:
val_obs_df

In [81]:
import matplotlib.pyplot as plt
import numpy as np
plt.rcParams.update({'font.size': 11})

epochs = np.arange(len(val_obs_df)) + 1

# plot lines 
plt.plot(epochs, val_obs_df['LossMetric.obs_mse'].values, label = "Validation Loss", marker='o') 
plt.plot(epochs, trn_obs_df['LossMetric.obs_mse'].values, label = "Training Loss", marker='o') 

plt.title('$\mathcal{L}_z$: Observation MSE Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')

current_figure = plt.gcf()
current_figure.savefig(f'{experiments_dir}/obs_loss.pdf', bbox_inches='tight')

plt.legend() 
plt.show()

In [82]:
# plot lines 
plt.plot(epochs, val_obs_df['LossMetric.lead_mse'].values, label = "Validation Loss", marker='o') 
plt.plot(epochs, trn_obs_df['LossMetric.lead_mse'].values, label = "Training Loss", marker='o') 

plt.title('$\mathcal{L}_q$: AKI Early Prediction MSE Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')


current_figure = plt.gcf()
current_figure.savefig(f'{experiments_dir}/lead_loss.pdf', bbox_inches='tight')

plt.legend() 
plt.show()

In [83]:
# plot lines 
plt.plot(epochs, val_obs_df['LossMetric.dx_balanced_focal_bce'].values, label = "Validation Loss", marker='o') 
plt.plot(epochs, trn_obs_df['LossMetric.dx_balanced_focal_bce'].values, label = "Training Loss", marker='o') 

plt.title('$\mathcal{L}_x$: Discharge Codes Prediction BCE Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')


current_figure = plt.gcf()
current_figure.savefig(f'{experiments_dir}/dx_loss.pdf', bbox_inches='tight')


plt.legend() 
plt.show()

In [84]:
# plot lines 
val_loss = 50 * val_obs_df['LossMetric.obs_mse'] + 50 * val_obs_df['LossMetric.lead_mse'] +  val_obs_df['LossMetric.dx_balanced_focal_bce']
trn_loss = 50 * trn_obs_df['LossMetric.obs_mse'] + 50 * trn_obs_df['LossMetric.lead_mse'] +  trn_obs_df['LossMetric.dx_balanced_focal_bce']

plt.plot(epochs, val_loss.values, label = "Validation Loss", marker='o') 
plt.plot(epochs, trn_loss.values, label = "Training Loss", marker='o') 

plt.title('$\mathcal{L}$: Multi-Objective Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')


current_figure = plt.gcf()
current_figure.savefig(f'{experiments_dir}/loss.pdf', bbox_inches='tight')


plt.legend() 
plt.show()

In [88]:
first_emergence = {k:v for k,v in res.items() if k.startswith('first')}

In [92]:
first_emergence = {k.split('_')[-1]: v for k,v in first_emergence.items()}

In [94]:
fem_auc_df = pd.DataFrame(first_emergence, index=[0])

In [95]:
print(fem_auc_df.to_latex())