In [1]:
%load_ext autoreload
%autoreload 2
%env XLA_PYTHON_CLIENT_MEM_FRACTION=0.1
import sys
import os
import glob
import random
from collections import defaultdict
from pathlib import Path

from IPython.display import display

import pandas as pd
import dask

from tqdm import tqdm
import jax
jax.config.update('jax_platform_name', 'gpu')
# jax.config.update('jax_log_compiles', True)
# jax.config.update("jax_debug_nans", True)
# jax.config.update("jax_enable_x64", True)

env: XLA_PYTHON_CLIENT_MEM_FRACTION=0.1


In [2]:


sys.path.append("../..")

from lib import utils as U
from lib.ehr.dataset import load_dataset, load_dataset_scheme, Dataset
from lib.ehr.interface import Patients
from lib.ehr.concepts import CPRDDemographicVectorConfig, DemographicVectorConfig


In [3]:
import logging
import pprint

# logging.root.level = logging.DEBUG
pp = pprint.PrettyPrinter(indent=4)


In [4]:
# Assign the folder of the dataset to `DATA_FILE`.

HOME = os.environ.get('HOME')
DATA_FILE = f'{HOME}/Documents/DS211/users/tb1009/DATA/PAT_COHORT/ICENODE_SUBSET_1000.csv'
SOURCE_DIR = os.path.abspath("..")
cache_to_disk = 'cached_inteface/cprd_1000'
# use_cached = 'cached_inteface/cprd_1000' # 'cached_inteface/m4inpatients_8000' #cache_to_disk        # False # 
use_cached =  False #'cached_inteface/cprd_50000' # 'cached_inteface/m4inpatients_8000' #cache_to_disk        # False # 

##### Possible Interface Scheme Configurations

In [5]:
import json


interface_schem_options = load_dataset_scheme('CPRD').supported_target_scheme_options
print(json.dumps(interface_schem_options, sort_keys=True, indent=4))


{
    "dx": [
        "DxLTC9809FlatMedcodes",
        "DxLTC212FlatCodes"
    ],
    "ethnicity": [
        "CPRDEthnicity16",
        "CPRDEthnicity5"
    ],
    "gender": [
        "CPRDGender"
    ],
    "imd": [
        "CPRDIMDCategorical"
    ],
    "outcome": [
        "dx_cprd_ltc212",
        "dx_cprd_ltc9809"
    ]
}


In [6]:
cprd_interface_scheme_kw = dict(dx='DxLTC9809FlatMedcodes',
                                #outcome='dx_cprd_ltc212',
                                outcome='dx_cprd_ltc9809',
                                ethnicity='CPRDEthnicity5')
# Demographic vector attributes
demographic_vector_conf = CPRDDemographicVectorConfig(age=True,
                                                      gender=True,
                                                      ethnicity=True,
                                                      imd=True)

In [7]:
if use_cached:
    cprd_patients = Patients.load(use_cached)
    splits = cprd_patients.dataset.random_splits([0.8, 0.9], random_seed=42, balanced='admissions')

else:
    with U.modified_environ(DATA_FILE=DATA_FILE), dask.config.set(scheduler='processes', num_workers=12):
        
        # Load dataset
        cprd_dataset = load_dataset('CPRD', sample=None)
        # Use training-split for fitting the outlier_remover and the scalers.
        splits = cprd_dataset.random_splits([0.8, 0.9], random_seed=42, balanced='admissions')


        # Load interface
        cprd_patients = Patients(cprd_dataset, demographic_vector_conf,
                                **cprd_interface_scheme_kw).load_subjects(num_workers=12)

        # Cache to disk
        cprd_patients.save(cache_to_disk, overwrite=True)

In [8]:
len(cprd_patients.subjects)

1000

In [9]:
from lib.ml import (ICENODE, ICENODEDimensions, OutpatientEmbeddingDimensions, 
                    Trainer, TrainerReporting, OptimizerConfig, WarmupConfig, GRU, GRUDimensions,
                   RETAIN, RETAINDimensions)
from lib.metric import  (CodeAUC, UntilFirstCodeAUC, AdmissionAUC,
                      CodeGroupTopAlarmAccuracy, LossMetric, ObsCodeLevelLossMetric)

import jax.random as jrandom

In [11]:
emb_dims = OutpatientEmbeddingDimensions(dx=30, demo=5)
key = jrandom.PRNGKey(0)

def icenode_model():
    dims = ICENODEDimensions(mem=15, emb=emb_dims)
    return ICENODE(dims=dims, schemes=cprd_patients.schemes, 
                   demographic_vector_config=cprd_patients.demographic_vector_config,
                   key=key)

def gru_model():
    dims = GRUDimensions(emb=emb_dims)
    return GRU(dims=dims, schemes=cprd_patients.schemes, 
                   demographic_vector_config=cprd_patients.demographic_vector_config,
                   key=key)

def retain_model():
    dims = RETAINDimensions(mem_a=25, mem_b=25, emb=emb_dims)
    return RETAIN(dims=dims, schemes=cprd_patients.schemes, 
                   demographic_vector_config=cprd_patients.demographic_vector_config,
                   key=key)

models = {
    #'dx_icenode': icenode_model(),
    'dx_gru': gru_model(),
    'dx_retain': retain_model()
}

2023-08-24 16:58:20.970995: W external/xla/xla/service/gpu/nvptx_compiler.cc:698] The NVIDIA driver's CUDA version is 11.6 which is older than the ptxas CUDA version (11.8.89). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [12]:
splits = cprd_patients.random_splits([0.9, 0.95], 
                                    balanced='admissions')

trainer = Trainer(optimizer_config=OptimizerConfig(opt='adam', lr=1e-3),
                    reg_hyperparams=None,
                    epochs=80,
                    batch_size=256,
                    dx_loss='balanced_focal_bce')

warmup = WarmupConfig(epochs=0.1, 
                      batch_size=8,
                      opt='adam', lr=1e-3, 
                      decay_rate=0.5)

loss_metric =  LossMetric(cprd_patients, 
                          dx_loss=('softmax_bce', 'balanced_focal_softmax_bce', 
                                   'balanced_focal_bce', 'allpairs_exp_rank', 'allpairs_hard_rank', 
                                   'allpairs_sigmoid_rank'))

metrics = [
     #CodeAUC(cprd_patients), 
           AdmissionAUC(cprd_patients), 
            CodeGroupTopAlarmAccuracy(cprd_patients, n_partitions=5, 
                                      top_k_list=[3, 5, 10, 15, 20],
                                      train_split=splits[0]), 
           loss_metric]


reporting = TrainerReporting(output_dir='dx_icenode',
                             metrics=metrics,
                             console=True,
                             parameter_snapshots=True,
                             config_json=True)

  frequency_vec = frequency_vec / frequency_vec.sum()


In [None]:
res = {}
for name in ['dx_gru', 'dx_retain']:
    model = models[name]
    print(name)
    reporting = TrainerReporting(output_dir=name,
                                 metrics=metrics,
                                 console=True,
                                 model_stats=False,
                                 parameter_snapshots=True,
                                 config_json=True)
    res[name] = trainer(model, cprd_patients, 
                  splits=splits,
                  reporting=reporting,
                  n_evals=800,
                  warmup_config=None,
                  continue_training=False)

dx_retain


Loading to device:   0%|          | 0/60 [00:00<?, ?subject/s]

  0%|          | 0/80 [00:00<?, ?Epoch/s]

  0%|          | 0/56 [00:00<?, ?Batch/s]

Loading to device:   0%|          | 0/23 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/23 [00:00<?, ?subject/s]

  0%|          | 0.00/137448.00 [00:00<?, ?longitudinal-days/s]