In [1]:
%load_ext autoreload
%autoreload 2
import sys
import os
import glob
import random
from collections import defaultdict
from pathlib import Path

from IPython.display import display

import pandas as pd

from tqdm import tqdm
import jax
jax.config.update('jax_platform_name', 'cpu')

In [2]:


sys.path.append("../..")

from lib import utils as U
from lib.ehr.dataset import load_dataset
from lib.ehr.inpatient_interface import Inpatients

In [3]:
import logging
logging.root.level = logging.DEBUG


In [4]:
# Assign the folder of the dataset to `DATA_FILE`.

HOME = os.environ.get('HOME')
DATA_DIR = f'{HOME}/GP/ehr-data'
SOURCE_DIR = os.path.abspath("..")

with U.modified_environ(DATA_DIR=DATA_DIR):
    m4inpatient_dataset = load_dataset('M4ICU', max_workers=1)
   

DEBUG:root:Loading dataframe files
DEBUG:root:[DONE] Loading dataframe files
DEBUG:root:Matching admission_id
DEBUG:root:[DONE] Matching admission_id
DEBUG:root:Time casting..
DEBUG:root:[DONE] Time casting..


INFO: Pandarallel will run on 1 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.


DEBUG:root:Dataframes validation and time conversion
INFO:root:Unrecognised ICD v10 codes: 3323 (28.74%)
DEBUG:root:
                    Unrecognised <class 'lib.ehr.coding_scheme.DxICD10'> codes (3323)
                    to be removed: ['E08.3513', 'E10.3213', 'E10.3219', 'E10.3291', 'E10.3293', 'E10.3299', 'E10.3313', 'E10.3319', 'E10.3393', 'E10.3399', 'E10.3411', 'E10.3413', 'E10.3491', 'E10.3511', 'E10.3512', 'E10.3513', 'E10.3519', 'E10.3522', 'E10.3523', 'E10.3531', 'E10.3532', 'E10.3559', 'E10.3591', 'E10.3592', 'E10.3593', 'E10.3599', 'E11.3213', 'E11.3219', 'E11.3291', 'E11.3292', 'E11.3293', 'E11.3299', 'E11.3311', 'E11.3313', 'E11.3319', 'E11.3391', 'E11.3393', 'E11.3399', 'E11.3413', 'E11.3419', 'E11.3491', 'E11.3492', 'E11.3493', 'E11.3499', 'E11.3513', 'E11.3519', 'E11.3521', 'E11.3532', 'E11.3542', 'E11.3553', 'E11.3591', 'E11.3592', 'E11.3593', 'E11.3599', 'H34.8112', 'H34.8120', 'H34.8122', 'H34.8192', 'H34.8310', 'H34.8320', 'H35.3110', 'H35.3120', 'H35.3130', 'H35.

INFO:root:Unrecognised ICD v9 codes: 118 (1.63%)
DEBUG:root:
                    Unrecognised <class 'lib.ehr.coding_scheme.DxICD9'> codes (118)
                    to be removed: ['041.49', '173.21', '173.22', '173.30', '173.31', '173.32', '173.40', '173.41', '173.42', '173.50', '173.51', '173.52', '173.59', '173.60', '173.61', '173.62', '173.70', '173.71', '173.72', '173.79', '173.80', '173.81', '173.82', '173.91', '173.92', '173.99', '282.40', '282.43', '282.44', '282.46', '284.11', '284.12', '284.19', '286.52', '286.53', '286.59', '294.20', '294.21', '310.81', '310.89', '331.6', '348.82', '358.30', '365.70', '365.72', '365.73', '414.4', '415.13', '425.11', '425.18', '444.09', '488.81', '488.82', '488.89', '512.2', '512.82', '512.83', '512.84', '512.89', '516.31', '516.32', '516.33', '516.34', '516.35', '516.36', '516.37', '516.4', '516.5', '518.51', '518.52', '518.53', '539.01', '539.09', '539.81', '539.89', '573.5', '596.81', '596.82', '596.83', '596.89', '629.31', '649.81', '704.

In [25]:
splits = m4inpatient_dataset.random_splits([0.8, 0.9], 42, 'subjects')



In [17]:
preprocessing = m4inpatient_dataset.fit_preprocessing(splits[0])

In [18]:
m4inpatient_dataset.apply_preprocessing(preprocessing)

DEBUG:root:Removed 2320851 (0.023) outliers from obs


In [62]:
m4inaptients = Inpatients(m4inpatient_dataset, splits[2][:10])

DEBUG:root:Loading subjects..


INFO: Pandarallel will run on 1 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.


DEBUG:root:Extracting dx codes...
DEBUG:root:[DONE] Extracting dx codes
DEBUG:root:Extracting dx codes history...
DEBUG:root:[DONE] Extracting dx codes history
DEBUG:root:Extracting outcome...
DEBUG:root:[DONE] Extracting outcome
DEBUG:root:Extracting procedures...
DEBUG:root:[DONE] Extracting procedures
DEBUG:root:Extracting inputs...
DEBUG:root:[DONE] Extracting inputs
DEBUG:root:Extracting observables...
DEBUG:root:[DONE] Extracting observables
DEBUG:root:Compiling admissions...
DEBUG:root:[DONE] Loading subjects


In [63]:
m4inaptients.size_in_bytes / 1024 ** 3

0.0014329766854643822

In [64]:
m4inaptients_jax = m4inaptients.to_jax_arrays(splits[2][:10])



In [65]:
m4inaptients_jax.size_in_bytes / 1024 ** 3

0.0013888729736208916

In [66]:
len(m4inaptients_jax.subjects)

10

In [67]:
m4inaptients_jax.n_admissions

28

In [68]:
m4inaptients.n_segments

2347

In [69]:
m4inaptients.n_obs_times

1627

In [70]:
m4inaptients.p_obs

0.2370723212456464

In [71]:
m4inaptients.subjects

{'19485534': Inpatient(
   subject_id='19485534',
   static_info=InpatientStaticInfo(
     gender='M',
     date_of_birth=Timestamp('2136-01-01 00:00:00'),
     ethnicity=bool[5](numpy),
     ethnicity_scheme=<lib.ehr.coding_scheme.MIMIC4Eth5 object at 0x7f0f3c19f9d0>,
     constant_vec=bool[6](numpy)
   ),
   admissions=[
     InpatientAdmission(
       admission_id='20148586',
       admission_dates=(
         Timestamp('2200-07-04 22:26:00'),
         Timestamp('2200-07-07 18:25:00')
       ),
       dx_codes=CodesVector(
         vec=bool[17375](numpy),
         scheme=<lib.ehr.coding_scheme.DxICD9 object at 0x7f0d1520f9d0>
       ),
       dx_codes_history=CodesVector(
         vec=bool[17375](numpy),
         scheme=<lib.ehr.coding_scheme.DxICD9 object at 0x7f0d1520f9d0>
       ),
       outcome=CodesVector(
         vec=bool[2081](numpy),
         scheme=<lib.ehr.outcome.OutcomeExtractor object at 0x7f0f3c19f520>
       ),
       observables=[
         InpatientObservables(
    

In [77]:
# import numpy as np
# import matplotlib.pyplot as plt

a = m4inaptients.obs_coocurrence_matrix
# a = np.array(a)
# plt.imshow(a, cmap='hot', interpolation='nearest')
# plt.show()

In [78]:
a

Array([[   4,    3,    3, ...,    2,    2,    0],
       [   3,   85,   84, ...,   80,   16,    0],
       [   3,   84,  135, ...,  109,   38,    2],
       ...,
       [   2,   80,  109, ..., 1373,  750,   26],
       [   2,   16,   38, ...,  750,  771,    8],
       [   0,    0,    2, ...,   26,    8,   47]], dtype=int32)

In [79]:
s = m4inaptients_jax.subjects[splits[2][0]]

In [80]:
m4inaptients_jax.outcome_frequency_partitions(5, splits[2][:10])

[Array([   0,    1,    2, ..., 1760, 1811, 1817], dtype=int32),
 Array([1830, 1835, 1855, 1858, 1861, 1885, 1897, 1961, 1962,  195,  391,
        1097, 1388], dtype=int32),
 Array([1490, 1562, 1566, 1616, 1791, 1828, 1894, 1954], dtype=int32),
 Array([1615, 1595, 1740, 2061], dtype=int32),
 Array([ 367, 1610,  388], dtype=int32)]

In [81]:
s.admissions[0].observables[0].time

Array([ 7.5666666, 31.566668 ], dtype=float32)

In [82]:
s.static_info

InpatientStaticInfo(
  gender='M',
  date_of_birth=Timestamp('2136-01-01 00:00:00'),
  ethnicity=bool[5],
  ethnicity_scheme=<lib.ehr.coding_scheme.MIMIC4Eth5 object at 0x7f0f3c19f9d0>,
  constant_vec=bool[6]
)

In [83]:
s.static_info.age(s.admissions[0].admission_dates[0])

64.50376454483231

## TODO

1. Squeeze code vectors.
2. Squeeze float32 to float16.