In [1]:
# %reload_ext autoreload
# %autoreload 2
from typing import Optional, Tuple, Literal
import os
import sys
import time
import matplotlib.pyplot as plt
import seaborn as sns
import sqlalchemy
import numpy as np
import pandas as pd
import jax
import jax.numpy as jnp
import jax.random as jrandom 
import jax.nn as jnn
import equinox as eqx
import optax
from tqdm.notebook import tqdm
from collections import defaultdict
from functools import partial
from scipy.stats import spearmanr, pearsonr

jax.config.update('jax_enable_x64', True)
jax.config.update('jax_platforms', 'cpu')

# jax.config.update('jax_check_tracer_leaks', True) 
sys.path.append("../../..")
# %reload_ext autoreload
# %autoreload 2
    
from lib.base import Config, Module
from lib.ml.in_models import InICENODELiteICNNImpute
from lib.ml.rectilinear_modules import ZeroImputer, MeanImputer, RectilinearImputer
from lib.ml.experiment import Experiment
from lib.metric.metrics import MetricsCollection
import lib.ehr.example_datasets.mimiciv_aki as m4aki
from lib.ehr.tvx_ehr import TVxEHR
from lib.utils import modified_environ, write_config, append_params_to_zip, zip_members, load_config, translate_path

import pub_ready_plots
from pub_ready_plots import get_mpl_rcParams
rc_params, fig_width_in, fig_height_in = pub_ready_plots.get_mpl_rcParams(
    width_frac=1,  # between 0 and 1
    height_frac=0.2,  # between 0 and 1
    layout="jmlr"  # or "iclr", "neurips", "poster-portrait", "poster-landscape"
)
rc_params['figure.constrained_layout.use'] = True

# rc_params['font.size'] = 10
# rc_params['axes.titlesize'] = 12
# rc_params['axes.labelsize'] = 10
# rc_params['legend.fontsize'] = 10

plt.rcParams.update(rc_params)
# %reload_ext autoreload
# %autoreload 2

In [2]:
EVAL_CONFIG_FILE = '/home/asem/GP/ICENODE/experiment_templates/icu/eval_config.json'
metrics_config = Config.from_dict(load_config(translate_path(EVAL_CONFIG_FILE))).metrics
metrics = MetricsCollection(metrics=tuple(Module.import_module(config=config) for config in metrics_config))

In [3]:
EXP_DIR = 'inicenodeliteicnn_mlpdyn_pretrain_g0'
SNAPSHOT = 'step0340.eqx'
experiment = Experiment(Config.from_dict(load_config(os.path.join(EXP_DIR, 'config.json'))))


In [None]:
tvx0 = TVxEHR.load("/home/asem/GP/ehr-data/mimic4aki-cohort/tvx_aki_0.h5")
tvx1 = TVxEHR.load("/home/asem/GP/ehr-data/mimic4aki-cohort/tvx_aki_1.h5")

In [6]:
def gen_data(tvx):
    
    obs = [adm.observables  for subject in tvx.subjects.values() for adm in subject.admissions]
    adm_id = sum(([adm.admission_id] * len(adm.observables.time)  for subject in tvx.subjects.values() for adm in subject.admissions), [])
    subj_id = sum(([subject.subject_id] * len(adm.observables.time)  for subject in tvx.subjects.values() for adm in subject.admissions), [])
    
    obs_val = np.vstack([obs_i.value for obs_i in obs])
    obs_mask = np.vstack([obs_i.mask for obs_i in obs])
    obs_time = np.hstack([obs_i.time for obs_i in obs])
    obs_time_index = np.hstack([np.arange(len(obs_i.time)) for obs_i in obs])
    
    tvx0.scheme.obs
    features = list(map(tvx.scheme.obs.desc.get, tvx.scheme.obs.codes))
    
    
    obs_val = pd.DataFrame(obs_val, columns=features)
    obs_mask = pd.DataFrame(obs_mask.astype(int), columns=features)
    meta = pd.DataFrame({'subject_id': subj_id, 'admission_id': adm_id, 'time': obs_time, 'time_index': obs_time_index})

    artificial_mask = obs_mask.copy()
    artificial_mask = obs_mask & np.array(jrandom.bernoulli(jrandom.PRNGKey(0), p=0.8, shape=obs_mask.shape))
    return obs_val, obs_mask, artificial_mask, meta

obs_val0, obs_mask0, art_mask0, meta0 = gen_data(tvx0)
obs_val1, obs_mask1, art_mask1, meta1 = gen_data(tvx1)

obs_val0.to_csv('g0g1_missingness_data/missingness_vals0.csv')
obs_mask0.to_csv('g0g1_missingness_data/missingness_mask0.csv')
art_mask0.to_csv('g0g1_missingness_data/missingness_artificial_mask0.csv')
meta0.to_csv('g0g1_missingness_data/missingness_meta0.csv')

obs_val1.to_csv('g0g1_missingness_data/missingness_vals1.csv')
obs_mask1.to_csv('g0g1_missingness_data/missingness_mask1.csv')
art_mask1.to_csv('g0g1_missingness_data/missingness_artificial_mask1.csv')
meta1.to_csv('g0g1_missingness_data/missingness_meta1.csv')


In [8]:
obs_val0 = pd.read_csv('g0g1_missingness_data/missingness_vals0.csv', index_col=[0])
obs_mask0 = pd.read_csv('g0g1_missingness_data/missingness_mask0.csv', index_col=[0])
art_mask0 = pd.read_csv('g0g1_missingness_data/missingness_artificial_mask0.csv', index_col=[0])
meta0 = pd.read_csv('g0g1_missingness_data/missingness_meta0.csv', index_col=[0])

obs_val1 = pd.read_csv('g0g1_missingness_data/missingness_vals1.csv', index_col=[0])
obs_mask1 = pd.read_csv('g0g1_missingness_data/missingness_mask1.csv', index_col=[0])
art_mask1 = pd.read_csv('g0g1_missingness_data/missingness_artificial_mask1.csv', index_col=[0])
meta1 = pd.read_csv('g0g1_missingness_data/missingness_meta1.csv', index_col=[0])


In [10]:
meta1

Unnamed: 0,subject_id,admission_id,time,time_index
0,10010867,22429197,0.666667,0
1,10010867,22429197,0.783333,1
2,10010867,22429197,0.800000,2
3,10010867,22429197,0.816667,3
4,10010867,22429197,0.833333,4
...,...,...,...,...
689982,19995179,22929215,11.583333,15
689983,19995179,22929215,12.583333,16
689984,19995179,22929215,13.583333,17
689985,19995179,22929215,14.583333,18


In [11]:
rect_imputer = RectilinearImputer.from_tvx_ehr(tvx0)

In [12]:
rect_preds = rect_imputer.batch_predict(tvx1)

  0%|          | 0/2957 [00:00<?, ?it/s]

In [7]:

db_name = "seg_pretrain_evals.sqlite"

engine = sqlalchemy.create_engine("sqlite:///%s" % db_name, execution_options={"sqlite_raw_colnames": True},
                                 connect_args={'timeout': 5})


df = {name: pd.read_sql_table(name, engine) for name in 
      ('evaluation_runs', 'evaluation_status', 'experiments', 'metrics', 'results')}
metrics = df['metrics'].rename(columns={'name': 'metric', 'id': 'metric_id'})
eval_runs = df['evaluation_runs'].rename(columns={'id': 'evaluation_id'})
experiments = df['experiments'].rename(columns={'name': 'experiment', 'id': 'experiment_id'})
eval_status = df['evaluation_status'].rename(columns={'id': 'status_id', 'name': 'status'})

res = pd.merge(df['results'], metrics, left_on='metric_id', right_on='metric_id', how='left')
res = pd.merge(res, eval_runs, left_on='evaluation_id', right_on='evaluation_id', how='left')
res = pd.merge(res, experiments, left_on='experiment_id', right_on='experiment_id', how='left')
res = pd.merge(res, eval_status, left_on='status_id', right_on='status_id', how='left')
res['step'] = res.snapshot.str.extract('(\d+)').astype(int)

In [11]:
res[res.metric.str.startswith('PerColumnObsPredictionLoss.r2')].sort_values('value')

In [17]:
tvx0.scheme.obs.codes[84]