# Table of Content

<a name="outline"></a>

## Setup

- [A](#seca) External Imports
- [B](#secb) Internal Imports
- [C](#secc) Configurations and Paths 
- [D](#secd) Patient Interface and Train/Val/Test Partitioning
- [E](#sece) General Utility Functions


## Training

- [1](#sec1) Training ICE-NODE and The Baselines on MIMIC-III
- [2](#sec2) Training ICE-NODE and The Baselines on MIMIC-IV

<a name="seca"></a>

### A External Imports [^](#outline)

In [None]:
import sys
import os
import glob
import random
from collections import defaultdict
from pathlib import Path

from IPython.display import display

import pandas as pd

from tqdm import tqdm

<a name="secb"></a>

### B Internal Imports [^](#outline)

In [None]:
# sys.path.append('..')
sys.path.append('repo')

from icenode.train_icenode_2lr import ICENODE
from icenode.train_icenode_uniform2lr import ICENODE as ICENODE_UNIFORM
from icenode.train_gram import GRAM
from icenode.train_retain import RETAIN
from icenode.metrics import evaluation_table
from icenode.utils import write_params, load_config, load_params

from icenode.mimic3.dag import CCSDAG
from icenode.mimic3.concept import DiagSubject
from icenode.jax_interface import SubjectDiagSequenceJAXInterface,  DiagnosisJAXInterface 

%load_ext autoreload
%autoreload 2

<a name="secd"></a>

### D Configurations and Paths [^](#outline)

In [None]:
mimic3_files = {
    'adm_df': 'data/mimic3_adm_df.csv.gz',
    'diag_df': 'data/mimic3_diag_df.csv.gz'
}

mimic4_files = {
    'adm_df': 'data/mimic4_adm_df.csv.gz',
    'diag_df': 'data/mimic4_diag_df.csv.gz'
}

model_cls = {
    'ICE-NODE': ICENODE,
    'ICE-NODE_UNIFORM': ICENODE_UNIFORM,
    'GRU': GRAM,
    'RETAIN': RETAIN
}   

# optimal hyperparams re: each model.
model_config = {
    'ICE-NODE': 'models_config/icenode_2lr.json' ,
    'ICE-NODE_UNIFORM': 'models_config/icenode_2lr.json' ,
    'GRU': 'models_config/gru.json' ,
    'RETAIN': 'models_config/retain.json'
}

model_config = {clf: load_config(file) for clf, file in model_config.items()}

clfs = list(model_cls.keys())

<a name="sece"></a>

### E Patient Interface and Train/Val/Test Patitioning [^](#outline)

In [None]:
def get_patient_interface(mimic_files, clfs):
    adm_df = pd.read_csv(mimic_files['adm_df'])
    diag_df = pd.read_csv(mimic_files['diag_df'], dtype={'ICD9_CODE': str})
    # Cast columns of dates to datetime64
    adm_df['ADMITTIME'] = pd.to_datetime(adm_df['ADMITTIME'], infer_datetime_format=True).dt.normalize()
    adm_df['DISCHTIME'] = pd.to_datetime(adm_df['DISCHTIME'], infer_datetime_format=True).dt.normalize()

    subjects = DiagSubject.to_list(adm_df, diag_df)
    ccs_dag = CCSDAG()
    
    interface_by_kind =  {
        'timestamped': DiagnosisJAXInterface(subjects, ccs_dag),
        'sequential': SubjectDiagSequenceJAXInterface(subjects, ccs_dag)
    }
    
    interface_kind = {
        'ICE-NODE':  'timestamped',
        'ICE-NODE_UNIFORM': 'timestamped',
        'GRU': 'sequential',
        'RETAIN': 'sequential'
    }

    return {clf: interface_by_kind[interface_kind[clf]] for clf in clfs}
    

m4_interface = get_patient_interface(mimic4_files, clfs)
m3_interface = get_patient_interface(mimic3_files, clfs)

m4_train_ids, m4_valid_ids, m4_test_ids = m4_interface[clfs[0]].random_splits(split1=0.7, split2=0.85, random_seed=42)
m3_train_ids, m3_valid_ids, m3_test_ids = m3_interface[clfs[0]].random_splits(split1=0.7, split2=0.85, random_seed=42)


In [None]:
m4_percentiles = m4_interface[clfs[0]].diag_flatccs_by_percentiles(20)
m3_percentiles = m3_interface[clfs[0]].diag_flatccs_by_percentiles(20)

m4_train_percentiles = m4_interface[clfs[0]].diag_flatccs_by_percentiles(20, m4_train_ids)
m3_train_percentiles = m3_interface[clfs[0]].diag_flatccs_by_percentiles(20, m3_train_ids)


<a name="secf"></a>

### F Utility Functions [^](#outline)

In [None]:
def get_model(clf, config, interface):
    model = model_cls[clf].create_model(config, interface, [], None)
    state = model.init(config)
    return model, state
        
def get_models(clfs, config, interface):
    return {clf: get_model(clf, config[clf], interface[clf]) for clf in clfs}


def train_model(model, m_state, config, train_ids, 
                valid_ids, training_output, percentile_codes):
    
    # Make a new directory (if doesn't exist) 
    Path(training_output).mkdir(parents=True, exist_ok=True)


    step_evaluation = {}
    
    # because it is mutable, and random.Random shuffles in-place.
    train_ids = train_ids.copy() 
    rng = random.Random(42)
    batch_size = config['training']['batch_size']
    batch_size = min(batch_size, len(train_ids))

    epochs = config['training']['epochs']
    iters = round(epochs * len(train_ids) / batch_size)

    for i in tqdm(range(iters)):
        rng.shuffle(train_ids)
        train_batch = train_ids[:batch_size]
        
        # Step = 1% progress
        current_step = round((i + 1) * 100 / iters)
        previous_step = round(i * 100 / iters)
        
        m_state = model.step_optimizer(current_step, m_state, train_batch)
        if model.hasnan(m_state):
            print('NaN detected')
            break


        if current_step == previous_step and i < iters - 1:
            continue

        raw_res = {
                'TRN': model.eval(m_state, train_batch),
                'VAL': model.eval(m_state, valid_ids)
            }

        eval_df, _ = evaluation_table(raw_res, percentile_codes)
#         display(eval_df)
        
        step_evaluation[current_step] = eval_df
        fname = os.path.join(training_output, f'step{current_step:04d}_params.pickle')
        model.write_params(m_state, fname)

    return m_state, step_evaluation

 

In [None]:
m4_models = get_models(clfs, model_config, m4_interface)

m3_models = get_models(clfs, model_config, m3_interface)

<a name="sec1"></a>

### 1 Training ICE-NODE and The Baselines on MIMIC-III [^](#outline)

#### ICE-NODE

In [None]:
m3_icenode_model, m3_icenode_state = m3_models['ICE-NODE']

## TODO: This may take a long time, a pretrained model already exists in (yy).
m3_icenode_state, m3_icenode_evals = train_model(m3_icenode_model, m3_icenode_state,
                                                 model_config['ICE-NODE'], m3_train_ids, m3_valid_ids,
                                                 'trained_models/m3_icenode', 
                                                 m3_train_percentiles)

#### ICE-NODE_UNIFORM

In [None]:
m3_icenode_U_model, m3_icenode_U_state = m3_models['ICE-NODE_UNIFORM']
## TODO: This can take up to (xx), trained model already exist in (yy).
m3_icenode_U_state, m3_icenode_U_evals = train_model(m3_icenode_U_model, m3_icenode_U_state,
                                                     model_config['ICE-NODE_UNIFORM'], 
                                                     m3_train_ids, m3_valid_ids,
                                                     'trained_models/m3_icenode_uniform', 
                                                     m3_train_percentiles)


#### GRU

In [None]:
m3_gru_model, m3_gru_state = m3_models['GRU']
## TODO: This can take up to (xx), trained model already exist in (yy).
m3_gru_state, m3_gru_evals = train_model(m3_gru_model, m3_gru_state,
                                         model_config['GRU'], 
                                         m3_train_ids, m3_valid_ids,
                                         'trained_models/m3_gru', 
                                         m3_train_percentiles)

#### RETAIN

In [None]:
m3_retain_model, m3_retain_state = m3_models['RETAIN']
## TODO: This can take up to (xx), trained model already exist in (yy).
m3_retain_state, m3_retain_evals = train_model(m3_retain_model, m3_retain_state,
                                               model_config['RETAIN'],
                                               m3_train_ids, m3_valid_ids,
                                               'trained_models/m3_retain',
                                               m3_train_percentiles)

<a name="sec2"></a>

### 2 Training ICE-NODE and The Baselines on MIMIC-IV [^](#outline)

#### ICE-NODE

In [None]:
m4_icenode_model, m4_icenode_state = m4_models['ICE-NODE']

## TODO: This can take up to (xx), trained model already exist in (yy).
m4_icenode_state, m4_icenode_evals = train_model(m4_icenode_model, m4_icenode_state,
                                                 model_config['ICE-NODE'], m4_train_ids, m4_valid_ids,
                                                 'trained_models/m3_icenode', 
                                                 m4_train_percentiles)

#### ICE-NODE_UNIFORM

In [None]:
m4_icenode_U_model, m4_icenode_U_state = m4_models['ICE-NODE_UNIFORM']
## TODO: This can take up to (xx), trained model already exist in (yy).
m4_icenode_U_state, m4_icenode_U_evals = train_model(m4_icenode_U_model, m4_icenode_U_state,
                                                     model_config['ICE-NODE_UNIFORM'], 
                                                     m4_train_ids, m4_valid_ids,
                                                     'trained_models/m4_icenode_uniform', 
                                                     m4_train_percentiles)


#### GRU

In [None]:
m4_gru_model, m4_gru_state = m4_models['GRU']
## TODO: This can take up to (xx), trained model already exist in (yy).
m4_gru_state, m4_gru_evals = train_model(m4_gru_model, m4_gru_state,
                                         model_config['GRU'], 
                                         m4_train_ids, m4_valid_ids,
                                         'trained_models/m4_gru', 
                                         m4_train_percentiles)

#### RETAIN

In [None]:
m4_retain_model, m4_retain_state = m4_models['RETAIN']
## RESOURCES WARNING: This model, with this large dataset and occasionally long patient histories, 
## unfortunately requires larger memory than what is available in usual high-end GPUs (e.g. 12 GB in my main workstation).
## For this particular experiment, we relied on CPUs and the CPU RAM (over 64 GB).
## Regarding training on MIMIC-IV, ICE-NODE and ICE-NODE_UNIFORM have finished training in less than 48 hours, 
## while GRU model has finished in less than 24 hours, however
## RETAIN training on MIMIC-IV would need more than three weeks to finish on a CPU.
## There is already a pretrained model that we add to this anonymous repository.
m4_retain_state, m4_retain_evals = train_model(m4_retain_model, m4_retain_state,
                                               model_config['RETAIN'],
                                               m4_train_ids, m4_valid_ids,
                                               'trained_models/m4_retain',
                                               m4_train_percentiles)