## Libs


In [1]:
%load_ext autoreload
%autoreload 2

import sys
import time
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import jax
import jax.numpy as jnp
import jax.random as jrandom 
import equinox as eqx
import optax
from tqdm.notebook import tqdm
from collections import defaultdict
from functools import partial

jax.config.update('jax_enable_x64', True)
# jax.config.update('jax_check_tracer_leaks', True) 
sys.path.append("../../..")
from lib.ml.base_models import ICNNObsDecoder
import lib.ehr.example_datasets.mimiciv_aki as m4aki
from lib.ehr.tvx_ehr import TVxEHR
from lib.utils import modified_environ, write_config
 

## Data Loading

### First Time Loading

In [2]:
# tvx = m4aki.TVxAKIMIMICIVDataset.load('/home/asem/GP/ehr-data/mimic4aki-cohort/tvx_aki.h5')

In [3]:
# obs = [adm.observables  for subject in tvx0.subjects.values() for adm in subject.admissions]
# adm_id = sum(([adm.admission_id] * len(adm.observables.time)  for subject in tvx0.subjects.values() for adm in subject.admissions), [])
# subj_id = sum(([subject.subject_id] * len(adm.observables.time)  for subject in tvx0.subjects.values() for adm in subject.admissions), [])

In [4]:
# obs_val = np.vstack([obs_i.value for obs_i in obs])
# obs_mask = np.vstack([obs_i.mask for obs_i in obs])
# obs_time = np.hstack([obs_i.time for obs_i in obs])

In [5]:
# tvx0.scheme.obs
# features = list(map(tvx0.scheme.obs.desc.get, tvx0.scheme.obs.codes))

In [6]:
# obs_val = pd.DataFrame(obs_val, columns=features)
# obs_mask = pd.DataFrame(obs_mask.astype(int), columns=features)
# meta = pd.DataFrame({'subject_id': subj_id, 'admission_id': adm_id, 'time': obs_time})


In [7]:
# artificial_mask = obs_mask.copy()
# artificial_mask = obs_mask & np.array(jrandom.bernoulli(jrandom.PRNGKey(0), p=0.5, shape=obs_mask.shape))


In [8]:
# obs_val.to_csv('missingness_data/missingness_vals.csv')
# obs_mask.to_csv('missingness_data/missingness_mask.csv')
# meta.to_csv('missingness_data/meta.csv')
# artificial_mask.to_csv('missingness_data/missingness_artificial_mask.csv')


### Later Loading

In [9]:
obs_val = pd.read_csv('missingness_data/missingness_vals.csv', index_col=[0])
obs_mask = pd.read_csv('missingness_data/missingness_mask.csv', index_col=[0])
artificial_mask = pd.read_csv('missingness_data/missingness_artificial_mask.csv', index_col=[0])
meta = pd.read_csv('missingness_data/meta.csv', index_col=[0])


### Split

In [10]:
split_ratio = 0.7
seed = 0
indices = jrandom.permutation(jrandom.PRNGKey(seed), len(obs_val))
train_idx = indices[:int(split_ratio * len(indices))]
test_idx = indices[int(split_ratio * len(indices)):]

obs_val_train = jnp.array(obs_val.iloc[train_idx].to_numpy())
obs_mask_train = jnp.array(obs_mask.iloc[train_idx].to_numpy())
art_mask_train =  jnp.array(artificial_mask.iloc[train_idx].to_numpy())

obs_val_test = jnp.array(obs_val.iloc[test_idx].to_numpy())
obs_mask_test = jnp.array(obs_mask.iloc[test_idx].to_numpy())
art_mask_test =  jnp.array(artificial_mask.iloc[test_idx].to_numpy())

## Model Configuration

In [11]:
model = ICNNObsDecoder(observables_size=obs_mask.shape[1], state_size=0, 
                       hidden_size_multiplier=3, depth=8, key=jrandom.PRNGKey(0))


In [12]:
model

In [13]:
model.f_energy

## Training

In [16]:
def mse(X, X_hat, M, axis=None):
    return jnp.mean((X - X_hat)**2, where=M, axis=axis)

@eqx.filter_jit
def loss(model, batch_X, batch_M, batch_M_art):
    # Zero for artificially missig values
    batch_X_art = jnp.where(batch_M_art, batch_X, 0.)
    # Tune for artificially masked-out values, fix mask-in (batch_M_art) values.
    X_imp, aux = eqx.filter_vmap(model.partial_input_optimise)(batch_X_art, batch_M_art)
    # Penalise discrepancy with artifially masked-out values.
    return mse(batch_X, X_imp, (~batch_M_art) & batch_M), aux


@eqx.filter_jit
def zero_imputer_loss(batch_X, batch_M, batch_M_art):
    return mse(batch_X, jnp.zeros_like(batch_X), (~batch_M_art) & batch_M)

@eqx.filter_jit
def weighted_loss(model, batch_X, batch_M, batch_M_art, weights):
    # Zero for artificially missig values
    batch_X_art = jnp.where(batch_M_art, batch_X, 0.)
    # Tune for artificially masked-out values, fix mask-in (batch_M_art) values.
    X_imp, aux = eqx.filter_vmap(model.partial_input_optimise)(batch_X_art, batch_M_art)
    # Penalise discrepancy with artifially masked-out values.
    loss_vec = mse(batch_X, X_imp, (~batch_M_art) & batch_M, axis=0)
    unweighted_loss = mse(batch_X, X_imp, (~batch_M_art) & batch_M)
    weighted_loss_ = jnp.nansum(loss_vec * weights) / jnp.sum(weights) 
    return weighted_loss_, unweighted_loss, aux
    
@eqx.filter_value_and_grad(has_aux=True)
def loss_grad(model, batch_X, batch_M, batch_M_art):
    return loss(model, batch_X, batch_M, batch_M_art)


@eqx.filter_jit
def make_step(model, opt_state, batch_X, batch_M, batch_M_art):
    (loss, aux), grads = loss_grad(model, batch_X, batch_M, batch_M_art)
    updates, opt_state = optim.update(grads, opt_state, params=model)
    model = eqx.apply_updates(model, updates)
    return (loss, aux), model, opt_state

def dataloader(arrays, batch_size, *, key):
    dataset_size = arrays[0].shape[0]
    indices = jnp.arange(dataset_size)
    while True:
        perm = jrandom.permutation(key, indices)
        (key,) = jrandom.split(key, 1)
        start = 0
        end = batch_size
        while end < dataset_size:
            batch_perm = perm[start:end]
            yield tuple(array[batch_perm] for array in arrays)
            start = end
            end = start + batch_size


lr=1e-3
steps=1000000
train_batch_size=256
test_batch_size=512
# train_batch_size=1
# test_batch_size=1
eval_frequency = 10

optim = optax.adadelta(lr)
opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
data_train = (obs_val_train, obs_mask_train, art_mask_train)
data_test = (obs_val_test, obs_mask_test, art_mask_test)



In [None]:
train_batches = dataloader(data_train, train_batch_size, key=jrandom.PRNGKey(0))
test_batches = iter(dataloader(data_test, train_batch_size, key=jrandom.PRNGKey(0)))

progress = tqdm(range(steps))
train_history = defaultdict(list)
test_history = defaultdict(list)

for step, batch_train in zip(progress, train_batches):
    start = time.time()
    (train_loss, train_aux), model, opt_state = make_step(model, opt_state, *batch_train)
    train_zloss =  zero_imputer_loss(*batch_train)
    train_nsteps = int(sum(train_aux.n_steps) / len(train_aux.n_steps))
    train_history['zloss'].append(train_zloss)
    train_history['loss'].append(train_loss)
    train_history['n_opt_steps'].append(train_nsteps)
    
    end = time.time()
    if (step % eval_frequency) == 0 or step == steps - 1:
        batch_test = next(test_batches)
        test_wloss, test_loss, aux = weighted_loss(model, *batch_test, weights=art_mask_train.mean(axis=0))
        test_zloss = zero_imputer_loss(*batch_test)
        nsteps = int(sum(aux.n_steps) / len(aux.n_steps))
        test_history['zloss'].append(test_zloss)
        test_history['loss'].append(test_loss)
        test_history['wloss'].append(test_wloss)
        test_history['n_opt_steps'].append(nsteps)
        
    progress.set_description(f"Trn-L: {train_loss:.3f}, Trn-Z-L: {train_zloss: .3f}, Tst N-steps: {train_nsteps}, " 
                             f"Tst-L: {test_loss:.3f}, Tst-W-L: {test_wloss:.3f}, Tst-Z-L: "
                             f"{test_zloss:.3f}, Tst N-steps: {nsteps}, Computation time: {end - start:.2f}, ")

In [None]:
train_stats = pd.DataFrame(train_history)

In [None]:
(train_stats.zloss > train_stats.loss).mean()