## Libs


In [1]:
%load_ext autoreload
%autoreload 2
from typing import Tuple
import sys
import time
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import jax
import jax.numpy as jnp
import jax.random as jr
import equinox as eqx
import optax
from tqdm.notebook import tqdm
from collections import defaultdict
from functools import partial
import sklearn.neural_network as sknn
from sklearn.datasets import fetch_california_housing, load_diabetes
from sklearn.ensemble import RandomForestRegressor
from sklearn.experimental import enable_iterative_imputer  # noqa
from sklearn.impute import IterativeImputer, KNNImputer, SimpleImputer
from sklearn.model_selection import cross_val_score
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler


jax.config.update('jax_enable_x64', True)
jax.config.update('jax_platforms', 'cpu')
# jax.config.update('jax_check_tracer_leaks', True) 
sys.path.append("../../..")
from lib.ml.icnn_modules import ProbICNNImputerTrainer
import lib.ehr.example_datasets.mimiciv_aki as m4aki
from lib.ehr.tvx_ehr import TVxEHR
from lib.utils import modified_environ, write_config
 

## Data Loading

In [2]:
X_diabetes_full, y_diabetes = load_diabetes(return_X_y=True)
X_california_full, y_california = fetch_california_housing(return_X_y=True)

def add_missingness(X_full: jnp.ndarray, key: jr.PRNGKey, p: float = 0.8) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    mask = np.array(jr.bernoulli(key, p=p, shape=X_full.shape))
    X = jnp.where(mask, X_full, jnp.nan)
    return X, mask


X_california, M_california = add_missingness(X_california_full, jr.PRNGKey(0), 0.85)
X_diabetes, M_diabetes = add_missingness(X_california_full, jr.PRNGKey(0), 0.85)

### Split

In [3]:
optax_optimisers = {
    'adam': optax.adam,
    'polyak_sgd': optax.polyak_sgd,
    'novograd': optax.novograd,
    'lamb': optax.lamb,
    'yogi': optax.yogi,
}


imputers =  {
    'zero_imputer': lambda: SimpleImputer(missing_values=np.nan, add_indicator=False, strategy="constant", fill_value=0),
    'mean_imputer': lambda: SimpleImputer(missing_values=np.nan, add_indicator=False, strategy="mean", fill_value=0),
    'knn_imputer': lambda: KNNImputer(missing_values=np.nan),
    'iter_imputer': lambda: IterativeImputer(
        missing_values=np.nan,
        add_indicator=False,
        random_state=0,
        n_nearest_features=5,
        max_iter=5,
        sample_posterior=True,
    ),
    'icnn_stacked_lognormal': lambda: ProbICNNImputerTrainer(steps=5000, icnn_model_name='stacked', loss='log_normal', artificial_missingness=0.8),
    'icnn_stacked_kl': lambda: ProbICNNImputerTrainer(steps=5000, icnn_model_name='stacked', loss='kl_divergence', artificial_missingness=0.8),
    'icnn_staged_lognormal': lambda: ProbICNNImputerTrainer(steps=5000, icnn_model_name='staged', loss='log_normal', artificial_missingness=0.8),
    'icnn_staged_kl': lambda: ProbICNNImputerTrainer(steps=5000, icnn_model_name='staged', loss='kl_divergence', artificial_missingness=0.8),
    'icnn_stacked_lognormal_sq': lambda: ProbICNNImputerTrainer(steps=5000, icnn_model_name='stacked', loss='log_normal', artificial_missingness=0.8, icnn_positivity='squared'),
    'icnn_stacked_kl_sq': lambda: ProbICNNImputerTrainer(steps=5000, icnn_model_name='stacked', loss='kl_divergence', artificial_missingness=0.8, icnn_positivity='squared'),
    'icnn_staged_lognormal_sq': lambda: ProbICNNImputerTrainer(steps=5000, icnn_model_name='staged', loss='log_normal', artificial_missingness=0.8, icnn_positivity='squared'),
    'icnn_staged_kl_sq': lambda: ProbICNNImputerTrainer(steps=5000, icnn_model_name='staged', loss='kl_divergence', artificial_missingness=0.8, icnn_positivity='squared'),
}



## Imputation-only Performance

In [4]:
diabetes_trained_imputer = {k: v().fit(X_diabetes) for k, v in imputers.items()} 
california_trained_imputer = {k: v().fit(X_california) for k, v in imputers.items()} 

  0%|          | 0/5000 [00:00<?, ?it/s]

ValueError: `where` does not uniquely identify a single element of `pytree`. This usually occurs when trying to replace a `None` value:

  >>> eqx.tree_at(lambda t: t[0], (None, None, 1), True)


for which the fix is to specify that `None`s should be treated as leaves:

  >>> eqx.tree_at(lambda t: t[0], (None, None, 1), True,
  ...             is_leaf=lambda x: x is None)

In [None]:
def per_feature_imputer_performance(imputer, X_full: jnp.ndarray, mask: jnp.ndarray):
    X_missing = np.where(mask, X_full, np.nan)
    X_imputed = imputer.transform(X_missing)
    r2 = np.empty(X_full.shape[1]) + float('nan')
    for i in range(X_full.shape[1]):
        m = mask[:, i]
        y_hat = X_imputed[:, i]
        y = X_full[:, i]
        r2[i] = r_squared(y[~m], y_hat[~m])
    return r2

In [None]:
diabetes_r2_scores = {k: per_feature_imputer_performance(v, X_diabetes_full, M_diabetes) for k, v in diabetes_trained_imputer.items() }

In [None]:
california_r2_scores = {k: per_feature_imputer_performance(v, X_california_full, M_california) for k, v in california_trained_imputer.items() }