In [1]:
%load_ext autoreload
%autoreload 2
import sys
import os
import glob
import random
from collections import defaultdict
from pathlib import Path

from IPython.display import display

import pandas as pd
import dask

from tqdm import tqdm
import jax
jax.config.update('jax_platform_name', 'cpu')
# jax.config.update('jax_log_compiles', True)
# jax.config.update("jax_debug_nans", True)
# jax.config.update("jax_enable_x64", True)

In [2]:


sys.path.append("../..")

from lib import utils as U
from lib.ehr.dataset import load_dataset, load_dataset_scheme, Dataset
from lib.ehr.interface import Patients
from lib.ehr.concepts import CPRDDemographicVectorConfig, DemographicVectorConfig


In [3]:
import logging
import pprint

logging.root.level = logging.DEBUG
pp = pprint.PrettyPrinter(indent=4)


In [4]:
# Assign the folder of the dataset to `DATA_FILE`.

HOME = os.environ.get('HOME')
DATA_FILE = f'{HOME}/GP/ehr-data/cprd-data/DUMMY_DATA.csv'
SOURCE_DIR = os.path.abspath("..")
cache_to_disk = None #'cached_inteface/m4inpatients_8000'
use_cached =  False # 'cached_inteface/m4inpatients_8000' #cache_to_disk        # False # 

##### Possible Interface Scheme Configurations

In [5]:
import json


interface_schem_options = load_dataset_scheme('CPRD').supported_target_scheme_options
print(json.dumps(interface_schem_options, sort_keys=True, indent=4))


DEBUG:root:Constructing dx_cprd_ltc9809 (<class 'lib.ehr.coding_scheme.DxLTC9809FlatMedcodes'>) scheme
DEBUG:root:Constructing eth_cprd_16 (<class 'lib.ehr.coding_scheme.CPRDEthnicity16'>) scheme
DEBUG:root:Constructing cprd_gender (<class 'lib.ehr.coding_scheme.CPRDGender'>) scheme
DEBUG:root:Constructing cprd_imd_cat (<class 'lib.ehr.coding_scheme.CPRDIMDCategorical'>) scheme


{
    "dx": [
        "DxLTC9809FlatMedcodes",
        "DxLTC212FlatCodes"
    ],
    "ethnicity": [
        "CPRDEthnicity16",
        "CPRDEthnicity5"
    ],
    "gender": [
        "CPRDGender"
    ],
    "imd": [
        "CPRDIMDCategorical"
    ],
    "outcome": [
        "dx_cprd_ltc212",
        "dx_cprd_ltc9809"
    ]
}


In [6]:
cprd_interface_scheme_kw = dict(dx='DxLTC212FlatCodes',
                                outcome='dx_cprd_ltc212',
#                               outcome='dx_cprd_ltc9809',
                                ethnicity='CPRDEthnicity5')

In [7]:
if use_cached:
    cprd_patients = Patients.load(use_cached)
    splits = cprd_patients.dataset.random_splits([0.8, 0.9], random_seed=42, balanced='admissions')

else:
    with U.modified_environ(DATA_FILE=DATA_FILE), dask.config.set(scheduler='processes', num_workers=12):
        
        # Load dataset
        cprd_dataset = load_dataset('CPRD', sample=None)
        # Use training-split for fitting the outlier_remover and the scalers.
        splits = cprd_dataset.random_splits([0.8, 0.9], random_seed=42, balanced='admissions')

        # Demographic vector attributes
        demographic_vector_conf = CPRDDemographicVectorConfig(age=True,
                                                              gender=True,
                                                              ethnicity=True,
                                                              imd=True)
        # Load interface
        cprd_patients = Patients(cprd_dataset, demographic_vector_conf,
                                **cprd_interface_scheme_kw).load_subjects(num_workers=12)

        # Cache to disk
#         cprd_patients.save(cache_to_disk, overwrite=True)

DEBUG:root:Constructing dx_cprd_ltc9809 (<class 'lib.ehr.coding_scheme.DxLTC9809FlatMedcodes'>) scheme
DEBUG:root:Constructing eth_cprd_16 (<class 'lib.ehr.coding_scheme.CPRDEthnicity16'>) scheme
DEBUG:root:Constructing cprd_gender (<class 'lib.ehr.coding_scheme.CPRDGender'>) scheme
DEBUG:root:Constructing cprd_imd_cat (<class 'lib.ehr.coding_scheme.CPRDIMDCategorical'>) scheme
DEBUG:root:Removing subjects by matching demographic(-0)and admissions(-0)tables
DEBUG:root:Constructing dx_cprd_ltc212 (<class 'lib.ehr.coding_scheme.DxLTC212FlatCodes'>) scheme
DEBUG:root:Constructing dx_cprd_ltc212 (<class 'lib.ehr.coding_scheme.OutcomeExtractor'>) scheme
DEBUG:root:Constructing dx_cprd_ltc212 (<class 'lib.ehr.coding_scheme.DxLTC212FlatCodes'>) scheme
DEBUG:root:Constructing eth_cprd_5 (<class 'lib.ehr.coding_scheme.CPRDEthnicity5'>) scheme
DEBUG:root:Constructing cprd_gender (<class 'lib.ehr.coding_scheme.CPRDGender'>) scheme
DEBUG:root:Constructing cprd_imd_cat (<class 'lib.ehr.coding_schem

In [8]:
len(cprd_patients.subjects)

6

In [9]:
from lib.ml import (ICENODE, ICENODEDimensions, OutpatientEmbeddingDimensions, 
                    Trainer, TrainerReporting, OptimizerConfig, WarmupConfig)
from lib.metric import  (CodeAUC, UntilFirstCodeAUC, AdmissionAUC,
                      CodeGroupTopAlarmAccuracy, LossMetric, ObsCodeLevelLossMetric)

import jax.random as jrandom

DEBUG:matplotlib:matplotlib data path: /home/asem/GP/env/icenode-dev/lib/python3.9/site-packages/matplotlib/mpl-data
DEBUG:matplotlib:CONFIGDIR=/home/asem/.config/matplotlib
DEBUG:matplotlib:interactive is False
DEBUG:matplotlib:platform is linux
DEBUG:matplotlib:CACHEDIR=/home/asem/.cache/matplotlib
DEBUG:matplotlib.font_manager:Using fontManager instance from /home/asem/.cache/matplotlib/fontlist-v330.json


In [10]:
emb_dims = OutpatientEmbeddingDimensions(dx=30, demo=5)
dims = ICENODEDimensions(mem=15, emb=emb_dims)
key = jrandom.PRNGKey(0)

m = ICENODE(dims=dims, 
            schemes=cprd_patients.schemes,
            demographic_vector_config=cprd_patients.demographic_vector_config,
            key=key)

DEBUG:jax._src.xla_bridge:Initializing backend 'cpu'
DEBUG:jax._src.xla_bridge:Backend 'cpu' initialized
DEBUG:jax._src.xla_bridge:Initializing backend 'cuda'
DEBUG:jax._src.xla_bridge:Backend 'cuda' initialized
DEBUG:jax._src.xla_bridge:Initializing backend 'rocm'
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: Interpreter CUDA
DEBUG:jax._src.xla_bridge:Initializing backend 'tpu'
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
DEBUG:jax._src.interpreters.pxla:Compiling _threefry_seed for with global shapes and types [ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _threefry_split for with global shapes and types [ShapedA

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(uint32[2,2]), ShapedArray(int32[]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling squeeze for with global shapes and types [ShapedArray(uint32[1,2])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _uniform for with global shapes and types [ShapedArray(key<fry>[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compil

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _uniform for with global shapes and types [ShapedArray(key<fry>[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _uniform for with global shapes and types [ShapedArray(key<fry>[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _uniform

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _uniform for with global shapes and types [ShapedArray(key<fry>[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _uniform for with global shapes and types [ShapedArray(key<fry>[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _threefr

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _uniform for with global shapes and types [ShapedArray(key<fry>[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _uniform for with global shapes and types [ShapedArray(key<fry>[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _uniform

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _uniform for with global shapes and types [ShapedArray(key<fry>[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _uniform for with global shapes and types [ShapedArray(key<fry>[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _uniform

In [11]:
splits = cprd_patients.random_splits([0.9, 0.95], 
                                    balanced='admissions')

trainer = Trainer(optimizer_config=OptimizerConfig(opt='adam', lr=1e-3),
                    reg_hyperparams=None,
                    epochs=80,
                    batch_size=128,
                    dx_loss='balanced_focal_softmax_bce')

warmup = WarmupConfig(epochs=0.1, 
                      batch_size=8,
                      opt='adam', lr=1e-3, 
                      decay_rate=0.5)

loss_metric =  LossMetric(cprd_patients, 
                          dx_loss=('softmax_bce', 'balanced_focal_softmax_bce', 
                                   'balanced_focal_bce', 'allpairs_exp_rank', 'allpairs_hard_rank', 
                                   'allpairs_sigmoid_rank'))

metrics = [CodeAUC(cprd_patients), 
           AdmissionAUC(cprd_patients), 
           CodeGroupTopAlarmAccuracy(cprd_patients, n_partitions=5, 
                                     top_k_list=[3, 5, 10, 15, 20],
                                     train_split=splits[0]), 
           loss_metric]


reporting = TrainerReporting(output_dir='dx_icenode',
                             metrics=metrics,
                             console=True,
                             parameter_snapshots=True,
                             config_json=True)

  frequency_vec = frequency_vec / frequency_vec.sum()


In [22]:
splits

[['4', '2', '3', '5', '1'], [], ['6']]

In [24]:
res = trainer(m, cprd_patients, 
              splits=splits,
              reporting=reporting,
              n_evals=100,
              warmup_config=warmup,
              continue_training=False)

INFO:root:Warming up...


batch_size 0
n_train_admissions 0
train_ids 1
Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/home/asem/GP/env/icenode-dev/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3505, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_347903/814137775.py", line 1, in <module>
    res = trainer(m, cprd_patients,
  File "/home/asem/GP/ICENODE/notebooks/cprd_dx/../../lib/ml/trainer.py", line 581, in __call__
    model = self._warmup(model=model,
  File "/home/asem/GP/ICENODE/notebooks/cprd_dx/../../lib/ml/trainer.py", line 620, in _warmup
    return trainer._train(model=model,
  File "/home/asem/GP/ICENODE/notebooks/cprd_dx/../../lib/ml/trainer.py", line 650, in _train
    iters = round(self.epochs * n_train_admissions / batch_size)
ZeroDivisionError: float division by zero

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/asem/GP/env/icenode-dev/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 2102, 