# Table of Content

<a name="outline"></a>

## Setup

- [A](#seca) External Imports
- [B](#secb) Internal Imports
- [C](#secc) Configurations and Paths 
- [D](#secd) JAX Interface
- [E](#sece) General Utility Functions


## Training

- [1](#sec1) Training ICE-NODE and The Baselines


<a name="seca"></a>

### A External Imports [^](#outline)

In [6]:
# pip install --upgrade jax==0.4.1 jaxlib==0.4.1+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

In [1]:
# set pre-allocated percentage of GPU memory (here using 20% of the 48GB)
%env XLA_PYTHON_CLIENT_MEM_FRACTION=0.2

env: XLA_PYTHON_CLIENT_MEM_FRACTION=0.2


In [2]:
%load_ext autoreload
%autoreload 2
import sys
import os
import glob
import random
from collections import defaultdict
from pathlib import Path

from IPython.display import display

import pandas as pd

from tqdm import tqdm
import jax
jax.config.update('jax_platform_name', 'gpu')

<a name="secb"></a>

### B Internal Imports [^](#outline)

In [3]:


sys.path.append("..")

from lib import utils as U
from lib.ehr.dataset import load_dataset

In [5]:
# Assign the folder of the dataset to `DATA_FILE`.

# HOME and DATA_STORE are arbitrary, change as appropriate.
HOME = os.environ.get('HOME')
DATA_FILE = f'{HOME}/Documents/DS211/users/tb1009/DATA/ICE_TEST_50000.csv'
SOURCE_DIR = os.path.abspath("..")

<a name="secd"></a>

### C Configurations and Paths [^](#outline)

In [6]:
output_dir = 'cprd_artefacts'
Path(output_dir).mkdir(parents=True, exist_ok=True)

In [7]:
with U.modified_environ(DATA_FILE=DATA_FILE):
    cprd_dataset = load_dataset('CPRD')
   

In [8]:
from lib.ml import ICENODE, ICENODE_UNIFORM, GRU, RETAIN, WindowLogReg

"""
predefined hyperparams re: each model.
"""

model_cls = {
    'ICE-NODE': ICENODE,
    'ICE-NODE_UNIFORM': ICENODE_UNIFORM,
    'GRU': GRU,
    'RETAIN': RETAIN,
    'LogReg': WindowLogReg
}

model_config = {
    'ICE-NODE': f'{SOURCE_DIR}/expt_configs/icenode.json' ,
    'ICE-NODE_UNIFORM': f'{SOURCE_DIR}/expt_configs/icenode.json' ,
    'GRU': f'{SOURCE_DIR}/expt_configs/gru.json' ,
    'RETAIN': f'{SOURCE_DIR}/expt_configs/retain.json',
    'LogReg': f'{SOURCE_DIR}/expt_configs/window_logreg.json'
}

model_config = {clf: U.load_config(file) for clf, file in model_config.items()}

clfs = ['ICE-NODE', 'ICE-NODE_UNIFORM', 'GRU', 'RETAIN', 'LogReg']

In [9]:
cprd_train_output_dir = {clf: f'{output_dir}/train/{clf}' for clf in clfs}

[Path(d).mkdir(parents=True, exist_ok=True) for d in cprd_train_output_dir.values()]

[None, None, None, None, None]

In [10]:
from lib.ml import ConfigDiskWriter, MinibatchLogger, EvaluationDiskWriter, ParamsDiskWriter

The reporter objects are called inside training iterations
Each has its own functionality:
1. **ConfigDiskWriter**: writes the experiment config file as JSON in the same training directory
2. **MinibatchLogger**: writes to the console the training progress details.
3. **EvaluationDiskWriter**: writes the evaluation as csv tables in the same training directory for each step of the 100.
4. **ParamsDiskWriter**: writes the model parameters snapshot at each step out of 100.

In [11]:
make_reporters = lambda output_dir, config: [ConfigDiskWriter(output_dir=output_dir, config=config),
                                             MinibatchLogger(config),
                                             EvaluationDiskWriter(output_dir=output_dir),
                                             ParamsDiskWriter(output_dir=output_dir)]

reporters = {model: make_reporters(cprd_train_output_dir[model], model_config[model]) for model in clfs}

<a name="sece"></a>

### D JAX Interface [^](#outline)

In [12]:
from lib.ehr.coding_scheme import DxLTC212FlatCodes, DxLTC9809FlatMedcodes, EthCPRD5, EthCPRD16
from lib.ehr import OutcomeExtractor, SurvivalOutcomeExtractor
from lib.ehr import Subject_JAX
from lib.ehr import StaticInfoFlags


The dictionary `code_scheme` in the next cell specifies the code spaces of:
- 'dx': diagnostic input (input features) codes. Possible arguments: :
    - `DxLTC9809FlatMedcodes()` for medcodes or
    - `DxLTC212FlatCodes()` for disease nums. 
- 'outcome': diagnostic outcome (target to be predicted) codes. Possible arguments: 
    - `OutcomeExtractor('dx_cprd_ltc9809')` for medcodes prediction or 
    - `OutcomeExtractor('dx_cprd_ltc212')` for disease num predictions, or 
    - `SurvivalOutcomeExtractor('dx_cprd_ltc9809')` for medcodes prediction (first occurrence per patient) or 
    - `SurvivalOutcomeExtractor('dx_cprd_ltc212')` for disease num predictions (first occurrence per patient), or 
- 'eth': ethinicity codes. Possible arguments:
    - `EthCPRD16()` to consider the 16 classifications of ethnicity.
    - `EthCPRD5()` to consider the 5 classifications of ethnicity.
    

**Note**: OutcomeExtractor can be provided only a subset of the diagnostic codes. For example,
you can focus the prediction objective on a small subset of interest (e.g. to predict only pulmonary-heart 
diseases codes, etc.).
OutcomeExtractor can also be replaced by SurvivalOutcomeExtractor to enforce the prediction 
objective to predict only the first occurrence of each code for one patient, and subsequent
redundant occurences will be avoided and not incorporated in the loss function.

In [13]:
code_scheme = {
    #'dx': DxLTC9809FlatMedcodes(), # other options 
    'dx': DxLTC212FlatCodes(),
    #'outcome': OutcomeExtractor('dx_cprd_ltc9809'),
    'outcome': OutcomeExtractor('dx_cprd_ltc212'),
    # Comment above^, and uncomment below, to consider only the first occurrence of codes per subject.
    # 'outcome': FirstOccurrenceOutcomeExtractor('dx_cprd_ltc9809'),
    'eth': EthCPRD5()
}

### Adding Demographic Information in Training

What do you need to include as control features? **Uncomment each line to consider the corresponding static information.**

In [None]:
static_info_flags = StaticInfoFlags(
 gender=True,
 age=True,
 idx_deprivation=True,
 ethnicity=EthCPRD5(), # <- include it by the category of interest, not just 'True'.
)
cprd_interface = Subject_JAX.from_dataset(cprd_dataset, code_scheme=code_scheme, static_info_flags=static_info_flags)
cprd_splits = cprd_interface.random_splits(split1=0.7, split2=0.85, random_seed=42)

In [None]:
import jax.random as jrandom
import lib.ml as ml

key = jrandom.PRNGKey(0)

In the next cell we load a dictionary for each model specifiying the experiment configuration per model.
The classname of the trainer used is also specified in the experiment configs.
For example, this is the configuration file of ICE-NODE experiment.

```json
{
    "emb": {
        "dx": {
           "decoder_n_layers": 2,
           "classname":  "MatrixEmbeddings",      
           "embeddings_size": 300
        }
    },
    "model": {
        "ode_dyn_label": "mlp3",
        "ode_init_var": 1e-7,
        "state_size": 30,
        "timescale": 30
    },
    "training": {
        "batch_size": 256,
        "decay_rate": [0.25, 0.33],
        "lr": [7e-5,  1e-3],
        "epochs": 60,
        "reg_hyperparams": {
            "L_dyn": 1000.0,
            "L_l1": 0,
            "L_l2": 0
        },
        "opt": "adam",
        "classname": "ODETrainer2LR" <---- "classname, so this class should be available through ml package."
    }
}
```

Since we have a string of the classname, one way to get `ml.ODETrainer2LR` is `getattr(ml, 'ODETrainer2LR')`

In [None]:
cprd_models = {clf: model_cls[clf].from_config(model_config[clf],
                                              cprd_interface,
                                              cprd_splits[0],
                                              key) for clf in clfs}




cprd_trainers_cls = {clf: getattr(ml, model_config[clf]["training"]["classname"]) for clf in clfs}
cprd_trainers = {clf: cprd_trainers_cls[clf](**model_config[clf]["training"]) for clf in clfs}

## Metrics of Interest Specification

In [None]:
from lib.metric import (CodeAUC, UntilFirstCodeAUC, AdmissionAUC, CodeGroupTopAlarmAccuracy, MetricsCollection, LossMetric)

## Evaluation Metrics per Model

1. *CodeAUC*: evaluates the prediction AUC per code (aggregating over all subject visits, for all subjects)
2. *UntilFirstCodeAUC*: same as *CodeAUC*, but evaluates the prediction AUC until the first occurrence for each subject, once the code occured, all the subsequent visits are ignored for that code. If the code does not show in a particular subject, all the subject visits are ignored.
3. *AdmissionAUC*: evaluates the prediction AUC per visit (i.e. probability of assigning higher risk values for present codes than the absent ones).
4. *CodeGroupTopAlarmAccuracy*: partition codes into groups according the code frequency (from the most frequent to the least), and for each visit picks the top `k` risks, and the metric evaluates the accuracy of the top `k` riskiest codes by the model for being indeed present.
5. *MetricsCollection*: Groups multiple metrics to be considered at once.

**Note:** you will get different results when calling the method `outcome_by_percentiles` by changing the 'outcome' enty of the `ode_scheme` dictionary as following:
- OutcomeExtractor: the counting will consider the code and its redundant occurrence for each subject, then aggregated over all subjects 
- FirstOccurrenceOutcomeExtractor: the counting will consider the first occurrence only for each subject, then aggregated over all subjects.

In [None]:
# pecentile_range=20 will partition the codes into five gruops, where each group contains 
# codes that overall constitutes 20% of the codes in all visits of specified 'subjects' list.
code_freq_partitions = cprd_interface.outcome_by_percentiles(percentile_range=20, subjects=cprd_splits[0])

# Evaluate for different k values
top_k_list = [3, 5, 10, 15, 20]

metrics = [CodeAUC(cprd_interface),
          UntilFirstCodeAUC(cprd_interface),
          AdmissionAUC(cprd_interface),
           LossMetric(cprd_interface),
          CodeGroupTopAlarmAccuracy(cprd_interface, top_k_list=top_k_list, code_groups=code_freq_partitions)]
all_metrics = MetricsCollection(metrics)

<a name="sec1"></a>

### 1 Training ICE-NODE and The Baselines (#outline)

In [None]:
from lib.ml import MetricsHistory

def train(clf, **kwargs):
    output_dir = cprd_train_output_dir[clf]
    config = model_config[clf]
    model = cprd_models[clf]
    trainer = cprd_trainers[clf]
    reporters = [EvaluationDiskWriter(output_dir), # <- responsible for writing evaluation tables on disk at the given path
                 ParamsDiskWriter(output_dir), # <- responsible for writing model parameters snapshot after each iteration.
                 ConfigDiskWriter(output_dir, config), # writes the config file as JSON
                ]
    
    history = MetricsHistory(all_metrics) # <- empty history
    
    return trainer(model, cprd_interface, cprd_splits, 
                   history=history, reporters=reporters, prng_seed=42, **kwargs)

#### ICE-NODE

In [None]:
## TODO: This may take a long time, a pretrained model already exists in (yy).
icenode_results = train('ICE-NODE', continue_training=False)

#### ICE-NODE_UNIFORM

In [None]:
icenode_uniform_results = train('ICE-NODE_UNIFORM')

#### GRU

In [21]:
gru_results = train('GRU')

  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_w

  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_write(zinfo, force_zip64=force_zip64)
  return self._open_to_w

#### RETAIN

In [22]:
retain_results = train('RETAIN')

  1%|▏         | 18/1320 [9:37:55<696:43:14, 1926.42s/it] 


KeyboardInterrupt: 

In [None]:

logreg_results = train('LogReg')

  0%|          | 0/145 [00:00<?, ?it/s]