## Libs


In [1]:
%load_ext autoreload
%autoreload 2

import sys
import time
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import jax
import jax.numpy as jnp
import jax.random as jrandom 
import equinox as eqx
import optax
from tqdm.notebook import tqdm
from collections import defaultdict
from functools import partial
from sklearn.datasets import fetch_california_housing, load_diabetes

jax.config.update('jax_enable_x64', True)
# jax.config.update('jax_check_tracer_leaks', True) 
sys.path.append("../../..")
from lib.ml.base_models import ICNNObsDecoder
import lib.ehr.example_datasets.mimiciv_aki as m4aki
from lib.ehr.tvx_ehr import TVxEHR
from lib.utils import modified_environ, write_config
 

## Data Loading

In [2]:

rng = np.random.RandomState(42)

X_diabetes, y_diabetes = load_diabetes(return_X_y=True)
X_california, y_california = fetch_california_housing(return_X_y=True)
X_california = X_california[:3000]
y_california = y_california[:3000]
X_diabetes = X_diabetes[:3000]
y_diabetes = y_diabetes[:3000]


def add_missing_values(X_full, y_full):
    n_samples, n_features = X_full.shape

    # Add missing values in 75% of the lines
    missing_rate = 0.75
    n_missing_samples = int(n_samples * missing_rate)

    missing_samples = np.zeros(n_samples, dtype=bool)
    missing_samples[:n_missing_samples] = True

    rng.shuffle(missing_samples)
    missing_features = rng.randint(0, n_features, n_missing_samples)
    X_missing = X_full.copy()
    X_missing[missing_samples, missing_features] = np.nan
    y_missing = y_full.copy()

    return X_missing, y_missing


X_miss_california, y_miss_california = add_missing_values(X_california, y_california)
X_miss_diabetes, y_miss_diabetes = add_missing_values(X_diabetes, y_diabetes)

In [3]:
X_miss_california.shape, X_miss_diabetes.shape

((3000, 8), (442, 10))

### Split

In [4]:
rng = np.random.RandomState(0)

from sklearn.ensemble import RandomForestRegressor

# To use the experimental IterativeImputer, we need to explicitly ask for it:
from sklearn.experimental import enable_iterative_imputer  # noqa
from sklearn.impute import IterativeImputer, KNNImputer, SimpleImputer
from sklearn.model_selection import cross_val_score
from sklearn.pipeline import make_pipeline


In [5]:
def get_scores_for_imputer(imputer, X_missing, X_full):
    M_missing = np.where(np.isnan(X_missing), False, True)
    
    X_full_hat = imputer.fit_transform(X_missing)

    SE = (X_full - X_full_hat)**2

    return np.nanmean(SE, where=~M_missing)




x_labels = []
california_scores = []
diabetes_scores = []


### Replace missing values by 0

In [6]:
zero_imputer =  SimpleImputer(missing_values=np.nan, add_indicator=False, strategy="constant", fill_value=0)
california_scores.append(get_scores_for_imputer(zero_imputer,  X_miss_california, X_california))
diabetes_scores.append(get_scores_for_imputer(zero_imputer,  X_miss_diabetes, X_diabetes))

x_labels.append("Zero imputation")

### kNN-imputation of the missing values

In [7]:

knn_imputer =  KNNImputer(missing_values=np.nan)
california_scores.append(get_scores_for_imputer(knn_imputer,  X_miss_california, X_california))
diabetes_scores.append(get_scores_for_imputer(knn_imputer,  X_miss_diabetes, X_diabetes))



x_labels.append("KNN Imputation")

### Impute missing values with mean

In [8]:
mean_imputer =  SimpleImputer(missing_values=np.nan, add_indicator=False, strategy="mean", fill_value=0)
california_scores.append(get_scores_for_imputer(mean_imputer,  X_miss_california, X_california))
diabetes_scores.append(get_scores_for_imputer(mean_imputer,  X_miss_diabetes, X_diabetes))


x_labels.append("Mean Imputation")

### Iterative Imputer

In [40]:
iter_imputer = IterativeImputer(
        missing_values=np.nan,
        add_indicator=False,
        random_state=0,
        n_nearest_features=5,
        max_iter=25,
        sample_posterior=True,
    )
california_scores.append(get_scores_for_imputer(iter_imputer,  X_miss_california, X_california))
diabetes_scores.append(get_scores_for_imputer(iter_imputer,  X_miss_diabetes, X_diabetes))

x_labels.append("Iterative Imputation")

In [41]:

results = pd.DataFrame({'Diabetes': diabetes_scores, 'California': california_scores, 'Method': x_labels})

In [42]:
results

Unnamed: 0,Diabetes,California,Method
0,0.002004,261858.315676,Zero imputation
1,0.001222,87331.981458,KNN Imputation
2,0.002029,84315.71887,Mean Imputation
3,0.003351,325368.340942,Iterative Imputation
4,0.002578,188718.492461,Iterative Imputation
5,0.00265,169989.546168,Iterative Imputation
6,0.002966,168304.38818,Iterative Imputation
7,0.002165,184039.414123,Iterative Imputation
8,0.001842,497378.765971,Iterative Imputation
9,0.003794,195866.006261,Iterative Imputation


## Model Configuration

In [12]:
model = ICNNObsDecoder(observables_size=obs_mask.shape[1], state_size=0, 
                       hidden_size_multiplier=3, depth=4, key=jrandom.PRNGKey(0))


NameError: name 'obs_mask' is not defined

In [12]:
model

ICNNObsDecoder(
  f_energy=ICNN(
    Wzs=(
      PositiveSquaredLinear(
        weight=f64[300,100],
        bias=f64[300],
        in_features=100,
        out_features=300,
        use_bias=True
      ),
      PositiveSquaredLinear(
        weight=f64[300,300],
        bias=None,
        in_features=300,
        out_features=300,
        use_bias=False
      ),
      PositiveSquaredLinear(
        weight=f64[300,300],
        bias=None,
        in_features=300,
        out_features=300,
        use_bias=False
      ),
      PositiveSquaredLinear(
        weight=f64[300,300],
        bias=None,
        in_features=300,
        out_features=300,
        use_bias=False
      ),
      PositiveSquaredLinear(
        weight=f64[300,300],
        bias=None,
        in_features=300,
        out_features=300,
        use_bias=False
      ),
      PositiveSquaredLinear(
        weight=f64[300,300],
        bias=None,
        in_features=300,
        out_features=300,
        use_bias=False
   

In [13]:
model.f_energy

ICNN(
  Wzs=(
    PositiveSquaredLinear(
      weight=f64[300,100],
      bias=f64[300],
      in_features=100,
      out_features=300,
      use_bias=True
    ),
    PositiveSquaredLinear(
      weight=f64[300,300],
      bias=None,
      in_features=300,
      out_features=300,
      use_bias=False
    ),
    PositiveSquaredLinear(
      weight=f64[300,300],
      bias=None,
      in_features=300,
      out_features=300,
      use_bias=False
    ),
    PositiveSquaredLinear(
      weight=f64[300,300],
      bias=None,
      in_features=300,
      out_features=300,
      use_bias=False
    ),
    PositiveSquaredLinear(
      weight=f64[300,300],
      bias=None,
      in_features=300,
      out_features=300,
      use_bias=False
    ),
    PositiveSquaredLinear(
      weight=f64[300,300],
      bias=None,
      in_features=300,
      out_features=300,
      use_bias=False
    ),
    PositiveSquaredLinear(
      weight=f64[300,300],
      bias=None,
      in_features=300,
      out_fea

## Training

In [14]:
def mse(X, X_hat, M, axis=None):
    return jnp.mean((X - X_hat)**2, where=M, axis=axis)

@eqx.filter_jit
def loss(model, batch_X, batch_M, batch_M_art):
    # Zero for artificially missig values
    batch_X_art = jnp.where(batch_M_art, batch_X, 0.)
    # Tune for artificially masked-out values, fix mask-in (batch_M_art) values.
    X_imp, aux = eqx.filter_vmap(model.partial_input_optimise)(batch_X_art, batch_M_art)
    # Penalise discrepancy with artifially masked-out values.
    return mse(batch_X, X_imp, (~batch_M_art) & batch_M), aux


@eqx.filter_jit
def zero_imputer_loss(batch_X, batch_M, batch_M_art):
    return mse(batch_X, jnp.zeros_like(batch_X), (~batch_M_art) & batch_M)

@eqx.filter_jit
def weighted_loss(model, batch_X, batch_M, batch_M_art, weights):
    # Zero for artificially missig values
    batch_X_art = jnp.where(batch_M_art, batch_X, 0.)
    # Tune for artificially masked-out values, fix mask-in (batch_M_art) values.
    X_imp, aux = eqx.filter_vmap(model.partial_input_optimise)(batch_X_art, batch_M_art)
    # Penalise discrepancy with artifially masked-out values.
    loss_vec = mse(batch_X, X_imp, (~batch_M_art) & batch_M, axis=0)
    unweighted_loss = mse(batch_X, X_imp, (~batch_M_art) & batch_M)
    weighted_loss_ = jnp.nansum(loss_vec * weights) / jnp.sum(weights) 
    return weighted_loss_, unweighted_loss, aux
    
@eqx.filter_value_and_grad(has_aux=True)
def loss_grad(model, batch_X, batch_M, batch_M_art):
    return loss(model, batch_X, batch_M, batch_M_art)


@eqx.filter_jit
def make_step(model, opt_state, batch_X, batch_M, batch_M_art):
    (loss, aux), grads = loss_grad(model, batch_X, batch_M, batch_M_art)
    updates, opt_state = optim.update(grads, opt_state, params=model)
    model = eqx.apply_updates(model, updates)
    return (loss, aux), model, opt_state

def dataloader(arrays, batch_size, *, key):
    dataset_size = arrays[0].shape[0]
    indices = jnp.arange(dataset_size)
    while True:
        perm = jrandom.permutation(key, indices)
        (key,) = jrandom.split(key, 1)
        start = 0
        end = batch_size
        while end < dataset_size:
            batch_perm = perm[start:end]
            yield tuple(array[batch_perm] for array in arrays)
            start = end
            end = start + batch_size


lr=1e-3
steps=1000000
train_batch_size=256
test_batch_size=512
# train_batch_size=1
# test_batch_size=1
eval_frequency = 10

optim = optax.adadelta(lr)
opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
data_train = (obs_val_train, obs_mask_train, art_mask_train)
data_test = (obs_val_test, obs_mask_test, art_mask_test)



In [None]:
train_batches = dataloader(data_train, train_batch_size, key=jrandom.PRNGKey(0))
test_batches = iter(dataloader(data_test, train_batch_size, key=jrandom.PRNGKey(0)))

progress = tqdm(range(steps))
train_history = defaultdict(list)
test_history = defaultdict(list)

for step, batch_train in zip(progress, train_batches):
    start = time.time()
    (train_loss, train_aux), model, opt_state = make_step(model, opt_state, *batch_train)
    train_zloss =  zero_imputer_loss(*batch_train)
    train_nsteps = int(sum(train_aux.n_steps) / len(train_aux.n_steps))
    train_history['zloss'].append(train_zloss)
    train_history['loss'].append(train_loss)
    train_history['n_opt_steps'].append(train_nsteps)
    
    end = time.time()
    if (step % eval_frequency) == 0 or step == steps - 1:
        batch_test = next(test_batches)
        test_wloss, test_loss, aux = weighted_loss(model, *batch_test, weights=art_mask_train.mean(axis=0))
        test_zloss = zero_imputer_loss(*batch_test)
        nsteps = int(sum(aux.n_steps) / len(aux.n_steps))
        test_history['zloss'].append(test_zloss)
        test_history['loss'].append(test_loss)
        test_history['wloss'].append(test_wloss)
        test_history['n_opt_steps'].append(nsteps)
        
    progress.set_description(f"Trn-L: {train_loss:.3f}, Trn-Z-L: {train_zloss: .3f}, Tst N-steps: {train_nsteps}, " 
                             f"Tst-L: {test_loss:.3f}, Tst-W-L: {test_wloss:.3f}, Tst-Z-L: "
                             f"{test_zloss:.3f}, Tst N-steps: {nsteps}, Computation time: {end - start:.2f}, ")

  0%|          | 0/1000000 [00:00<?, ?it/s]

In [None]:
train_stats = pd.DataFrame(train_history)

In [None]:
(train_stats.zloss > train_stats.loss).mean()