## Libs


In [1]:
%load_ext autoreload
%autoreload 2

import sys
import time
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import jax
import jax.numpy as jnp
import jax.random as jrandom 
import equinox as eqx
import optax
from tqdm.notebook import tqdm
from collections import defaultdict
from functools import partial
from sklearn.datasets import fetch_california_housing, load_diabetes

jax.config.update('jax_enable_x64', True)
# jax.config.update('jax_check_tracer_leaks', True) 
sys.path.append("../../..")
from lib.ml.base_models import ICNNObsDecoder
import lib.ehr.example_datasets.mimiciv_aki as m4aki
from lib.ehr.tvx_ehr import TVxEHR
from lib.utils import modified_environ, write_config
 

## Data Loading

In [2]:

rng = np.random.RandomState(42)

X_diabetes, y_diabetes = load_diabetes(return_X_y=True)
X_california, y_california = fetch_california_housing(return_X_y=True)
X_california = X_california[:3000]
y_california = y_california[:3000]
X_diabetes = X_diabetes[:3000]
y_diabetes = y_diabetes[:3000]


def add_missing_values(X_full, y_full):
    n_samples, n_features = X_full.shape

    # Add missing values in 75% of the lines
    missing_rate = 0.75
    n_missing_samples = int(n_samples * missing_rate)

    missing_samples = np.zeros(n_samples, dtype=bool)
    missing_samples[:n_missing_samples] = True

    rng.shuffle(missing_samples)
    missing_features = rng.randint(0, n_features, n_missing_samples)
    X_missing = X_full.copy()
    X_missing[missing_samples, missing_features] = np.nan
    y_missing = y_full.copy()

    return X_missing, y_missing


X_miss_california, y_miss_california = add_missing_values(X_california, y_california)
X_miss_diabetes, y_miss_diabetes = add_missing_values(X_diabetes, y_diabetes)

In [3]:
X_miss_california.shape, X_miss_diabetes.shape

((3000, 8), (442, 10))

### Split

In [4]:
rng = np.random.RandomState(0)

from sklearn.ensemble import RandomForestRegressor

# To use the experimental IterativeImputer, we need to explicitly ask for it:
from sklearn.experimental import enable_iterative_imputer  # noqa
from sklearn.impute import IterativeImputer, KNNImputer, SimpleImputer
from sklearn.model_selection import cross_val_score
from sklearn.pipeline import make_pipeline


In [5]:
def get_scores_for_imputer(imputer, X_missing, X_full):
    M_missing = np.where(np.isnan(X_missing), False, True)
    
    X_full_hat = imputer.fit_transform(X_missing)

    SE = (X_full - X_full_hat)**2

    return np.nanmean(SE, where=~M_missing),  np.nanmean(SE, where=~M_missing, axis=0)




x_labels = []
california_scores = []
diabetes_scores = []


### Replace missing values by 0

In [6]:
zero_imputer =  SimpleImputer(missing_values=np.nan, add_indicator=False, strategy="constant", fill_value=0)
california_scores.append(get_scores_for_imputer(zero_imputer,  X_miss_california, X_california))
diabetes_scores.append(get_scores_for_imputer(zero_imputer,  X_miss_diabetes, X_diabetes))
x_labels.append("Zero imputation")

### kNN-imputation of the missing values

In [7]:
knn_imputer =  KNNImputer(missing_values=np.nan)
california_scores.append(get_scores_for_imputer(knn_imputer,  X_miss_california, X_california))
diabetes_scores.append(get_scores_for_imputer(knn_imputer,  X_miss_diabetes, X_diabetes))
x_labels.append("KNN Imputation")

### Impute missing values with mean

In [8]:
mean_imputer =  SimpleImputer(missing_values=np.nan, add_indicator=False, strategy="mean", fill_value=0)
california_scores.append(get_scores_for_imputer(mean_imputer,  X_miss_california, X_california))
diabetes_scores.append(get_scores_for_imputer(mean_imputer,  X_miss_diabetes, X_diabetes))
x_labels.append("Mean Imputation")

### Iterative Imputer

In [9]:
iter_imputer = IterativeImputer(
        missing_values=np.nan,
        add_indicator=False,
        random_state=0,
        n_nearest_features=5,
        max_iter=5,
        sample_posterior=True,
    )
california_scores.append(get_scores_for_imputer(iter_imputer,  X_miss_california, X_california))
diabetes_scores.append(get_scores_for_imputer(iter_imputer,  X_miss_diabetes, X_diabetes))

x_labels.append("Iterative Imputation")

## Model Configuration

In [10]:
claifornia_imputer = ICNNObsDecoder(observables_size=X_miss_california.shape[1], state_size=0,  hidden_size_multiplier=3, depth=10, key=jrandom.PRNGKey(0))
diabetes_imputer = ICNNObsDecoder(observables_size=X_miss_diabetes.shape[1], state_size=0,  hidden_size_multiplier=3, depth=10, key=jrandom.PRNGKey(0))

2024-06-25 20:44:39.721573: W external/xla/xla/service/gpu/nvptx_compiler.cc:718] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.4.131). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


## Training

In [11]:
def mse(X, X_hat, M, axis=None):
    return jnp.mean(jnp.square(X - X_hat), where=M, axis=axis)

@eqx.filter_jit
def loss(model, batch_X, batch_M, batch_M_art):
    # Zero for artificially missig values
    batch_X_art = jnp.where(batch_M_art, batch_X, 0.)
    # Tune for artificially masked-out values, fix mask-in (batch_M_art) values.
    X_imp, aux = eqx.filter_vmap(model.partial_input_optimise)(batch_X_art, batch_M_art)
    # Penalise discrepancy with artifially masked-out values.
    return mse(batch_X, X_imp, (~batch_M_art) & batch_M), aux

@eqx.filter_jit
def mean_imputer_loss(batch_X, batch_M, batch_M_art):
    return mse(batch_X, jnp.nanmean(batch_X), (~batch_M_art) & batch_M)

    
@eqx.filter_value_and_grad(has_aux=True)
def loss_grad(model, batch_X, batch_M, batch_M_art):
    return loss(model, batch_X, batch_M, batch_M_art)


@eqx.filter_jit
def make_step(model, optim, opt_state, batch_X, batch_M, batch_M_art):
    (loss, aux), grads = loss_grad(model, batch_X, batch_M, batch_M_art)
    updates, opt_state = optim.update(grads, opt_state, params=model)
    model = eqx.apply_updates(model, updates)
    return (loss, aux), model, opt_state

def dataloader(arrays, batch_size, *, key):
    dataset_size = arrays[0].shape[0]
    indices = jnp.arange(dataset_size)
    while True:
        perm = jrandom.permutation(key, indices)
        (key,) = jrandom.split(key, 1)
        start = 0
        end = batch_size
        while end < dataset_size:
            batch_perm = perm[start:end]
            yield tuple(array[batch_perm] for array in arrays)
            start = end
            end = start + batch_size


def train_imputer(model, X_missing, lr=1e-3, steps=1000, train_batch_size=8, test_batch_size=8, eval_frequency = 10):
    optim = optax.adam(lr)
    opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
    split=0.7
    key = jrandom.PRNGKey(0)
    mask = jnp.array(np.where(np.isnan(X_missing), False, True))
    X_missing = jnp.nan_to_num(X_missing, 0.0)
    train_idx, test_idx = jnp.split(jrandom.permutation(key, len(X_missing)), [int(split * len(X_missing))])
    data_train = (X_missing[train_idx], mask[train_idx])
    data_test = (X_missing[test_idx], mask[test_idx])

    train_batches = dataloader(data_train, train_batch_size, key=key)
    test_batches = iter(dataloader(data_test, train_batch_size, key=key))
    

    progress = tqdm(range(steps))
    train_history = defaultdict(list)
    test_history = defaultdict(list)
    
    for step, batch_train in zip(progress, train_batches):
        start = time.time()
        batch_X, batch_M = batch_train
        batch_M_art = batch_M & jrandom.bernoulli(key, p=0.8, shape=batch_M.shape)
        
        (key, ) = jrandom.split(key, 1)
        (train_loss, train_aux), model, opt_state = make_step(model, optim, opt_state, batch_X, batch_M, batch_M_art)
        train_mloss =  mean_imputer_loss(batch_X, batch_M, batch_M_art)
        train_nsteps = int(sum(train_aux.n_steps) / len(train_aux.n_steps))
        train_history['mloss'].append(train_mloss)
        train_history['loss'].append(train_loss)
        train_history['n_opt_steps'].append(train_nsteps)
        
        end = time.time()
        if (step % eval_frequency) == 0 or step == steps - 1:
            batch_test = next(test_batches)
            test_batch_X, test_batch_M = batch_train
            test_batch_M_art = test_batch_M & jrandom.bernoulli(key, p=0.8, shape=test_batch_M.shape)
            test_loss, aux = loss(model, test_batch_X, test_batch_M, test_batch_M_art)
            test_mloss = mean_imputer_loss(test_batch_X, test_batch_M, test_batch_M_art)
            nsteps = int(sum(aux.n_steps) / len(aux.n_steps))
            test_history['mloss'].append(test_mloss)
            test_history['loss'].append(test_loss)
            test_history['n_opt_steps'].append(nsteps)
            
        progress.set_description(f"Trn-L: {train_loss:.3f}, Trn-M-L: {train_mloss: .3f}, Tst N-steps: {train_nsteps}, " 
                                 f"Tst-L: {test_loss:.3f}, Tst-M-L: {test_mloss:.3f}, Tst N-steps: {nsteps}, Computation time: {end - start:.2f}, ")
    return model, train_history, test_history

In [12]:
claifornia_imputer, cal_train_history, cal_test_history = train_imputer(claifornia_imputer, X_missing=X_miss_california, 
                                                    lr=5e-4, steps=1500, train_batch_size=64, test_batch_size=64, eval_frequency = 10)
diabetes_imputer, diab_train_history, diab_test_history = train_imputer(diabetes_imputer, X_missing=X_miss_diabetes,
                                                      lr=5e-4, steps=1500, train_batch_size=8, test_batch_size=8, eval_frequency = 10)


  0%|          | 0/1500 [00:00<?, ?it/s]

jax.pure_callback failed
Traceback (most recent call last):
  File "/home/asem/GP/env/icenode-dev/lib/python3.11/site-packages/jax/_src/callback.py", line 77, in pure_callback_impl
    return callback(*args)
           ^^^^^^^^^^^^^^^
  File "/home/asem/GP/env/icenode-dev/lib/python3.11/site-packages/jax/_src/callback.py", line 65, in __call__
    return tree_util.tree_leaves(self.callback_func(*args, **kwargs))
                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/asem/GP/env/icenode-dev/lib/python3.11/site-packages/equinox/_errors.py", line 70, in raises
    raise EqxRuntimeError(msgs[_index.item()])
equinox._errors.EqxRuntimeError: The linear solver returned non-finite (NaN or inf) output. This usually means that the
operator was not well-posed, and that the solver does not support this.

If you are trying solve a linear least-squares problem then you should pass
`solver=AutoLinearSolver(well_posed=False)`. By default `lineax.linear_solve`
assumes that t

XlaRuntimeError: The linear solver returned non-finite (NaN or inf) output. This usually means that the
operator was not well-posed, and that the solver does not support this.

If you are trying solve a linear least-squares problem then you should pass
`solver=AutoLinearSolver(well_posed=False)`. By default `lineax.linear_solve`
assumes that the operator is square and nonsingular.

If you *were* expecting this solver to work with this operator, then it may be because:

(a) the operator is singular, and your code has a bug; or

(b) the operator was nearly singular (i.e. it had a high condition number:
    `jnp.linalg.cond(operator.as_matrix())` is large), and the solver suffered from
    numerical instability issues; or

(c) the operator is declared to exhibit a certain property (e.g. positive definiteness)
    that is does not actually satisfy.
-------
This error occurred during the runtime of your JAX program. Setting the environment
variable `EQX_ON_ERROR=breakpoint` is usually the most useful way to debug such errors.
(This can be navigated using most of the usual commands for the Python debugger:
`u` and `d` to move through stack frames, the name of a variable to print its value,
etc.) See also `https://docs.kidger.site/equinox/api/errors/#equinox.error_if` for more
information.


In [None]:
cal_train_res = pd.DataFrame(cal_train_history)
diab_train_res = pd.DataFrame(diab_train_history)


In [None]:
(cal_train_res.mloss > cal_train_res.loss).mean()

In [None]:
(diab_train_res.mloss > diab_train_res.loss).mean()

In [None]:
diab_train_res

In [None]:
def get_scores_for_icnn_imputer(imputer, X_missing, X_full):
    M_missing = jnp.array(np.where(np.isnan(X_missing), False, True))
    X_missing = jnp.nan_to_num(X_missing)
    X_full_hat, aux = eqx.filter_vmap(imputer.partial_input_optimise)(X_missing, M_missing)
    SE = (X_full - X_full_hat)**2
    return np.nanmean(SE, where=~M_missing), np.nanmean(SE, where=~M_missing, axis=0)

In [None]:
california_scores.append(get_scores_for_icnn_imputer(claifornia_imputer,  X_miss_california, X_california))
diabetes_scores.append(get_scores_for_icnn_imputer(diabetes_imputer,  X_miss_diabetes, X_diabetes))
x_labels.append("ICNN Imputation")

In [None]:
california_scores, california_scores_per_feature = zip(*california_scores)
diabetes_scores, diabetes_scores_per_feature = zip(*diabetes_scores)


In [None]:
results = pd.DataFrame({'Diabetes': diabetes_scores, #'California': california_scores, 
                        'Method': x_labels})
results

In [None]:
cal_features = [f'feature_{i}' for i in range(len(california_scores_per_feature[0]))]
diab_features = [f'feature_{i}' for i in range(len(diabetes_scores_per_feature[0]))]
cal_features_res = pd.DataFrame(dict(zip(x_labels, california_scores_per_feature)), index=cal_features)

In [None]:
cal_features_res