In [1]:
%load_ext autoreload
%autoreload 2
%env XLA_PYTHON_CLIENT_MEM_FRACTION=0.1
import sys
import os
import glob
import random
from collections import defaultdict
from pathlib import Path

from IPython.display import display

import pandas as pd
import dask

from tqdm import tqdm
import jax
jax.config.update('jax_platform_name', 'cpu')
# jax.config.update('jax_log_compiles', True)
# jax.config.update("jax_debug_nans", True)
# jax.config.update("jax_enable_x64", True)

env: XLA_PYTHON_CLIENT_MEM_FRACTION=0.1


In [2]:


sys.path.append("../..")

from lib import utils as U
from lib.ehr.dataset import load_dataset, load_dataset_scheme, load_dataset_config, Dataset
from lib.ehr.interface import Patients, InterfaceConfig
from lib.ehr.concepts import CPRDDemographicVectorConfig, DemographicVectorConfig


In [3]:
import logging

# logging.root.level = logging.DEBUG


In [4]:
tag = 'CPRD'
PATH = f'{os.environ.get("HOME")}/GP/ehr-data/cprd-data/DUMMY_DATA.csv'
# PATH = f'{HOME}/Documents/DS211/users/tb1009/DATA/PAT_COHORT/ICENODE_SUBSET_1000.csv'

sample = None
cache =  f'cached_inteface/cprd_1000'
dataset_config = load_dataset_config(tag, 
                                     sample=sample,
                                     path=PATH)
dataset_gen = lambda c: load_dataset(config=c)

##### Possible Interface Scheme Configurations

In [5]:
import json
dataset_scheme = load_dataset_scheme(tag)
interface_schem_options = load_dataset_scheme(tag).supported_target_scheme_options
print(json.dumps(interface_schem_options, sort_keys=True, indent=4))


{
    "dx": [
        "DxLTC9809FlatMedcodes",
        "DxLTC212FlatCodes"
    ],
    "ethnicity": [
        "CPRDEthnicity16",
        "CPRDEthnicity5"
    ],
    "gender": [
        "CPRDGender"
    ],
    "imd": [
        "CPRDIMDCategorical"
    ],
    "outcome": [
        "dx_cprd_ltc212",
        "dx_cprd_ltc9809"
    ]
}


In [6]:

# Demographic vector attributes
demographic_vector_conf = CPRDDemographicVectorConfig(age=True,
                                                      gender=True,
                                                      ethnicity=True,
                                                      imd=True)

interface_scheme = dataset_scheme.make_target_scheme_config(dx='DxLTC9809FlatMedcodes',
                                                            #outcome='dx_cprd_ltc212',
                                                            outcome='dx_cprd_ltc9809',
                                                            ethnicity='CPRDEthnicity5')
interface_config = InterfaceConfig(scheme=interface_scheme,
                                   dataset_scheme=dataset_scheme,
                                   demographic_vector=demographic_vector_conf,
                                   cache=cache)


In [7]:
interface_config

InterfaceConfig(
  demographic_vector=CPRDDemographicVectorConfig(
    gender=True,
    age=True,
    ethnicity=True,
    imd=True
  ),
  leading_observable=None,
  scheme={
    'dx':
    'DxLTC9809FlatMedcodes',
    'ethnicity':
    'CPRDEthnicity5',
    'gender':
    'CPRDGender',
    'imd':
    'CPRDIMDCategorical',
    'outcome':
    'dx_cprd_ltc9809'
  },
  cache='cached_inteface/cprd_1000'
)

In [8]:
cprd_patients = Patients.try_load_cached(interface_config,
                                          dataset_config=dataset_config,
                                          dataset_generator=dataset_gen,
                                          num_workers=8)

In [9]:
len(cprd_patients.subjects)

180

In [13]:
from lib.ml import (ICENODE, ICENODEDimensions, 
                    GRU, GRUDimensions,
                    RETAIN, RETAINDimensions,
                    OutpatientEmbeddingDimensions, 
                    Trainer, TrainerReporting, OptimizerConfig, WarmupConfig,
                    TrainerConfig, ReportingConfig)
from lib.metric import  (CodeAUC, UntilFirstCodeAUC, AdmissionAUC,
                      CodeGroupTopAlarmAccuracy, LossMetric, ObsCodeLevelLossMetric)

import jax.random as jrandom

In [14]:
emb_dims = OutpatientEmbeddingDimensions(dx=30, demo=5)
key = jrandom.PRNGKey(0)

def icenode_model():
    dims = ICENODEDimensions(mem=15, emb=emb_dims)
    return ICENODE(dims=dims, schemes=cprd_patients.schemes, 
                   demographic_vector_config=cprd_patients.config.demographic_vector,
                   key=key)

def gru_model():
    dims = GRUDimensions(emb=emb_dims)
    return GRU(dims=dims, schemes=cprd_patients.schemes, 
                   demographic_vector_config=cprd_patients.config.demographic_vector,
                   key=key)

def retain_model():
    dims = RETAINDimensions(mem_a=25, mem_b=25, emb=emb_dims)
    return RETAIN(dims=dims, schemes=cprd_patients.schemes, 
                   demographic_vector_config=cprd_patients.config.demographic_vector,
                   key=key)

models = {
    #'dx_icenode': icenode_model(),
    'dx_gru': gru_model(),
    'dx_retain': retain_model()
}


In [15]:
splits = cprd_patients.random_splits([0.9, 0.95], 
                                    balanced='admissions')



trainer_conf = TrainerConfig(optimizer=OptimizerConfig(opt='adam', lr=1e-3),
                             epochs=80,
                             batch_size=256,
                             dx_loss='balanced_focal_bce')
trainer = Trainer(trainer_conf)

loss_metric =  LossMetric(cprd_patients, 
                          dx_loss=('softmax_bce', 'balanced_focal_softmax_bce', 
                                   'balanced_focal_bce'))

metrics = [
     CodeAUC(cprd_patients), 
           AdmissionAUC(cprd_patients), 
            CodeGroupTopAlarmAccuracy(cprd_patients, n_partitions=5, 
                                      top_k_list=[3, 5, 10, 15, 20],
                                      train_split=splits[0]), 
           loss_metric]

In [16]:
res = {}
for name in ['dx_gru', 'dx_retain']:
    model = models[name]
    print(name)
    reporting_conf = ReportingConfig(output_dir=name,
                                     console=True,
                                     model_stats=False,
                                     parameter_snapshots=True,
                                     config_json=True)
    reporting = TrainerReporting(reporting_conf, metrics=metrics)
    
    res[name] = trainer(model, cprd_patients, 
                  splits=splits,
                  reporting=reporting,
                  n_evals=100,
                  warmup_config=None,
                  continue_training=False)

dx_gru


Loading to device:   0%|          | 0/8 [00:00<?, ?subject/s]

  0%|          | 0/80 [00:00<?, ?Epoch/s]

  0%|          | 0/1 [00:00<?, ?Batch/s]

Loading to device:   0%|          | 0/163 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/163 [00:00<?, ?subject/s]

  0%|          | 0.00/259239.00 [00:00<?, ?longitudinal-days/s]

Embedding:   0%|          | 0/163 [00:00<?, ?subject/s]

  0%|          | 0.00/259239.00 [00:00<?, ?longitudinal-days/s]

Embedding:   0%|          | 0/8 [00:00<?, ?subject/s]

  0%|          | 0.00/14123.00 [00:00<?, ?longitudinal-days/s]

  0%|          | 0/1 [00:00<?, ?Batch/s]

Loading to device:   0%|          | 0/163 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/163 [00:00<?, ?subject/s]

  0%|          | 0.00/259239.00 [00:00<?, ?longitudinal-days/s]

Embedding:   0%|          | 0/163 [00:00<?, ?subject/s]

  0%|          | 0.00/259239.00 [00:00<?, ?longitudinal-days/s]

KeyboardInterrupt: 