# Table of Content

<a name="outline"></a>

## Setup

- [A](#seca) External Imports
- [B](#secb) Internal Imports
- [C](#secc) Configurations and Paths 
- [D](#secd) JAX Interface
- [E](#sece) General Utility Functions


## Training

- [1](#sec1) Training ICE-NODE and The Baselines


<a name="seca"></a>

### A External Imports [^](#outline)

In [1]:
%load_ext autoreload
%autoreload 2
import sys
import os
import glob
import random
from collections import defaultdict
from pathlib import Path

from IPython.display import display

import pandas as pd

from tqdm import tqdm
import jax
jax.config.update('jax_platform_name', 'gpu')

<a name="secb"></a>

### B Internal Imports [^](#outline)

In [2]:


sys.path.append("..")

from lib import utils as U
from lib.ehr.dataset import load_dataset

In [3]:
# Assign the folder of the dataset to `DATA_FILE`.

HOME = os.environ.get('HOME')
DATA_DIR = f'{HOME}/GP/ehr-data'
SOURCE_DIR = os.path.abspath("..")

<a name="secd"></a>

### C Configurations and Paths [^](#outline)

In [4]:

output_dir = 'mimic_artefacts_surv'
Path(output_dir).mkdir(parents=True, exist_ok=True)

In [5]:
with U.modified_environ(DATA_DIR=DATA_DIR):
    m3_dataset = load_dataset('M3')
#     m4_dataset = load_dataset('M4')
   

                Unrecognised <class 'lib.ehr.coding_scheme.DxICD9'> codes (38)
                to be removed: ['041.49', '282.40', '282.46', '284.11', '284.12', '284.19', '294.20', '294.21', '348.82', '365.70', '425.11', '425.18', '444.09', '512.83', '512.84', '512.89', '516.31', '516.34', '516.36', '518.51', '518.52', '518.53', '573.5', '596.89', '719.70', '747.32', '793.11', '793.19', '795.51', '997.49', '998.01', '998.09', '999.32', '999.33', 'V12.55', 'V13.89', 'V54.82', 'V88.21']
                Unrecognised <class 'lib.ehr.coding_scheme.PrICD9'> codes (7)
                to be removed: ['02.21', '17.55', '17.56', '35.05', '36.01', '36.02', '36.05']


In [6]:
from lib.ml import ICENODE, ICENODE_UNIFORM, GRU, RETAIN, WindowLogReg, NJODE
%load_ext autoreload
%autoreload 2
"""
predefined hyperparams re: each model.
"""

model_cls = {
    'ICE-NODE': ICENODE,
    'ICE-NODE_UNIFORM': ICENODE_UNIFORM,
    'GRU': GRU,
    'RETAIN': RETAIN,
    'LogReg': WindowLogReg,
    'NJODE': NJODE
}

model_config = {
    'ICE-NODE': f'{SOURCE_DIR}/expt_configs/icenode.json' ,
    'ICE-NODE_UNIFORM': f'{SOURCE_DIR}/expt_configs/icenode.json' ,
    'GRU': f'{SOURCE_DIR}/expt_configs/gru.json' ,
    'RETAIN': f'{SOURCE_DIR}/expt_configs/retain.json',
    'LogReg': f'{SOURCE_DIR}/expt_configs/window_logreg.json',
    'NJODE': f'{SOURCE_DIR}/expt_configs/njode.json'
}

model_config = {clf: U.load_config(file) for clf, file in model_config.items()}

clfs = [
    'ICE-NODE', 
    'ICE-NODE_UNIFORM', 
    'GRU', 'RETAIN', 'LogReg', 'NJODE']


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
# icenode_variants = []

# timescale = {'Ts7': 7.0, 'Ts30': 30.0}
# embeddings_size = {'E200': 200, 'E250': 250, 'E300': 300}
# tayreg = {'Ty0': 0, 'Ty3': 3}
# for ts_key, ts_val in timescale.items():
#     for e_key, e_val in embeddings_size.items():
#         for tay_key, tay_val in tayreg.items():
#             model_name = f'ICE-NODE-{"".join((ts_key, e_key, tay_key))}'
#             icenode_variants.append(model_name)
            
#             config = U.load_config(f'{SOURCE_DIR}/expt_configs/icenode.json')
#             config['emb']['dx']['embeddings_size'] = e_val
#             config['model']['timescale'] = ts_val
#             config['training']['tay_reg'] = tay_val
#             model_config[model_name] = config
#             model_cls[model_name] = ICENODE
# clfs.extend(icenode_variants)

In [8]:
m3_train_output_dir = {clf: f'{output_dir}/m3_train/{clf}' for clf in clfs}
m4_train_output_dir = {clf: f'{output_dir}/m4_train/{clf}' for clf in clfs}

[Path(d).mkdir(parents=True, exist_ok=True) for d in m3_train_output_dir.values()]
[Path(d).mkdir(parents=True, exist_ok=True) for d in m4_train_output_dir.values()]

[None, None, None, None, None, None]

In [9]:
from lib.ml import ConfigDiskWriter, MinibatchLogger, EvaluationDiskWriter, ParamsDiskWriter

The reporter objects are called inside training iterations
Each has its own functionality:
1. **ConfigDiskWriter**: writes the experiment config file as JSON in the same training directory
2. **MinibatchLogger**: writes to the console the training progress details.
3. **EvaluationDiskWriter**: writes the evaluation as csv tables in the same training directory for each step of the 100.
4. **ParamsDiskWriter**: writes the model parameters snapshot at each step out of 100.

In [10]:
make_reporters = lambda output_dir, config: [ConfigDiskWriter(output_dir=output_dir, config=config),
                                             MinibatchLogger(config),
                                             EvaluationDiskWriter(output_dir=output_dir),
                                             ParamsDiskWriter(output_dir=output_dir)]

m3_reporters = {model: make_reporters(m3_train_output_dir[model], model_config[model]) for model in clfs}
# m4_reporters = {model: make_reporters(m4_train_output_dir[model], model_config[model]) for model in clfs}

<a name="sece"></a>

### D JAX Interface [^](#outline)

In [11]:
from lib.ehr.coding_scheme import DxCCS, DxFlatCCS, DxICD9, DxICD10
from lib.ehr import Subject_JAX
from lib.ehr import StaticInfoFlags

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


The dictionary `code_scheme` in the next cell specifies the code spaces of:
- 'dx': diagnostic input (input features) codes. Possible arguments:  `DxCCS()`, `DxFlatCCS()`, `DxICD9()`, `DxICD10()`.
- 'outcome': diagnostic outcome (target to be predicted) codes. Possible arguments: 
    - `OutcomeExtractor('<outcome_label>')`  for prediction of all `<outcome_label>` codes or 
    - `FirstOccurrenceOutcomeExtractor('<outcome_label>')` for prediction of `<outcome_label>` for the first occurrence per patient).
    - `'<outcome_label>'` specifies a subset of dx-codes defined by JSON files in `lib/ehr/resources/outcome_filters`, where each of the labels below links to the following JSON file:
        - `'dx_flatccs_mlhc_groups'`:  `'dx_flatccs_mlhc_groups.json'`,
        - `'dx_flatccs_filter_v1'`: `'dx_flatccs_v1.json'`,
        - `'dx_icd9_filter_v1'`: `'dx_icd9_v1.json'`,
        - `'dx_icd9_filter_v2_groups'`: `'dx_icd9_v2_groups.json'`,
        - `'dx_icd9_filter_v3_groups'`: `'dx_icd9_v3_groups.json'`
    

**Note**: OutcomeExtractor can be configured through a JSON file to focus only a subset of the diagnostic codes. For example,
you can focus the prediction objective on a small subset of interest (e.g. to predict only pulmonary-heart 
diseases codes, etc.).
OutcomeExtractor can also be replaced by FirstOccurrenceOutcomeExtractor to enforce the prediction 
objective to predict only the first occurrence of each code for one patient, and subsequent
redundant occurences will be avoided and not incorporated in the loss function.

In [12]:
from lib.ehr import OutcomeExtractor, SurvivalOutcomeExtractor

In [13]:
code_scheme = {
    'dx': DxCCS(), # other options 
    'outcome': SurvivalOutcomeExtractor('dx_flatccs_filter_v1')
}

### Adding Demographic Information in Training

What do you need to include as control features? **Uncomment each line to consider the corresponding static information.**

In [14]:

static_info_flags = StaticInfoFlags(gender=True, age=True)

m3_interface = Subject_JAX.from_dataset(m3_dataset, 
                                        code_scheme=code_scheme, 
                                        static_info_flags=static_info_flags,
                                       data_max_size_gb=1)
# m4_interface = Subject_JAX.from_dataset(m4_dataset, 
#                                         code_scheme=code_scheme, 
#                                         static_info_flags=static_info_flags,
#                                        data_max_size_gb=1)

m3_splits = m3_interface.random_splits(split1=0.7, split2=0.85, random_seed=42)
# m4_splits = m4_interface.random_splits(split1=0.7, split2=0.85, random_seed=42)


                            S - M_domain (2497, p=0.14371223021582732):
                            ['001', '002', '003', '003.2', '004']...

                            M_domain - S (0, p=0.0):
                            []...

                            M_domain (14878):
                            ['001.0', '001.1', '001.9', '002.0', '002.1']...

                            S (17375): ['001', '001.0', '001.1', '001.9', '002']...

                            S - M_domain (2497, p=0.14371223021582732):
                            ['001', '002', '003', '003.2', '004']...

                            M_domain - S (0, p=0.0):
                            []...

                            M_domain (14878):
                            ['001.0', '001.1', '001.9', '002.0', '002.1']...

                            S (17375): ['001', '001.0', '001.1', '001.9', '002']...


In [15]:
import jax.random as jrandom
import lib.ml as ml
%load_ext autoreload
%autoreload 2
key = jrandom.PRNGKey(0)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In the next cell we load a dictionary for each model specifiying the experiment configuration per model.
The classname of the trainer used is also specified in the experiment configs.
For example, this is the configuration file of ICE-NODE experiment.

```json
{
    "emb": {
        "dx": {
           "decoder_n_layers": 2,
           "classname":  "MatrixEmbeddings",      
           "embeddings_size": 300
        }
    },
    "model": {
        "ode_dyn_label": "mlp3",
        "ode_init_var": 1e-7,
        "state_size": 30,
        "timescale": 30
    },
    "training": {
        "batch_size": 256,
        "decay_rate": [0.25, 0.33],
        "lr": [7e-5,  1e-3],
        "epochs": 60,
        "reg_hyperparams": {
            "L_dyn": 1000.0,
            "L_l1": 0,
            "L_l2": 0
        },
        "opt": "adam",
        "classname": "ODETrainer2LR" <---- "classname, so this class should be available through ml package."
    }
}
```

Since we have a string of the classname, one way to get `ml.ODETrainer2LR` is `getattr(ml, 'ODETrainer2LR')`

In [16]:
m3_models = {clf: model_cls[clf].from_config(model_config[clf],
                                              m3_interface,
                                              m3_splits[0],
                                              key) for clf in clfs}
# m4_models = {clf: model_cls[clf].from_config(model_config[clf],
#                                               m4_interface,
#                                               m4_splits[0],
#                                               key) for clf in clfs}



trainers_cls = {clf: getattr(ml, model_config[clf]["training"]["classname"]) for clf in clfs}
trainers = {clf: trainers_cls[clf](**model_config[clf]["training"]) for clf in clfs}

## Metrics of Interest Specification

In [17]:
from lib.metric import (CodeAUC, UntilFirstCodeAUC, AdmissionAUC, CodeGroupTopAlarmAccuracy, LossMetric, MetricsCollection)

## Evaluation Metrics per Model

1. *CodeAUC*: evaluates the prediction AUC per code (aggregating over all subject visits, for all subjects)
2. *UntilFirstCodeAUC*: same as *CodeAUC*, but evaluates the prediction AUC until the first occurrence for each subject, once the code occured, all the subsequent visits are ignored for that code. If the code does not show in a particular subject, all the subject visits are ignored.
3. *AdmissionAUC*: evaluates the prediction AUC per visit (i.e. probability of assigning higher risk values for present codes than the absent ones).
4. *CodeGroupTopAlarmAccuracy*: partition codes into groups according the code frequency (from the most frequent to the least), and for each visit picks the top `k` risks, and the metric evaluates the accuracy of the top `k` riskiest codes by the model for being indeed present.
5. *LossMetric*: records the loss values for different loss variants, which doesn't necessarily include the actual loss function that was used in the training.
6. *MetricsCollection*: Groups multiple metrics to be considered at once.

**Note:** you will get different results when calling the method `outcome_by_percentiles` by changing the 'outcome' enty of the `ode_scheme` dictionary as following:
- OutcomeExtractor: the counting will consider the code and its redundant occurrence for each subject, then aggregated over all subjects 
- FirstOccurrenceOutcomeExtractor: the counting will consider the first occurrence only for each subject, then aggregated over all subjects.

In [18]:
# pecentile_range=20 will partition the codes into five gruops, where each group contains 
# codes that overall constitutes 20% of the codes in all visits of specified 'subjects' list.
m3_code_freq_partitions = m3_interface.outcome_by_percentiles(percentile_range=20, subjects=m3_splits[0])
# m4_code_freq_partitions = m4_interface.outcome_by_percentiles(percentile_range=20, subjects=m4_splits[0])

# Evaluate for different k values
top_k_list = [3, 5, 10, 15, 20]

m3_metrics = [CodeAUC(m3_interface),
              UntilFirstCodeAUC(m3_interface),
              AdmissionAUC(m3_interface),
              LossMetric(m3_interface),
              CodeGroupTopAlarmAccuracy(m3_interface, top_k_list=top_k_list, code_groups=m3_code_freq_partitions)]
# m4_metrics = [CodeAUC(m4_interface),
#               UntilFirstCodeAUC(m4_interface),
#               AdmissionAUC(m4_interface),
#               LossMetric(m4_interface),
#               CodeGroupTopAlarmAccuracy(m4_interface, top_k_list=top_k_list, code_groups=m4_code_freq_partitions)]

m3_metrics = MetricsCollection(m3_metrics)
# m4_metrics = MetricsCollection(m4_metrics)

<a name="sec1"></a>

### 1 Training ICE-NODE and The Baselines (#outline)

In [19]:
from lib.ml import MetricsHistory

def m3_train(clf):
    output_dir = m3_train_output_dir[clf]
    config = model_config[clf]
    model = m3_models[clf]
    trainer = trainers[clf]
    reporters = [EvaluationDiskWriter(output_dir), # <- responsible for writing evaluation tables on disk at the given path
                 ParamsDiskWriter(output_dir), # <- responsible for writing model parameters snapshot after each iteration.
                 ConfigDiskWriter(output_dir, config), # writes the config file as JSON
                ]
    
    history = MetricsHistory(m3_metrics) # <- empty history
    
    return trainer(model, m3_interface, m3_splits, history=history, reporters=reporters, prng_seed=42)

# def m4_train(clf):
#     output_dir = m4_train_output_dir[clf]
#     config = model_config[clf]
#     model = m4_models[clf]
#     trainer = trainers[clf]
#     reporters = [EvaluationDiskWriter(output_dir), # <- responsible for writing evaluation tables on disk at the given path
#                  ParamsDiskWriter(output_dir), # <- responsible for writing model parameters snapshot after each iteration.
#                  ConfigDiskWriter(output_dir, config), # writes the config file as JSON
#                 ]
    
#     history = MetricsHistory(m4_metrics) # <- empty history
    
#     return trainer(model, m4_interface, m4_splits, history=history, reporters=reporters, prng_seed=42)

#### NJODE

In [20]:
m3_njode_results = m3_train('NJODE')

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7193/7193 [10:11:46<00:00,  5.10s/it]


#### ICE-NODE

In [None]:
# for icenode_variant in icenode_variants:
#     print(icenode_variant)
#     m3_train(icenode_variant)

In [None]:
m3_icenode_results = m3_train('ICE-NODE')

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉| 4794/4795 [4:21:41<00:02,  2.72s/it]

In [None]:
m4_icenode_results = m4_train('ICE-NODE')

#### ICE-NODE_UNIFORM

In [None]:
m3_icenode_uniform_results = m3_train('ICE-NODE_UNIFORM')

In [None]:
m4_icenode_uniform_results = m4_train('ICE-NODE_UNIFORM')

#### GRU

In [None]:
m3_gru_results = m3_train('GRU')

  row.append(agg_f(field_vals))
  r, k = function_base._ureduce(a, func=_nanmedian, axis=axis, out=out,
  row.append(agg_f(field_vals))
  row.append(agg_f(field_vals))
  r, k = function_base._ureduce(a, func=_nanmedian, axis=axis, out=out,
  row.append(agg_f(field_vals))
  row.append(agg_f(field_vals))
  r, k = function_base._ureduce(a, func=_nanmedian, axis=axis, out=out,
  row.append(agg_f(field_vals))
  row.append(agg_f(field_vals))
  r, k = function_base._ureduce(a, func=_nanmedian, axis=axis, out=out,
  row.append(agg_f(field_vals))
  row.append(agg_f(field_vals))
  r, k = function_base._ureduce(a, func=_nanmedian, axis=axis, out=out,
  row.append(agg_f(field_vals))
  row.append(agg_f(field_vals))
  r, k = function_base._ureduce(a, func=_nanmedian, axis=axis, out=out,
  row.append(agg_f(field_vals))
 70%|█████████████████████████████████████████████████████████████████████████████▍                                 | 16052/23018 [4:15:52<2:03:42,  1.07s/it]

In [None]:
m4_gru_results = m4_train('GRU')

#### RETAIN

In [None]:
m3_retain_results = m3_train('RETAIN')

In [None]:
m4_retain_results = m4_train('RETAIN')

In [None]:
m3_logreg_results = m3_train('LogReg')

In [None]:
m4_logreg_results = m4_train('LogReg')