In [1]:
import tensorflow as tf
from time import time
import numpy
import os
import json
import pickle
import pandas
import datetime
from functools import partial, reduce
import importlib

import sys
sys.path.append('../libs')

import data_pipeline
import conv_model
import initialize
import prepare_data
import flacdb
import plot_batch

In [3]:
! ls -lh /scr1/checkpoints

total 79M
-rw-r--r-- 1 kuprel users 57K Nov 17 20:02 1117408_20191117-112416.data-00000-of-00002
-rw-r--r-- 1 kuprel users 40M Nov 17 20:02 1117408_20191117-112416.data-00001-of-00002
-rw-r--r-- 1 kuprel users 27K Nov 17 20:02 1117408_20191117-112416.index
-rw-r--r-- 1 kuprel users 57K Nov 17 18:19 1118450_20191117-135257.data-00000-of-00002
-rw-r--r-- 1 kuprel users 40M Nov 17 18:19 1118450_20191117-135257.data-00001-of-00002
-rw-r--r-- 1 kuprel users 27K Nov 17 18:19 1118450_20191117-135257.index
-rw-r--r-- 1 kuprel users 103 Nov 17 20:02 checkpoint


In [2]:
%%time

H = initialize.load_hypes()
initial_data_path = '/scr1/mimic/initial_data/'

# sig_data, metadata = initialize.load_initial_data(save_path=initial_data_path)
sig_data, metadata = initialize.load_initial_data(load_path=initial_data_path)
diagnosis = initialize.load_diagnosis(H['icd_codes'], metadata)
diagnosis = initialize.augment_diagnosis(diagnosis, metadata)
diagnosis = initialize.fix_diagnosis(diagnosis)
diagnosis_priors = (diagnosis == 1).sum() / (diagnosis != 0).sum()
diagnosis_priors['measured_systemic_hypertension'] = 0.5
diagnosis_priors['measured_pulmonary_hypertension'] = 0.5
diagnosis = initialize.conform_diagnosis(diagnosis, metadata)
partition = initialize.load_partition(H, sig_data)

CPU times: user 3.15 s, sys: 236 ms, total: 3.38 s
Wall time: 3.39 s


In [3]:
%%time
(diagnosis[partition['validation']] == 1).any(level=0, axis=0).sum()

CPU times: user 44 ms, sys: 0 ns, total: 44 ms
Wall time: 41.4 ms


25000                  117
2720                    60
2724                   143
2761                    50
2762                    93
27651                   16
27652                   18
27800                   22
2859                    60
4019                   236
41071                   42
41401                  119
4160                     3
4240                    28
4241                    27
42731                  152
42732                   16
4275                    25
4280                   143
42832                    8
42833                   16
431                     32
5119                    35
51881                  136
5715                    24
5849                   111
5859                    52
78552                   65
79902                   15
99592                   92
V5867                   33
gender_F               228
gender_M               371
race_asian              10
race_black              51
race_hispanic           29
race_white             382
a

In [5]:
%%time
(diagnosis[partition['validation']] == 1).any(level=0, axis=0).sum()

CPU times: user 76 ms, sys: 0 ns, total: 76 ms
Wall time: 71.1 ms


25000                  102
2724                   133
2762                    82
27651                   12
27652                   14
27800                   22
2859                    73
4019                   240
41071                   23
41401                  103
4240                    20
4241                    18
42731                  138
42732                   17
4275                    35
4280                   129
42832                   14
42833                   14
431                     27
5119                    40
51881                  124
5715                    21
5849                   107
5859                    52
78552                   50
79902                   16
99592                   65
gender_F               257
gender_M               312
race_asian              17
race_black              39
race_hispanic           15
race_white             375
age_at_least_75        161
height_at_least_70      69
weight_at_least_100     67
died                    99
d

In [38]:
%%time
(diagnosis[partition['validation']] == 1).any(level=0, axis=0).sum()

CPU times: user 68 ms, sys: 4 ms, total: 72 ms
Wall time: 72.1 ms


25000                   67
2724                    96
2762                    38
27651                    8
27652                   10
27800                   15
2859                    43
4019                   166
41071                   18
41401                   82
4240                    16
4241                    14
42731                   93
42732                   13
4275                    22
4280                    97
42832                    6
42833                   11
431                     17
5119                    24
51881                   77
5715                    10
5849                    73
5859                    36
78552                   27
79902                    7
99592                   42
gender_F               153
gender_M               191
race_asian              14
race_black              27
race_hispanic            8
race_white             240
age_at_least_75        103
height_at_least_70      50
weight_at_least_100     37
died                    54
d

In [5]:
%%time

dataset = {}
for part in ['train', 'validation']:
    I = partition[part]
    row_lengths = initialize.get_row_lengths(metadata[I])
    args = [metadata[I], sig_data[I], diagnosis[I], row_lengths]
    tensors = initialize.get_tensors(H, *args)
    dataset[part] = data_pipeline.build(H, tensors, part)

CPU times: user 18.2 s, sys: 200 ms, total: 18.4 s
Wall time: 18.4 s


In [7]:
importlib.reload(conv_model)

<module 'conv_model' from '../libs/conv_model.py'>

In [8]:
model = conv_model.build(H, diagnosis_priors)

model.fit(
    dataset['train'],
    validation_data = dataset['validation'],
    steps_per_epoch = 2**7,
    validation_steps = 1,
)

Train for 128 steps, validate for 1 steps
 19/128 [===>..........................] - ETA: 16:26 - loss: 1.3292 - pressure_loss: 1.6601 - diagnosis_loss: 1.1104 - pressure_ABP_systolic: 21.4015 - pressure_ABP_diastolic: 11.9643 - pressure_ABP_pulse: 18.4847 - pressure_CVP_systolic: 6.4757 - pressure_CVP_diastolic: 5.5955 - pressure_CVP_pulse: 5.4249 - pressure_ICP_systolic: 5.1387 - pressure_ICP_diastolic: 4.5058 - pressure_ICP_pulse: 3.4367 - pressure_PAP_systolic: 10.4759 - pressure_PAP_diastolic: 8.0560 - pressure_PAP_pulse: 8.5764 - diagnosis_25000_diabetes_sensitivity: 0.5435 - diagnosis_25000_diabetes_specificity: 0.4972 - diagnosis_25000_diabetes_accuracy: 0.5203 - diagnosis_25000_diabetes_precise_sensitivity: 0.0981 - diagnosis_25000_diabetes_precise_threshold: 0.9851 - diagnosis_2720_hypercholesterolemia_sensitivity: 0.2426 - diagnosis_2720_hypercholesterolemia_specificity: 0.7832 - diagnosis_2720_hypercholesterolemia_accuracy: 0.5129 - diagnosis_2720_hypercholesterolemia_preci

KeyboardInterrupt: 

In [6]:
sig_data = pandas.read_hdf('/scr-ssd/mimic/sig_data.hdf')

In [7]:
sig_data['sig_name'].value_counts()

II        1871134
PLETH     1579432
RESP      1434547
V         1302166
AVR        914718
ABP        611101
III        310765
CVP        247842
I          173964
MCL        126479
ICP         67941
PAP         53724
MCL1        43520
ART         37237
AVF         33734
AVL         21747
UAP          6313
PLETHR       2866
AOBP         2079
PLETHL       2051
UVP          1220
IC2           712
RAP           526
CO2           443
ECG           428
V1            343
IC1           305
P1            223
LAP           102
BAP            66
V2             45
P4             11
FAP             9
V5              6
AO              5
V3              3
P3              3
P2              2
Name: sig_name, dtype: int64