In [1]:
# %reload_ext autoreload
# %autoreload 2
from typing import Optional, Tuple, Literal
import os
import sys
import time
import matplotlib.pyplot as plt
import seaborn as sns
import sqlalchemy
import numpy as np
import pandas as pd
import jax
import jax.numpy as jnp
import jax.random as jrandom 
import jax.nn as jnn
import equinox as eqx
import optax
from tqdm.notebook import tqdm
from collections import defaultdict
from functools import partial
from scipy.stats import spearmanr, pearsonr

jax.config.update('jax_enable_x64', True)
jax.config.update('jax_platforms', 'cpu')

# jax.config.update('jax_check_tracer_leaks', True) 
sys.path.append("../../..")
# %reload_ext autoreload
# %autoreload 2
    
from lib.base import Config, Module
from lib.ml.in_models import InICENODELiteICNNImpute
from lib.ml.rectilinear_modules import ZeroImputer, MeanImputer, RectilinearImputer
from lib.ml.experiment import Experiment
from lib.metric.metrics import MetricsCollection
import lib.ehr.example_datasets.mimiciv_aki as m4aki
from lib.ehr.tvx_ehr import TVxEHR
from lib.utils import modified_environ, write_config, append_params_to_zip, zip_members, load_config, translate_path

import pub_ready_plots
from pub_ready_plots import get_mpl_rcParams
rc_params, fig_width_in, fig_height_in = pub_ready_plots.get_mpl_rcParams(
    width_frac=1,  # between 0 and 1
    height_frac=0.2,  # between 0 and 1
    layout="jmlr"  # or "iclr", "neurips", "poster-portrait", "poster-landscape"
)
rc_params['figure.constrained_layout.use'] = True

# rc_params['font.size'] = 10
# rc_params['axes.titlesize'] = 12
# rc_params['axes.labelsize'] = 10
# rc_params['legend.fontsize'] = 10

plt.rcParams.update(rc_params)
# %reload_ext autoreload
# %autoreload 2

'/home/asem'

In [2]:
# EVAL_CONFIG_FILE = '/home/asem/GP/ICENODE/experiment_templates/icu/eval_config.json'
# metrics_config = Config.from_dict(load_config(translate_path(EVAL_CONFIG_FILE))).metrics
# metrics = MetricsCollection(metrics=tuple(Module.import_module(config=config) for config in metrics_config))

EXP_DIR = f'{os.environ["HOME"]}/GP/ehr-data/aki_out/seg_pretrain/AutoODEICNN_mlpdyn_g0'
SNAPSHOT = 'step0240.eqx'
experiment = Experiment(Config.from_dict(load_config(os.path.join(EXP_DIR, 'config.json'))))


In [3]:

RESULTS_DIR = 'ode-icnn_results'


In [None]:
tvx0 = TVxEHR.load("/home/asem/GP/ehr-data/mimic4aki-cohort/tvx_aki_0.h5")
tvx1 = TVxEHR.load("/home/asem/GP/ehr-data/mimic4aki-cohort/tvx_aki_1.h5")

In [4]:
tvx_phantom = TVxEHR.load("/home/asem/GP/ehr-data/mimic4aki-cohort/tvx_aki_phantom.h5")
tvx0 = tvx1 = tvx_phantom

In [6]:
odeicnn = experiment.load_model(tvx0, 0).load_params_from_archive(os.path.join(EXP_DIR, 'params.zip'), SNAPSHOT)

In [7]:
def gen_data(tvx):
    
    obs = [adm.observables  for subject in tvx.subjects.values() for adm in subject.admissions]
    adm_id = sum(([adm.admission_id] * len(adm.observables.time)  for subject in tvx.subjects.values() for adm in subject.admissions), [])
    subj_id = sum(([subject.subject_id] * len(adm.observables.time)  for subject in tvx.subjects.values() for adm in subject.admissions), [])
    
    obs_val = np.vstack([obs_i.value for obs_i in obs])
    obs_mask = np.vstack([obs_i.mask for obs_i in obs])
    obs_time = np.hstack([obs_i.time for obs_i in obs])
    obs_time_index = np.hstack([np.arange(len(obs_i.time)) for obs_i in obs])
    
    tvx.scheme.obs
    features = list(map(tvx.scheme.obs.desc.get, tvx.scheme.obs.codes))
    
    
    obs_val = pd.DataFrame(obs_val, columns=features)
    obs_mask = pd.DataFrame(obs_mask.astype(int), columns=features)
    meta = pd.DataFrame({'subject_id': subj_id, 'admission_id': adm_id, 'time': obs_time, 'time_index': obs_time_index})

    artificial_mask = obs_mask.copy()
    artificial_mask = obs_mask & np.array(jrandom.bernoulli(jrandom.PRNGKey(0), p=0.8, shape=obs_mask.shape))
    return obs_val, obs_mask, artificial_mask, meta

In [None]:
# obs_val0, obs_mask0, art_mask0, meta0 = gen_data(tvx0)
# obs_val1, obs_mask1, art_mask1, meta1 = gen_data(tvx1)

# obs_val0.to_csv('g0g1_missingness_data/missingness_vals0.csv')
# obs_mask0.to_csv('g0g1_missingness_data/missingness_mask0.csv')
# art_mask0.to_csv('g0g1_missingness_data/missingness_artificial_mask0.csv')
# meta0.to_csv('g0g1_missingness_data/missingness_meta0.csv')

# obs_val1.to_csv('g0g1_missingness_data/missingness_vals1.csv')
# obs_mask1.to_csv('g0g1_missingness_data/missingness_mask1.csv')
# art_mask1.to_csv('g0g1_missingness_data/missingness_artificial_mask1.csv')
# meta1.to_csv('g0g1_missingness_data/missingness_meta1.csv')

In [9]:
# obs_valp, obs_maskp, art_maskp, metap = gen_data(tvx_phantom)
# obs_valp.to_csv('g0g1_missingness_data/missingness_valsp.csv')
# obs_maskp.to_csv('g0g1_missingness_data/missingness_maskp.csv')
# art_maskp.to_csv('g0g1_missingness_data/missingness_artificial_maskp.csv')
# metap.to_csv('g0g1_missingness_data/missingness_metap.csv')

In [8]:
# obs_val0 = pd.read_csv('g0g1_missingness_data/missingness_vals0.csv', index_col=[0])
# obs_mask0 = pd.read_csv('g0g1_missingness_data/missingness_mask0.csv', index_col=[0])
# art_mask0 = pd.read_csv('g0g1_missingness_data/missingness_artificial_mask0.csv', index_col=[0])
# meta0 = pd.read_csv('g0g1_missingness_data/missingness_meta0.csv', index_col=[0])

# obs_val1 = pd.read_csv('g0g1_missingness_data/missingness_vals1.csv', index_col=[0])
# obs_mask1 = pd.read_csv('g0g1_missingness_data/missingness_mask1.csv', index_col=[0])
# art_mask1 = pd.read_csv('g0g1_missingness_data/missingness_artificial_mask1.csv', index_col=[0])
# meta1 = pd.read_csv('g0g1_missingness_data/missingness_meta1.csv', index_col=[0])


In [8]:
obs_valp = pd.read_csv('g0g1_missingness_data/missingness_valsp.csv', index_col=[0])
obs_maskp = pd.read_csv('g0g1_missingness_data/missingness_maskp.csv', index_col=[0])
art_maskp = pd.read_csv('g0g1_missingness_data/missingness_artificial_maskp.csv', index_col=[0])
metap = pd.read_csv('g0g1_missingness_data/missingness_metap.csv', index_col=[0])


In [9]:
obs_val0 = obs_val1 = obs_valp
obs_mask0 = obs_mask1 = obs_maskp
art_mask0 = art_mask1 = art_maskp
meta0 = meta1 = metap


In [13]:
rect_imputer = RectilinearImputer.from_tvx_ehr(tvx0)
rect_preds = rect_imputer.batch_predict(tvx1)

In [16]:
odeicnn_preds = odeicnn.batch_predict(tvx1)

  0%|          | 0/300 [00:00<?, ?it/s]

In [10]:
def predictions_to_dataframe(obs_columns, predictions):
    predictions_df = []
    meta_df = defaultdict(list)
    for subject_id, subject_predictions in predictions.subject_predictions.items():
        for admission_prediction in subject_predictions:
            time = admission_prediction.observables.time
            obs = admission_prediction.observables.value
            meta_df['subject_id'].extend([subject_id] * len(time))
            meta_df['admission_id'].extend([admission_prediction.admission.admission_id] * len(time))
            meta_df['time_index'].extend(range(len(time)))
            meta_df['time'].extend(time.tolist())
            predictions_df.append(obs)

    predictions_df = pd.DataFrame(np.vstack(predictions_df), columns=obs_columns)
    meta_df = pd.DataFrame(meta_df)
    return predictions_df, meta_df


features = list(map(tvx_phantom.scheme.obs.desc.get, tvx_phantom.scheme.obs.codes))




In [21]:
rect_predictions_df, rect_meta_df = predictions_to_dataframe(features, rect_preds)
odeicnn_predictions_df, odeicnn_meta_df = predictions_to_dataframe(features, odeicnn_preds)

rect_predictions_df.to_csv(f'{RESULTS_DIR}/RECTLIN_pred_X_test_imp.csv')
rect_meta_df.to_csv(f'{RESULTS_DIR}/RECTLIN_pred_meta.csv')

odeicnn_predictions_df.to_csv(f'{RESULTS_DIR}/ODEICNN_pred_X_test_imp.csv')
odeicnn_meta_df.to_csv(f'{RESULTS_DIR}/ODEICNN_pred_meta.csv')


(71886, 100)

In [7]:

# db_name = "seg_pretrain_evals.sqlite"

# engine = sqlalchemy.create_engine("sqlite:///%s" % db_name, execution_options={"sqlite_raw_colnames": True},
#                                  connect_args={'timeout': 5})


# df = {name: pd.read_sql_table(name, engine) for name in 
#       ('evaluation_runs', 'evaluation_status', 'experiments', 'metrics', 'results')}
# metrics = df['metrics'].rename(columns={'name': 'metric', 'id': 'metric_id'})
# eval_runs = df['evaluation_runs'].rename(columns={'id': 'evaluation_id'})
# experiments = df['experiments'].rename(columns={'name': 'experiment', 'id': 'experiment_id'})
# eval_status = df['evaluation_status'].rename(columns={'id': 'status_id', 'name': 'status'})

# res = pd.merge(df['results'], metrics, left_on='metric_id', right_on='metric_id', how='left')
# res = pd.merge(res, eval_runs, left_on='evaluation_id', right_on='evaluation_id', how='left')
# res = pd.merge(res, experiments, left_on='experiment_id', right_on='experiment_id', how='left')
# res = pd.merge(res, eval_status, left_on='status_id', right_on='status_id', how='left')
# res['step'] = res.snapshot.str.extract('(\d+)').astype(int)

# res[res.metric.str.startswith('PerColumnObsPredictionLoss.r2')].sort_values('value')

# tvx0.scheme.obs.codes[84]

## Sklearn Imputations

In [12]:
from sklearn.experimental import enable_iterative_imputer  # noqa
from sklearn.impute import IterativeImputer, KNNImputer, SimpleImputer

sklearn_imputers =  {
    'zero_imputer': lambda: SimpleImputer(missing_values=np.nan, add_indicator=False, strategy="constant", fill_value=0),
    'mean_imputer': lambda: SimpleImputer(missing_values=np.nan, add_indicator=False, strategy="mean", fill_value=0),
    'knn_imputer': lambda: KNNImputer(missing_values=np.nan),
    'iter_imputer': lambda: IterativeImputer(
        missing_values=np.nan,
        add_indicator=False,
        random_state=0,
        n_nearest_features=5,
        max_iter=5,
        sample_posterior=True,
    )
}

In [13]:
sklearn_trained_imputers = {k: v().fit(np.where(obs_mask0, obs_val0, np.nan)) 
                            for k, v in sklearn_imputers.items()} 

sklearn_imputed_X = {k: v.transform(np.where(obs_mask1, obs_mask1, np.nan)) for k, v in sklearn_trained_imputers.items()} 
for sklearn_name, imputed_X_ in sklearn_imputed_X.items():
    X_test_imp_df = pd.DataFrame(imputed_X_, columns=obs_val.columns)    
    X_test_imp_df.to_csv(f'{RESULTS_DIR}/{sklearn_name}_pred_X_test_imp.csv')


KeyboardInterrupt: 