# Table of Content

<a name="outline"></a>

## Setup

- [A](#seca) External Imports
- [B](#secb) Internal Imports
- [C](#secc) Configurations and Paths 
- [D](#secd) Patient Interface and Train/Val/Test Partitioning
- [E](#sece) General Utility Functions


## Training

- [1](#sec1) Training ICE-NODE and The Baselines


<a name="seca"></a>

### A External Imports [^](#outline)

In [1]:
import sys
import os
import glob
import random
from collections import defaultdict
from pathlib import Path

from IPython.display import display

import pandas as pd

from tqdm import tqdm
import jax
jax.config.update('jax_platform_name', 'cpu')

<a name="secb"></a>

### B Internal Imports [^](#outline)

In [2]:
%load_ext autoreload
%autoreload 2

sys.path.append("..")

from lib import utils as U
from lib.ehr.dataset import load_dataset

In [3]:
# HOME and DATA_STORE are arbitrary, change as appropriate.
HOME = os.environ.get('HOME')
DATA_STORE = f'{HOME}/GP/ehr-data'
DATA_FILE = os.path.join(DATA_STORE, 'cprd-data/DUMMY_DATA.csv')
SOURCE_DIR = os.path.abspath("..")

<a name="secd"></a>

### D Configurations and Paths [^](#outline)

**Assign** MIMIC-III and MIMIC-IV directory paths into `mimic3` and `mimic4` variables.

In [4]:
output_dir = 'cprd_artefacts'
Path(output_dir).mkdir(parents=True, exist_ok=True)

In [5]:
with U.modified_environ(DATA_FILE=DATA_FILE):
    cprd_dataset = load_dataset('CPRD')
   

In [6]:
from lib.ml import ICENODE, ICENODE_UNIFORM, GRU, RETAIN, WindowLogReg

"""
predefined hyperparams re: each model.
"""

model_cls = {
    'ICE-NODE': ICENODE,
    'ICE-NODE-G': ICENODE,
    'ICE-NODE_UNIFORM': ICENODE_UNIFORM,
    'GRU': GRU,
    'GRU-G': GRU,
    'RETAIN': RETAIN,
    'LogReg': WindowLogReg
}

model_config = {
    'ICE-NODE': f'{SOURCE_DIR}/expt_configs/cprd/icenode.json' ,
    'ICE-NODE_UNIFORM': f'{SOURCE_DIR}/expt_configs/cprd/icenode.json' ,
    'GRU': f'{SOURCE_DIR}/expt_configs/cprd/gru.json' ,
    'RETAIN': f'{SOURCE_DIR}/expt_configs/cprd/retain.json',
    'LogReg': f'{SOURCE_DIR}/expt_configs/cprd/window_logreg.json'
}

model_config = {clf: U.load_config(file) for clf, file in model_config.items()}

clfs = ['ICE-NODE', 'ICE-NODE_UNIFORM', 'GRU', 'RETAIN', 'LogReg']

In [7]:
cprd_train_output_dir = {clf: f'{output_dir}/train/{clf}' for clf in clfs}

[Path(d).mkdir(parents=True, exist_ok=True) for d in cprd_train_output_dir.values()]

[None, None, None, None, None]

In [8]:
from lib.ml import ConfigDiskWriter, MinibatchLogger, EvaluationDiskWriter, ParamsDiskWriter
# The reporter objects are called inside training iterations
# Each has its own functionality:
# 1. ConfigDiskWriter: writes the experiment config file as JSON in the same training directory
# 2. MinibatchLogger: writes to the console the training progress details.
# 3. EvaluationDiskWriter: writes the evaluation as csv tables in the same training directory for each step of the 100.
# 4. ParamsDiskWriter: writes the model parameters snapshot at each step out of 100.
make_reporters = lambda output_dir, config: [ConfigDiskWriter(output_dir=output_dir, config=config),
                                             MinibatchLogger(config),
                                             EvaluationDiskWriter(output_dir=output_dir),
                                             ParamsDiskWriter(output_dir=output_dir)]

reporters = {model: make_reporters(cprd_train_output_dir[model], model_config[model]) for model in clfs}


<a name="sece"></a>

### E Patient Interface and Train/Val/Test Patitioning [^](#outline)

In [9]:
from lib.ehr.coding_scheme import DxLTC212FlatCodes, DxLTC9809FlatMedcodes, EthCPRD5, EthCPRD16
from lib.ehr import OutcomeExtractor
from lib.ehr import Subject_JAX

code_scheme = {
    'dx': DxLTC9809FlatMedcodes(),
    'dx_outcome': OutcomeExtractor('dx_cprd_ltc9809'),
    'eth': EthCPRD5()
}
cprd_interface = Subject_JAX.from_dataset(cprd_dataset, code_scheme=code_scheme)
cprd_splits = cprd_interface.random_splits(split1=0.7, split2=0.85, random_seed=42)


In [10]:
cprd_percentiles = cprd_interface.dx_outcome_by_percentiles(20)
cprd_train_percentiles = cprd_interface.dx_outcome_by_percentiles(20, cprd_splits[0])

In [11]:
import jax.random as jrandom
import lib.ml as ml

key = jrandom.PRNGKey(0)

In the next cell a dictionary for the trainer class of each model.
The classname of the trainer is already specified in the experiment configs 'model_config'
For example, this is the configuration file of ICE-NODE experiment.

```json
{
    "emb": {
        "dx": {
           "decoder_n_layers": 2,
           "classname":  "MatrixEmbeddings",      
           "embeddings_size": 300
        }
    },
    "model": {
        "ode_dyn_label": "mlp3",
        "ode_init_var": 1e-7,
        "state_size": 30,
        "timescale": 30
    },
    "training": {
        "batch_size": 256,
        "decay_rate": [0.25, 0.33],
        "lr": [7e-5,  1e-3],
        "epochs": 60,
        "reg_hyperparams": {
            "L_dyn": 1000.0,
            "L_l1": 0,
            "L_l2": 0
        },
        "opt": "adam",
        "classname": "ODETrainer2LR" <---- "classname, so this class should be available through ml package."
    }
}
```

Since we have a string of the classname, one way to get `ml.ODETrainer2LR` is `getattr(ml, 'ODETrainer2LR')`

In [12]:
cprd_models = {clf: model_cls[clf].from_config(model_config[clf],
                                              cprd_interface,
                                              cprd_splits[0],
                                              key) for clf in clfs}




cprd_trainers_cls = {clf: getattr(ml, model_config[clf]["training"]["classname"]) for clf in clfs}
cprd_trainers = {clf: cprd_trainers_cls[clf](**model_config[clf]["training"]) for clf in clfs}

## Metrics of Interest Specification

In [13]:
from lib.metric import (CodeAUC, UntilFirstCodeAUC, AdmissionAUC, CodeGroupTopAlarmAccuracy, MetricsCollection)

## Evaluation Metrics per Model

1. *CodeAUC*: evaluates the prediction AUC per code (aggregating over all subject visits, for all subjects)
2. *UntilFirstCodeAUC*: same as *CodeAUC*, but evaluates the prediction AUC until the first occurrence for each subject, once the code occured, all the subsequent visits are ignored for that code. If the code does not show in a particular subject, all the subject visits are ignored.
3. *AdmissionAUC*: evaluates the prediction AUC per visit (i.e. probability of assigning higher risk values for present codes than the absent ones).
4. *CodeGroupTopAlarmAccuracy*: partition codes into groups according the code frequency (from the most frequent to the least), and for each visit picks the top `k` risks, and the metric evaluates the accuracy of the top `k` riskiest codes by the model for being indeed present.
5. *MetricsCollection*: Groups multiple metrics to be considered at once.

In [14]:
# pecentile_range=20 will partition the codes into five gruops, where each group contains 
# codes that overall constitutes 20% of the codes in all visits of specified 'subjects' list.
code_freq_partitions = cprd_interface.dx_outcome_by_percentiles(percentile_range=20, subjects=cprd_splits[0])

# Evaluate for different k values
top_k_list = [3, 5, 10, 15, 20]

metrics = [CodeAUC(cprd_interface),
          UntilFirstCodeAUC(cprd_interface),
          AdmissionAUC(cprd_interface),
          CodeGroupTopAlarmAccuracy(cprd_interface, top_k_list=top_k_list, code_groups=code_freq_partitions)]
all_metrics = MetricsCollection(metrics)

<a name="sec1"></a>

### 1 Training ICE-NODE and The Baselines (#outline)

In [15]:
from lib.ml import MetricsHistory

def train(clf):
    output_dir = cprd_train_output_dir[clf]
    config = model_config[clf]
    model = cprd_models[clf]
    trainer = cprd_trainers[clf]
    reporters = [EvaluationDiskWriter(output_dir), # <- responsible for writing evaluation tables on disk at the given path
                 ParamsDiskWriter(output_dir), # <- responsible for writing model parameters snapshot after each iteration.
                 ConfigDiskWriter(output_dir, config), # writes the config file as JSON
                ]
    
    history = MetricsHistory(all_metrics) # <- empty history
    
    return trainer(model, cprd_interface, cprd_splits, history=history, reporters=reporters, prng_seed=42)

#### ICE-NODE

In [16]:
## TODO: This may take a long time, a pretrained model already exists in (yy).
icenode_results = train('ICE-NODE')

hiii init_opt
Entering jdb:
(jdb) 

SystemExit: 0

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


#### ICE-NODE_UNIFORM

In [13]:
## TODO: This can take up to (xx), trained model already exist in (yy).
cprd_trained_icenode_uni = T.train(cprd_models['ICE-NODE_UNIFORM'], config=model_config['ICE-NODE_UNIFORM'], 
                                 splits=cprd_splits, code_groups=cprd_train_percentiles,
                                 reporters=cprd_reporters['ICE-NODE_UNIFORM'])


  leaves, treedef = jax.tree_flatten(tree)
  return jax.tree_unflatten(treedef, leaves)
  leaves, treedef = jax.tree_flatten(tree)
  return jax.tree_unflatten(treedef, leaves)
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(

  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████

#### GRU

In [14]:
## TODO: This can take up to (xx), trained model already exist in (yy).
cprd_trained_gru = T.train(cprd_models['GRU'], config=model_config['GRU'], 
                         splits=cprd_splits, code_groups=cprd_train_percentiles,
                         reporters=cprd_reporters['GRU'])

  leaves, treedef = jax.tree_flatten(tree)
  return jax.tree_unflatten(treedef, leaves)
  leaves, treedef = jax.tree_flatten(tree)
  return jax.tree_unflatten(treedef, leaves)
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(

  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████

#### RETAIN

In [15]:
## TODO: This can take up to (xx), trained model already exist in (yy).
cprd_trained_retain = T.train(cprd_models['RETAIN'], config=model_config['RETAIN'], 
                         splits=cprd_splits, code_groups=cprd_train_percentiles,
                         reporters=cprd_reporters['RETAIN'])

  leaves, treedef = jax.tree_flatten(tree)
  return jax.tree_unflatten(treedef, leaves)
  leaves, treedef = jax.tree_flatten(tree)
  return jax.tree_unflatten(treedef, leaves)
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(

  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████

In [16]:
cprd_trained_logreg = T.train(cprd_models['LogReg'], config=model_config['LogReg'], 
                         splits=cprd_splits, code_groups=cprd_train_percentiles,
                         reporters=cprd_reporters['LogReg'])

  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
  rate[k][f'ACC-P{i}-k{k}'] = group_true_positive.sum(
