In [None]:
import psycopg2
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import sklearn
import tensorflow as tf
import helpers as hp

In [None]:
RANDOM_STATE = 42
WINDOW_LENGTH = 8
CLIENT_COUNT = 2
USE_FL = True
VITAL_NAMES = ['heartrate', 'sysbp', 'diasbp', 'meanbp', 'resprate', 'tempc', 'spo2']
LAB_NAMES = ['albumin', 'bun', 'bilirubin', 'lactate', 'bicarbonate', 'bands', 'chloride', 'creatinine', 'glucose', 'hemoglobin', 'hematocrit', 'platelet', 'potassium', 'ptt', 'sodium', 'wbc']
LABEL_NAME = 'label_death_icu' # Possible values: 'label_death_icu', 'label_death_continuous'

# Load data

## From SQL:

In [None]:
# Connect to db
conn = psycopg2.connect(host='localhost', port=5433, dbname='mimic', user='postgres', password='postgres')
cur = conn.cursor() 

# Read vital signs
vitals = pd.read_sql_query('SELECT * FROM mimiciii.vital_labeled_'+str(WINDOW_LENGTH)+'h;', conn)

# Read in labs values
labs = pd.read_sql_query('SELECT * FROM mimiciii.labs_labeled_'+str(WINDOW_LENGTH)+'h;', conn)

# Close the cursor and connection to so the server can allocate bandwidth to other requests
cur.close()
conn.close()

## From File:

In [None]:
vitals = pd.read_pickle('vitals_labeled_'+str(WINDOW_LENGTH)+'h.pickle')
labs = pd.read_pickle('labs_labeled_'+str(WINDOW_LENGTH)+'h.pickle')

# Data Processing
Create interface specifications: 

In [None]:
vitals_spec = tf.TensorSpec(
    shape=(None, len(VITAL_NAMES)),
    dtype=tf.dtypes.float64,
    name='vitals'
)
labs_spec = tf.TensorSpec(
    shape=(None, len(LAB_NAMES)),
    dtype=tf.dtypes.float64,
    name='labs'
)
label_spec = tf.TensorSpec(
    shape=1,
    dtype=tf.dtypes.float64,
    name='label'
)

## Build Model 

Build RNN-Model:

In [None]:
# Vital channel
inputs_vitals = tf.keras.Input(shape=vitals_spec.shape, name='Input_vitals') 
mask_vitals = tf.keras.layers.Masking(mask_value=-2., name='mask_vitals')(inputs_vitals)
GRU_layer1_vitals = tf.keras.layers.GRU(16, return_sequences=True, name='GRU_layer1_vitals')(mask_vitals)
GRU_layer2_vitals = tf.keras.layers.GRU(16, return_sequences=True, name='GRU_layer2_vitals')(GRU_layer1_vitals)
GRU_layer3_vitals = tf.keras.layers.GRU(16, return_sequences=False, name='GRU_layer3_vitals')(GRU_layer2_vitals)
normalized_vitals= tf.keras.layers.BatchNormalization(name='BatchNorm_vitals')(GRU_layer3_vitals)

#Labs channel
inputs_labs = tf.keras.Input(shape=labs_spec.shape, name='Input_labs')
mask_labs = tf.keras.layers.Masking(mask_value=-2., name='mask_labs')(inputs_labs)
GRU_layer1_labs = tf.keras.layers.GRU(16, return_sequences=True, name='GRU_layer1_labs')(mask_labs)
GRU_layer2_labs = tf.keras.layers.GRU(16, return_sequences=True, name='GRU_layer2_labs')(GRU_layer1_labs)
GRU_layer3_labs = tf.keras.layers.GRU(16, return_sequences=False, name='GRU_layer3_labs')(GRU_layer2_labs)
normalized_labs= tf.keras.layers.BatchNormalization(name='BatchNorm_labs')(GRU_layer3_labs)

#Concatanation of both branches
merge= tf.keras.layers.Concatenate()([normalized_vitals, normalized_labs])

FCL1 = tf.keras.layers.Dense(16, name='FCL1')(merge)  #which unites for the dense layer?
FCL2 = tf.keras.layers.Dense(16, name='FCL2')(FCL1)
outputs = tf.keras.layers.Dense(1, activation='sigmoid',name='output')(FCL2)

model = tf.keras.Model(inputs=[inputs_vitals,inputs_labs], outputs=outputs, name='RNN_model')
model.summary()

In [None]:
tf.keras.utils.plot_model(model, "./pictures/model.png", show_shapes=True)

## Build Pipeline:

In [None]:
# Extract vitals:
vitals.sort_values(['icustay_id', 'charttime'])
vital_data = vitals[['icustay_id'] + VITAL_NAMES].groupby(['icustay_id'])

# Extract labs:
labs.sort_values(['icustay_id', 'charttime'])
lab_data = labs[['icustay_id'] + LAB_NAMES].groupby(['icustay_id'])

# Extract ICU-stays and labels:
icustays = vitals[['icustay_id', LABEL_NAME]].groupby(['icustay_id']).first().reset_index().to_numpy()
icustays

## Evaluate Model:

In [None]:
# Labeling-dependent variables:
loss_fcn = 'binary_crossentropy'
label_cnt = 2

if LABEL_NAME == 'label_death_continuous':
    loss_fcn = 'mse'
    label_cnt = 5

In [None]:
# Metrics
metrics=[
    hp.ContinuousAUC(curve='ROC', name='AUROC', num_labels=label_cnt),
    hp.ContinuousAUC(curve='PR', name='AUPRC', num_labels=label_cnt),
    hp.ContinuousRecall(name='recall', num_labels=label_cnt),
    hp.ContinuousPrecision(name='precision', num_labels=label_cnt),
    tf.keras.losses.MeanAbsoluteError(name='MAE'),
    tf.keras.losses.MeanSquaredError(name='MSE')
]


In [None]:
trainer = hp.Trainer(
    vital_data, lab_data,
    loss_fcn, metrics,
    output_signature=((vitals_spec, labs_spec), label_spec),
    random_state=RANDOM_STATE,
    threaded=False,
    max_threads=2
)

In [None]:
if USE_FL:
    trainer.evaluateFL(
        model,
        icustays,
        n_rounds=50,
        n_clients=CLIENT_COUNT,
        n_labels=label_cnt,
        weighted=True,
        shuffle=True,
        stratify_clients=(label_cnt==2)
    ) 
    
else:
    trainer.evaluate(
        model,
        icustays,
        n_clients=CLIENT_COUNT,
        n_labels=label_cnt,
        weighted=True,
        shuffle=True,
        stratify_clients=(label_cnt==2)
    )

## Statistics

In [None]:
print('Average test AUROC:', trainer.test_scores['AUROC'].mean())
print('Average test AUPRC:', trainer.test_scores['AUPRC'].mean())

recall_sc = trainer.test_scores['recall'].mean()
precision_sc = trainer.test_scores['precision'].mean()

f1_sc = 2 * precision_sc * recall_sc / (precision_sc + recall_sc )
print('Average test F1:', f1_sc)

print('Average test MAE:', trainer.test_scores['MAE'].mean())
print('Average test MSE:', trainer.test_scores['MSE'].mean())

In [None]:
trainer.plot_history('loss', plt.subplot(2, 1, 1))

trainer.plot_history('precision', plt.subplot(2, 2, 3), x_step=4)
trainer.plot_history('recall', plt.subplot(2, 2, 4), x_step=4)

plt.tight_layout()
plt.show()

In [None]:
trainer.plot_history('AUROC', plt.subplot(2, 1, 1))
trainer.plot_history('AUPRC', plt.subplot(2, 1, 2))

plt.tight_layout()
plt.show()

In [None]:
trainer.plot_history('MAE', plt.subplot(2, 1, 1))
trainer.plot_history('MSE', plt.subplot(2, 1, 2))

plt.tight_layout()
plt.show()

# Save results

In [None]:
l = 'cont' if LABEL_NAME == 'label_death_continuous' else 'bin'

trainer.save(
    f'./scores/{l:s}_{WINDOW_LENGTH:d}h/'+
    ('scores_fl_' if USE_FL else 'scores_')+
    f'{l:s}_{CLIENT_COUNT:d}clients_{WINDOW_LENGTH:d}h.pickle'
)