## Libs


In [1]:
%load_ext autoreload
%autoreload 2
from typing import Optional, Tuple, Literal

import sys
import time
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import jax
import jax.numpy as jnp
import jax.random as jrandom 
import jax.nn as jnn
import equinox as eqx
import optax
from tqdm.notebook import tqdm
from collections import defaultdict
from functools import partial

jax.config.update('jax_enable_x64', True)
# jax.config.update('jax_check_tracer_leaks', True) 
sys.path.append("../../..")
from lib.ml.base_models import ICNNObsDecoder, ImputerMetrics, ICNN
import lib.ehr.example_datasets.mimiciv_aki as m4aki
from lib.ehr.tvx_ehr import TVxEHR
from lib.utils import modified_environ, write_config
 

In [2]:
class ProbICNNObsDecoder(eqx.Module):
    icnn_mean: ICNNObsDecoder
    icnn_var: ICNNObsDecoder

    def __init__(self, observables_size: int, state_size: int, hidden_size_multiplier: float,
                 depth: int,
                 optax_optimiser_name: Literal['adam', 'polyak_sgd', 'lamb', 'yogi'] = 'adam', *,
                 key: jrandom.PRNGKey):
        key_mu, key_sigma = jrandom.split(key, 2)
        self.icnn_mean = ICNNObsDecoder(observables_size, state_size, hidden_size_multiplier, depth,
                                        optax_optimiser_name,
                                        key=key_mu)
        self.icnn_var = ICNNObsDecoder(observables_size * 2, state_size, hidden_size_multiplier // 2, depth,
                                       optax_optimiser_name,
                                       key=key_sigma)

    @eqx.filter_jit
    def prob_partial_input_optimise(self, input: jnp.ndarray, fixed_mask: jnp.ndarray) -> Tuple[
        Tuple[jnp.ndarray, jnp.ndarray], ImputerMetrics]:
        mu, metrics = self.icnn_mean.partial_input_optimise(input, fixed_mask)

        mu_std, _ = self.icnn_var.partial_input_optimise(jnp.hstack((mu, jnp.where(fixed_mask, -4., 10.))),
                                                         jnp.hstack((jnp.ones_like(mu), fixed_mask)))
        mu, std = jnp.hsplit(mu_std, 2)
        std = jnn.softplus(std)
        return (mu, std), metrics

    @eqx.filter_jit
    def partial_input_optimise(self, input: jnp.ndarray, fixed_mask: jnp.ndarray) -> Tuple[jnp.ndarray, ImputerMetrics]:
        (mu, _), metrics = self.prob_partial_input_optimise(input, fixed_mask)
        return mu, metrics




class ProbStackedICNNImputer(ICNNObsDecoder):
    f_energy: ICNN

    def __init__(self, observables_size: int, hidden_size_multiplier: float, depth: int,
                 optax_optimiser_name: Literal['adam', 'polyak_sgd', 'lamb', 'yogi'] = 'adam', *,
                 key: jrandom.PRNGKey):
        super().__init__(observables_size=observables_size*2, state_size=0, hidden_size_multiplier=hidden_size_multiplier,
                        depth=depth, optax_optimiser_name=optax_optimiser_name, key=key)


    @eqx.filter_jit
    def prob_partial_input_optimise(self, input: jnp.ndarray, fixed_mask: jnp.ndarray) -> Tuple[
        Tuple[jnp.ndarray, jnp.ndarray], ImputerMetrics]:
        mu_std, metrics = super().partial_input_optimise(jnp.hstack((input, jnp.where(fixed_mask, -4., 10.))), 
                                                         jnp.hstack((fixed_mask, fixed_mask)))
        mu, std = jnp.hsplit(mu_std, 2)
        std = jnn.softplus(std)
        return (mu, std), metrics
        
    @eqx.filter_jit
    def partial_input_optimise(self, input: jnp.ndarray, fixed_mask: jnp.ndarray) -> Tuple[jnp.ndarray, ImputerMetrics]:
        (mu, _), metrics = self.prob_partial_input_optimise(input, fixed_mask)
        return mu, metrics





## Data Loading

### First Time Loading

In [3]:
# tvx = m4aki.TVxAKIMIMICIVDataset.load('/home/asem/GP/ehr-data/mimic4aki-cohort/tvx_aki.h5')

In [4]:
# obs = [adm.observables  for subject in tvx0.subjects.values() for adm in subject.admissions]
# adm_id = sum(([adm.admission_id] * len(adm.observables.time)  for subject in tvx0.subjects.values() for adm in subject.admissions), [])
# subj_id = sum(([subject.subject_id] * len(adm.observables.time)  for subject in tvx0.subjects.values() for adm in subject.admissions), [])

In [5]:
# obs_val = np.vstack([obs_i.value for obs_i in obs])
# obs_mask = np.vstack([obs_i.mask for obs_i in obs])
# obs_time = np.hstack([obs_i.time for obs_i in obs])

In [6]:
# tvx0.scheme.obs
# features = list(map(tvx0.scheme.obs.desc.get, tvx0.scheme.obs.codes))

In [7]:
# obs_val = pd.DataFrame(obs_val, columns=features)
# obs_mask = pd.DataFrame(obs_mask.astype(int), columns=features)
# meta = pd.DataFrame({'subject_id': subj_id, 'admission_id': adm_id, 'time': obs_time})


In [8]:
# artificial_mask = obs_mask.copy()
# artificial_mask = obs_mask & np.array(jrandom.bernoulli(jrandom.PRNGKey(0), p=0.5, shape=obs_mask.shape))


In [9]:
# obs_val.to_csv('missingness_data/missingness_vals.csv')
# obs_mask.to_csv('missingness_data/missingness_mask.csv')
# meta.to_csv('missingness_data/meta.csv')
# artificial_mask.to_csv('missingness_data/missingness_artificial_mask.csv')


### Later Loading

In [None]:
obs_val = pd.read_csv('missingness_data/missingness_vals.csv', index_col=[0])
obs_mask = pd.read_csv('missingness_data/missingness_mask.csv', index_col=[0])
artificial_mask = pd.read_csv('missingness_data/missingness_artificial_mask.csv', index_col=[0])
meta = pd.read_csv('missingness_data/meta.csv', index_col=[0])


### Split

In [None]:
split_ratio = 0.7
seed = 0
indices = jrandom.permutation(jrandom.PRNGKey(seed), len(obs_val))
train_idx = indices[:int(split_ratio * len(indices))]
test_idx = indices[int(split_ratio * len(indices)):]

obs_val_train = jnp.array(obs_val.iloc[train_idx].to_numpy())
obs_mask_train = jnp.array(obs_mask.iloc[train_idx].to_numpy())
art_mask_train =  jnp.array(artificial_mask.iloc[train_idx].to_numpy())

obs_val_test = jnp.array(obs_val.iloc[test_idx].to_numpy())
obs_mask_test = jnp.array(obs_mask.iloc[test_idx].to_numpy())
art_mask_test =  jnp.array(artificial_mask.iloc[test_idx].to_numpy())

## Model Configuration

In [None]:
model = ICNNObsDecoder(observables_size=obs_mask.shape[1], state_size=0, 
                       optax_optimiser_name='polyak_sgd',
                       hidden_size_multiplier=2, depth=4, key=jrandom.PRNGKey(0))

p_model = ProbStackedICNNImputer(observables_size=obs_mask.shape[1],
                             optax_optimiser_name='polyak_sgd',
                             hidden_size_multiplier=2, depth=4, key=jrandom.PRNGKey(0))


## Training

In [None]:
@eqx.filter_jit
def gaussian_kl(y: Tuple[jnp.ndarray, jnp.ndarray], y_hat: Tuple[jnp.ndarray, jnp.ndarray],
                mask: Optional[jnp.ndarray] = None, axis: Optional[int] = None) -> jnp.ndarray:
    """KL divergence between two Gaussian distributions."""
    mean, std = y
    mean_hat, std_hat = y_hat
    kl = jnp.log(std) - jnp.log(std_hat) + (std_hat ** 2 + (mean - mean_hat) ** 2) / (2 * (std ** 2)) - 0.5
    return jnp.nanmean(kl, axis=axis, where=mask)


@eqx.filter_jit
def log_normal(y: Tuple[jnp.ndarray, jnp.ndarray], y_hat: Tuple[jnp.ndarray, jnp.ndarray],
               mask: Optional[jnp.ndarray] = None, axis: Optional[int] = None) -> jnp.ndarray:
    """Log-normal loss."""
    mean, _ = y
    mean_hat, std_hat = y_hat
    error = (mean - mean_hat) / (std_hat + 1e-6)
    log_normal_loss = 0.5 * (error ** 2 + 2 * jnp.log(std_hat + 1e-6))
    return jnp.mean(log_normal_loss, axis=axis, where=mask)


@eqx.filter_jit
def r_squared(y: jnp.ndarray, y_hat: jnp.ndarray, mask: jnp.ndarray) -> jnp.ndarray:
    y = y.squeeze()
    y_hat = y_hat.squeeze()
    mask = mask.squeeze()


    y_bar = jnp.nanmean(y, where=mask)
    ss_tot = jnp.nansum((y - y_bar) ** 2, where=mask)
    ss_res = jnp.nansum((y - y_hat) ** 2, where=mask)

    return jnp.where(mask.sum() > 1, 1 - (ss_res / ss_tot), jnp.nan)


@eqx.filter_jit
def r_squared_ranked_prob(y: jnp.ndarray, y_hat: jnp.ndarray, mask: jnp.ndarray, sigma: jnp.ndarray, k: int):
    sigma = jnp.where(mask, sigma, jnp.inf)
    sigma_sorter = jnp.argpartition(sigma, k, axis=0)[:k]
    y = np.take_along_axis(y, sigma_sorter, axis=0)
    y_hat = np.take_along_axis(y_hat, sigma_sorter, axis=0)
    mask = np.take_along_axis(mask, sigma_sorter, axis=0)
    return jnp.where(jnp.all(mask), r_squared(y, y_hat, mask), jnp.nan)


def mse(X: jnp.ndarray, X_hat: jnp.ndarray, M: jnp.ndarray, axis: Optional[int] = None) -> jnp.ndarray:
    return jnp.mean((X.flatten() - X_hat.flatten()) ** 2, where=M.flatten(), axis=axis)


@eqx.filter_jit
def loss(model: ICNNObsDecoder, batch_X: jnp.ndarray, batch_M: jnp.ndarray, batch_M_art: jnp.ndarray) -> Tuple[
    jnp.ndarray, ImputerMetrics]:
    # Zero for artificially missig values
    batch_X_art = jnp.where(batch_M_art, batch_X, 0.)
    # Tune for artificially masked-out values, fix mask-in (batch_M_art) values.
    X_imp, aux = eqx.filter_vmap(model.partial_input_optimise)(batch_X_art, batch_M_art)
    # Penalise discrepancy with artifially masked-out values.
    return mse(batch_X, X_imp, (~batch_M_art) & batch_M), aux


@eqx.filter_jit
def prob_loss(model: ProbICNNObsDecoder, batch_X: jnp.ndarray, batch_M: jnp.ndarray, batch_M_art: jnp.ndarray) -> Tuple[
    jnp.ndarray, ImputerMetrics]:
    # Zero for artificially missig values
    batch_X_art = jnp.where(batch_M_art, batch_X, 0.)
    # Tune for artificially masked-out values, fix mask-in (batch_M_art) values.
    (X_imp, std), aux = eqx.filter_vmap(model.prob_partial_input_optimise)(batch_X_art, batch_M_art)
    # Penalise discrepancy with artifially masked-out values.
    mask = ~batch_M_art & batch_M
    return log_normal((batch_X, jnp.zeros_like(batch_X) + 0.01), (X_imp, std), mask), aux


@eqx.filter_jit
def model_r_squared(model: ICNNObsDecoder, batch_X: jnp.ndarray, batch_M: jnp.ndarray,
                    batch_M_art: jnp.ndarray) -> jnp.ndarray:
    # Zero for artificially missig values
    batch_X_art = jnp.where(batch_M_art, batch_X, 0.)
    # Tune for artificially masked-out values, fix mask-in (batch_M_art) values.
    X_imp, aux = eqx.filter_vmap(model.partial_input_optimise)(batch_X_art, batch_M_art)
    # Penalise discrepancy with artifially masked-out values.
    mask = (~batch_M_art) & batch_M
    r2_vec = eqx.filter_vmap(r_squared)(batch_X.T, X_imp.T, mask.T)
    return r2_vec, aux

    
def model_r_squared_ranked_prob(model: ICNNObsDecoder, batch_X: jnp.ndarray, batch_M: jnp.ndarray,
                                batch_M_art: jnp.ndarray, k: int) -> jnp.ndarray:
    # Penalise discrepancy with artifially masked-out values.
    mask = (~batch_M_art) & batch_M
    # Zero for artificially missig values
    batch_X_art = jnp.where(batch_M_art, batch_X, 0.)
    # Tune for artificially masked-out values, fix mask-in (batch_M_art) values.
    (X_imp, std), aux = eqx.filter_vmap(model.prob_partial_input_optimise)(batch_X_art, batch_M_art)
    r2_vec = eqx.filter_vmap(r_squared_ranked_prob)(batch_X.T, X_imp.T, mask.T, std.T, k)
    return r2_vec, aux



@eqx.filter_value_and_grad(has_aux=True)
def loss_grad(model: ICNNObsDecoder, batch_X: jnp.ndarray, batch_M: jnp.ndarray,
              batch_M_art: jnp.ndarray) -> Tuple[
    jnp.ndarray, ImputerMetrics]:
    return loss(model, batch_X, batch_M, batch_M_art)


@eqx.filter_value_and_grad(has_aux=True)
def prob_loss_grad(model: ProbICNNObsDecoder, batch_X: jnp.ndarray, batch_M: jnp.ndarray,
                   batch_M_art: jnp.ndarray) -> Tuple[
    jnp.ndarray, ImputerMetrics]:
    return prob_loss(model, batch_X, batch_M, batch_M_art)


@eqx.filter_jit
def make_step(model: ICNNObsDecoder, opt_state, batch_X: jnp.ndarray, batch_M: jnp.ndarray,
              batch_M_art: jnp.ndarray):
    (loss, aux), grads = loss_grad(model, batch_X, batch_M, batch_M_art)
    updates, opt_state = optim.update(grads, opt_state,
                                      params=eqx.filter(model, eqx.is_inexact_array),
                                      value=loss, grad=grads,
                                      value_fn=lambda m: loss(eqx.combine(m, model), batch_X, batch_y))

    model = eqx.apply_updates(model, updates)
    return (loss, aux), model, opt_state


@eqx.filter_jit
def make_prob_step(model: ProbICNNObsDecoder, opt_state, batch_X: jnp.ndarray, batch_M: jnp.ndarray,
                   batch_M_art: jnp.ndarray):
    (loss, aux), grads = prob_loss_grad(model, batch_X, batch_M, batch_M_art)
    updates, opt_state = optim.update(grads, opt_state,
                                      params=eqx.filter(model, eqx.is_inexact_array),
                                      value=loss, grad=grads,
                                      value_fn=lambda m: prob_loss(eqx.combine(m, model), batch_X, batch_M,
                                                                   batch_M_art))

    model = eqx.apply_updates(model, updates)
    return (loss, aux), model, opt_state


def dataloader(arrays, batch_size, *, key):
    dataset_size = arrays[0].shape[0]
    indices = jnp.arange(dataset_size)
    while True:
        perm = jrandom.permutation(key, indices)
        (key,) = jrandom.split(key, 1)
        start = 0
        end = batch_size
        while end < dataset_size:
            batch_perm = perm[start:end]
            yield tuple(array[batch_perm] for array in arrays)
            start = end
            end = start + batch_size

In [None]:
lr=1e-3
steps=1000000
train_batch_size=256
test_batch_size=1024
# train_batch_size=1
# test_batch_size=1
eval_frequency = 10

optim = optax.novograd(lr)
opt_state = optim.init(eqx.filter(p_model, eqx.is_inexact_array))
data_train = (obs_val_train, obs_mask_train, art_mask_train)
data_test = (obs_val_test, obs_mask_test, art_mask_test)

In [None]:
train_batches = dataloader(data_train, train_batch_size, key=jrandom.PRNGKey(0))
test_batches = iter(dataloader(data_test, train_batch_size, key=jrandom.PRNGKey(0)))
train_history = defaultdict(list)
test_history = defaultdict(list)

In [None]:
progress = tqdm(range(steps))

for step, batch_train in zip(progress, train_batches):
    start = time.time()
    (train_loss, train_aux), p_model, opt_state = make_prob_step(p_model, opt_state, *batch_train)
    r2_vec, _ =  model_r_squared(p_model, *batch_train)
    r2_vec_rank, _ = model_r_squared_ranked_prob(p_model, *batch_train, k=5)
    r2_vec = np.array(r2_vec)
    train_nsteps = int(sum(train_aux.n_steps) / len(train_aux.n_steps))
    train_history['R2'].append(r2_vec)
    train_history['R2_rank5'].append(r2_vec_rank)
    train_history['loss'].append(train_loss)
    train_history['n_opt_steps'].append(train_nsteps)
    
    end = time.time()
    if (step % eval_frequency) == 0 or step == steps - 1:
        batch_test = next(test_batches)
        test_loss, _ = prob_loss(p_model, *batch_test)
        r2_vec_test, _ = model_r_squared(p_model, *batch_test)
        r2_vec_rank_test, _ = model_r_squared_ranked_prob(p_model, *batch_test, k=10)
        r2_vec_test = np.array(r2_vec_test)
        test_history['loss'].append(test_loss)
        test_history['R2'].append(r2_vec_test)
        test_history['R2_rank10'].append(r2_vec_rank_test)

    progress.set_description(f"Trn-L: {train_loss:.3f}, Trn-R2: ({np.nanmax(r2_vec_rank):.2f}, {np.nanmin(r2_vec_rank):.2f}, {np.nanmean(r2_vec_rank):.2f}, {np.nanmedian(r2_vec_rank):.2f}),  Trn-N-steps: {train_nsteps}, " 
                             f"Tst-L:  {test_loss:.3f}, Tst-R2:  ({np.nanmax(r2_vec_rank_test):.2f}, {np.nanmin(r2_vec_rank_test):.2f}, {np.nanmean(r2_vec_rank_test):.2f}, {np.nanmedian(r2_vec_rank_test):.2f}), "
                             f"Computation time: {end - start:.2f}, ")
                            

In [None]:
train_stats = pd.DataFrame(train_history)

In [None]:
(train_stats.zloss > train_stats.loss).mean()