In [29]:
import os
import math
import sys
import shutil
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
import pandas as pd
import numpy as np
from tqdm import tqdm
from joblib import Parallel, delayed

import tensorflow as tf
from keras.layers import Input, Conv1D, GlobalAveragePooling1D, Dense, Attention, BatchNormalization, ReLU, LSTM, Dropout, Conv1DTranspose, Flatten, TimeDistributed, Concatenate, Layer
from keras.models import Model, Sequential
import keras
from keras import layers
from sklearn.utils import class_weight
import ipywidgets as widgets



In [2]:
# Input disease name
disease_box = widgets.Text(
    value='respiratory_HiRID',
    placeholder='respiratory_HiRID, circulatory, kidney or sepsis',
    description='Disease:',
    disabled=False,
    layout=widgets.Layout(width='500px')
)

# Display the text box
display(disease_box)

# Input hours before & after onset to label as positive
label_before = widgets.Label('Enter the # of hours before onset that are considered positive:')

before_scrollbar = widgets.IntSlider(
    value=6,
    min=0,  # minimum value
    max=24,  # maximum value
    step=1,  # increment size
    description='Before:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d'
)

# Display the scrollbar
display(label_before, before_scrollbar)

# Input hours before & after onset to label as positive
label_after = widgets.Label('Enter the # of hours after onset that are considered positive:')

after_scrollbar = widgets.IntSlider(
    value=24,
    min=0,  # minimum value
    max=24,  # maximum value
    step=1,  # increment size
    description='Before:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d'
)

# Display the scrollbar
display(label_after, after_scrollbar)

# Input hours before & after onset to label as positive
label_freq = widgets.Label('Enter the sampling frequency of the input (minutes):')

sampling_freq_box = widgets.Text(
    value='30',
    placeholder='30',
    description='Minutes:',
    disabled=False,
    layout=widgets.Layout(width='500px')
)

# Display the text box
display(label_freq, sampling_freq_box)

Text(value='respiratory_HiRID', description='Disease:', layout=Layout(width='500px'), placeholder='respiratory…

Label(value='Enter the # of hours before onset that are considered positive:')

IntSlider(value=6, continuous_update=False, description='Before:', max=24)

Label(value='Enter the # of hours after onset that are considered positive:')

IntSlider(value=24, continuous_update=False, description='Before:', max=24)

Label(value='Enter the sampling frequency of the input (minutes):')

Text(value='30', description='Minutes:', layout=Layout(width='500px'), placeholder='30')

In [30]:
main_folder_path = '/datasets/amelatur/mimic_m1pz92hj/' # path with patient IDs
ids_file = pd.read_csv( main_folder_path + '0labels.txt')
patient_ids = ids_file['id']
vitals_folder_path = '/datasets/amelatur/mimic_kih7jlb3/' # path with patient vitals

disease = str(disease_box.value)
onset_file_path = '/datasets/amelatur/data_slices/onset_files/' + disease + '_onset_index.csv' # path with onset indices
onset_indices_csv = pd.read_csv(onset_file_path, header=None, index_col=False).rename({0: 'Patient_ID', 1: 'Onset_Index'}, axis=1)

save_file_path = '/datasets/amelatur/whole_sequences/' + disease +'/'

statics_csv = pd.read_csv(main_folder_path + '0labels.txt')

# # delete save_file_path if it already exists
# if os.path.exists(save_file_path):
#     shutil.rmtree(save_file_path)

# # create empty save_file_path
# if not os.path.exists(save_file_path):
#     os.makedirs(save_file_path)

sampling_freq_minutes = int(sampling_freq_box.value)
max_seq_length = (60 // sampling_freq_minutes) * 24 * 7 # max seq length is 7 days
min_seq_length = (60 // sampling_freq_minutes) * 12 # min seq length is 12 hours

start_event_labeling = before_scrollbar.value * ( 60 // sampling_freq_minutes ) # 4 hours before onset 
end_event_labelling = after_scrollbar.value * ( 60 // sampling_freq_minutes ) # 2 hours after onset

In [31]:
overall_mean = np.load(save_file_path + 'overall_mean.npy')
overall_std = np.load(save_file_path + 'overall_std.npy')
overall_mean_weight = np.load(save_file_path + 'overall_mean_weight.npy')
overall_mean_age = np.load(save_file_path + 'overall_mean_age.npy')
overall_std_weight = np.load(save_file_path + 'overall_std_weight.npy')
overall_std_age = np.load(save_file_path + 'overall_std_age.npy')


In [32]:
# train test split
train_size = 0.8
train_patients = patient_ids.sample(frac=train_size, random_state=0)
test_patients = patient_ids.drop(train_patients.index)
print('Test patients:', len(test_patients))


# validation
validation_size = 0.15
validation_patients = train_patients.sample(frac=validation_size, random_state=0)
train_patients = train_patients.drop(validation_patients.index)
print('Validation patients:', len(validation_patients))
print('Train patients:', len(train_patients))


Test patients: 14548
Validation patients: 8729
Train patients: 49464


In [33]:
import random
from scipy.stats import expon


def get_data(patient_id):

    try:

        exclusion_flag = 0

        # get labs to apply exclusion criteria
        labs_file = pd.read_csv(main_folder_path + str(patient_id) + '_all_vals.csv', index_col=False)
        labs_file['charttime'] = pd.to_datetime(labs_file['charttime'])
        labs_file = labs_file.sort_values(by='charttime').set_index('charttime')

        n_steps = len(labs_file.index)

        # P/F ratio
        if (labs_file['po2'].isna().all() or n_steps < min_seq_length or n_steps > (60 // sampling_freq_minutes) * 24 * 7):
        #if (n_steps < min_seq_length or n_steps > (60 // sampling_freq_minutes) * 24 * 7):
        #if (labs_file['lactate'].isna().all() or labs_file['mbp'].isna().all() or n_steps < min_seq_length or n_steps > (60 // sampling_freq_minutes) * 24 * 7):
            exclusion_flag = 1

        # get onset time
        onset_index = (onset_indices_csv[onset_indices_csv['Patient_ID'] == patient_id]['Onset_Index'].values[0])
        
        # load vitals
        vitals_file = pd.read_csv(vitals_folder_path + str(patient_id) + '_vitals.csv', index_col=False)
        vitals_file['time'] = pd.to_datetime(vitals_file['time'])
        vitals_file = vitals_file.set_index('time')
        vitals_file = vitals_file.drop(['Unnamed: 0', 'id'], axis=1)

        ranges = {'heartrate':(0.0, 300.0), 'sbp': (10.0, 300.0), 'dbp': (10.0, 175.0), 'mbp': (10.0, 200.0), 'respiration': (0.0, 45.0), 'temperature': (25.0, 45.0), 'spo2': (10.0, 100.0)}

        # Replace values with NaNs if they are outside the specified range for each column
        vitals_file = vitals_file.apply(lambda col: np.where((col < ranges[col.name][0]) | (col > ranges[col.name][1]), np.nan, col))
        vitals = vitals_file.resample('30T').mean() 
        vitals = vitals.interpolate().ffill().bfill()
        #vitals = vitals.drop(['mbp'], axis=1)

        if vitals.isna().any().any():
            exclusion_flag = 1
        vitals = vitals.to_numpy()

        # construct target sequence
        target = np.zeros(vitals.shape[0])
        if np.isnan(onset_index) == False:
            
            onset_index = int(onset_index)
            if onset_index < 4:
                exclusion_flag = 1
            # target[onset_index:] = 1 # everything after onset is 1
            start_event = max(0, onset_index - start_event_labeling)
            end_event = min(vitals.shape[0], onset_index + end_event_labelling)
            target[start_event:end_event] = 1 # 4H before onset till 2H after onset is the event definition

            # end sequences after event end
            vitals = vitals[:end_event]
            target = target[:end_event] 

        # else:

        #     idx = np.random.choice(np.arange(len(new_neg_lengths)))
        #     sample_length = new_neg_lengths[idx]
        #     new_neg_lengths = np.delete(new_neg_lengths, idx)

        #     end_seq_negative = min(vitals.shape[0], int(sample_length))
        #     vitals = vitals[:end_seq_negative]
        #     target = target[:end_seq_negative]

        if vitals.shape[0] != target.shape[0]:
            exclusion_flag = 1

        
        # statics = statics_csv[statics_csv['id'] == patient_id][['weight', 'age', 'gender']]
        # statics['gender'] = 0 if statics['gender'].values[0] == 'M' else 1
        # statics = statics.T.to_numpy()
        # statics = statics.squeeze()
            
        statics = np.load(save_file_path + str(patient_id) + '_statics.npy')

        return patient_id, target, vitals, statics, exclusion_flag
    
    except Exception as e:
        print(e, patient_id)
        return patient_id, None, None, None, 1
        

In [34]:
def save_data(patient_id, target, vitals, statics):

    # save data
    np.save(save_file_path + str(patient_id) + '_target.npy', np.array(target))
    np.save(save_file_path + str(patient_id) + '_vitals.npy', np.array(vitals))
    np.save(save_file_path + str(patient_id) + '_statics.npy', np.array(statics))

    # print('Patient:', patient_id, 'saved')

    return 0


In [35]:
from joblib import Parallel, delayed
from tqdm import tqdm

def compute_mean_and_var(pat_id):
    patient_id, target, vitals, statics, exclusion_flag = get_data(pat_id)
    
    if exclusion_flag == 0:
        mean = np.mean(vitals, axis=0)
        var = np.var(vitals, axis=0)
        weight = statics[0]
        age = statics[1]
        

        # if np.sum(target) == 0: # negative patient, then cut its length
        #     return None, None, None, None, None

        nb_ones = np.sum(target)
        total_len = len(target)

        pos_or_neg = 0 if nb_ones == 0 else 1

        # also compute class weights
        return mean, var, nb_ones, total_len, pos_or_neg, weight, age
    else:
        return None, None, None, None, None, None, None

results = Parallel(n_jobs=-1)(delayed(compute_mean_and_var)(pat_id) for pat_id in tqdm(train_patients))

# Filter out None results
results = [result for result in results if result[0] is not None]

means, vars, nb_ones, total_len, pos_or_neg, weight, age = zip(*results)

overall_mean = np.mean(np.array(means), axis=0)
overall_std = np.sqrt(np.mean(np.array(vars), axis=0))   
overall_mean_weight = np.mean(np.array(weight))
overall_mean_age = np.mean(np.array(age))
overall_std_weight = np.std(np.array(weight))
overall_std_age = np.std(np.array(age))

  0%|          | 16/49464 [00:00<24:08, 34.15it/s]

[Errno 2] No such file or directory: '/datasets/amelatur/whole_sequences/respiratory_HiRID/35185673_statics.npy' 35185673
[Errno 2] No such file or directory: '/datasets/amelatur/whole_sequences/respiratory_HiRID/33804298_statics.npy' 33804298
[Errno 2] No such file or directory: '/datasets/amelatur/whole_sequences/respiratory_HiRID/39300549_statics.npy' 39300549
[Errno 2] No such file or directory: '/datasets/amelatur/whole_sequences/respiratory_HiRID/34040313_statics.npy' 34040313
[Errno 2] No such file or directory: '/datasets/amelatur/whole_sequences/respiratory_HiRID/30548152_statics.npy' 30548152
[Errno 2] No such file or directory: '/datasets/amelatur/whole_sequences/respiratory_HiRID/32142080_statics.npy' 32142080
[Errno 2] No such file or directory: '/datasets/amelatur/whole_sequences/respiratory_HiRID/39855725_statics.npy' 39855725
[Errno 2] No such file or directory: '/datasets/amelatur/whole_sequences/respiratory_HiRID/35958260_statics.npy' 35958260
[Errno 2] No such file o

KeyboardInterrupt: 

In [None]:

prevalence_ones = np.sum(nb_ones) / np.sum(total_len)
prevalence_zeros = 1 - prevalence_ones


print("Dataset yields", int(prevalence_zeros * 100), "% of zeros and", int(prevalence_ones * 100), "% of ones")
print("Positive patients:", np.sum(pos_or_neg), "Negative patients:", len(pos_or_neg) - np.sum(pos_or_neg))

len_pos = [total_len[i] for i in range(len(total_len)) if pos_or_neg[i] == 1]
len_neg = [total_len[i] for i in range(len(total_len)) if pos_or_neg[i] == 0]
avg_len_pos = np.mean(len_pos)
avg_len_neg = np.mean(len_neg)
print("Average seq length for positive patient:", avg_len_pos * 30 // 60, "hours")
print("Average seq length for negative patient:", avg_len_neg * 30 // 60, "hours")

print('Mean weight:', overall_mean_weight, 'Std weight:', overall_std_weight)
print('Mean age:', overall_mean_age, 'Std age:', overall_std_age)

# histogram of sequence lengths
import matplotlib.pyplot as plt

# Create a figure and a 2x1 grid of subplots
fig, axs = plt.subplots(2)

# Plot the histogram of len_pos on the first subplot
axs[0].hist(len_pos, alpha=0.5, label='Positive')
axs[0].set_title('Positive')
axs[0].set_xlim(xmin=0, xmax=300)


# Plot the histogram of len_neg on the second subplot
axs[1].hist(len_neg, alpha=0.5, label='Negative')
axs[1].set_title('Negative')
axs[1].set_xlim(xmin=0, xmax=300)


# Display the figure with the subplots
plt.tight_layout()
plt.show()




In [None]:
from sklearn.neighbors import KernelDensity

pos_lengths = np.array(len_pos)
neg_lengths = np.array(len_neg)

# Fit KDE to positive lengths
kde = KernelDensity(kernel='gaussian', bandwidth=0.5).fit(pos_lengths.reshape(-1, 1))

# Sample new lengths for negative class
new_neg_lengths = kde.sample(len(neg_lengths), random_state=0)

# Since the KDE can generate negative samples, we take the absolute value
new_neg_lengths = np.abs(new_neg_lengths)

neg_lengths.shape

# # Create a list to keep track of indices where we've replaced values
# replaced_indices = []

# # Create a new list with replaced values
# new_total_len = [new_len_neg.pop(0) if (pos_or_neg[i] == 0 and i not in replaced_indices) else val for i, val in enumerate(total_len)]

# # If new_len_neg is not empty after replacing all values, append the remaining values to new_total_len
# if new_len_neg:
#     new_total_len.extend(new_len_neg)


In [None]:
print('We ended up with ' + str(len(means)) + ' patients to train on after the exclusion criteria was applied')

In [None]:
def iterate_patients_save(pat_id):
    patient_id, respi_status, vitals, statics, exclusion_flag = get_data(pat_id) # get inputs and targets

    if exclusion_flag == 0:

        # normalize the inputs/vitals
        vitals = (vitals - overall_mean) / overall_std

        # pad sequences expects lists, so create a list of length 1 with the inputs and targets
        respi_status = tf.keras.preprocessing.sequence.pad_sequences([respi_status],maxlen=max_seq_length, padding='post', truncating='post', value=100000, dtype='float32')

        
        vitals = tf.keras.preprocessing.sequence.pad_sequences([vitals],maxlen=max_seq_length, padding='post', truncating='post', value=100000, dtype='float32')

        # remove the first dimension of the inputs and targets
        respi_status = np.squeeze(respi_status, axis=0)
        vitals = np.squeeze(vitals, axis=0)

        # add dimension to targets
        respi_status = np.expand_dims(respi_status, axis=-1)

        save_data(patient_id, respi_status, vitals, statics)

        return respi_status, vitals
    else:
        return None, None


res = Parallel(n_jobs=-1)(delayed(iterate_patients_save)(pat_id) for pat_id in tqdm(train_patients))

In [None]:
res = Parallel(n_jobs=-1)(delayed(iterate_patients_save)(pat_id) for pat_id in tqdm(validation_patients))

In [9]:

def data_generator(patient_ids_selected):

    for pat_id in patient_ids_selected:
        
        vital_path = save_file_path + str(pat_id) + '_vitals.npy'
        target_path = save_file_path + str(pat_id) + '_target.npy'


        if os.path.exists(vital_path):
            try:
                respi_status = np.load(target_path)
                vitals = np.load(vital_path)

                first_minus_one_index = np.where(respi_status > 1.0)[0][0]    

                decoder_input = np.zeros((vitals.shape[0], 1))
                decoder_input[first_minus_one_index:, :] = 100000.0

                
                statics = np.load(save_file_path + str(pat_id) + '_statics.npy')

                statics[0] = (statics[0] - overall_mean_weight) / overall_std_weight
                statics[1] = (statics[1] - overall_mean_age) / overall_std_age

                #yield (vitals, decoder_input), respi_status
                yield (vitals, statics), respi_status, 
            except:
                pass
            
                


def custom_data_loader(batch_size, patient_ids_selected):

    # Shuffle the patient ids
    shuffled_series_patients = patient_ids_selected.sample(frac=1)

    dataset = tf.data.Dataset.from_generator(
        lambda: data_generator(shuffled_series_patients), 
        # output_signature=(
        #         (tf.TensorSpec(shape=(336, 7), dtype=tf.float32), tf.TensorSpec(shape=(336, 1), dtype=tf.float32)),
        #     tf.TensorSpec(shape=(336, 1), dtype=tf.float32),
        # )
        output_signature=(
                (tf.TensorSpec(shape=(336, 7), dtype=tf.float32),tf.TensorSpec(shape=(3,), dtype=tf.float32)),
            tf.TensorSpec(shape=(336, 1), dtype=tf.float32),
        )
    )


    dataset = dataset.batch(batch_size).prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

    return dataset

batch_size = 32

data_loader = custom_data_loader(batch_size, train_patients)
val_data_loader = custom_data_loader(batch_size, validation_patients)

#for batch in data_loader:
#for (inputs, decoder_inputs), targets in data_loader:
for (inputs, statics), targets in data_loader:

    print("Inputs shape:", inputs.shape)
    #print("Decoder inputs shape:", decoder_inputs.shape)
    print("Targets shape:", targets.shape)
    print("Statics shape:", statics.shape)
    break

Exception ignored in: <generator object data_generator at 0x7fe97880d0e0>
Traceback (most recent call last):
  File "/home/amelatur/.pyenv/versions/3.10.12/envs/my_proj/lib/python3.10/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 870, in iterator_completed
    del self._iterators[self._normalize_id(iterator_id)]
RuntimeError: generator ignored GeneratorExit


Inputs shape: (32, 336, 7)
Targets shape: (32, 336, 1)
Statics shape: (32, 3)


In [None]:
from tensorflow import linalg, ones, math, cast, float32, maximum

def padding_mask(input):
    # Create mask which marks the 100000.0 padding values in the input by a 1
    mask = math.equal(input, 100000.0)
    mask = cast(mask, float32)
    print(mask)
 
    return mask
 
def lookahead_mask(shape):
    # Mask out future entries by marking them with a 1.0
    mask = 1 - linalg.band_part(ones((shape, shape)), -1, 0)
 
    return mask

def create_lstm_encoder_decoder(input_shape, target_shape):
    # Define an input sequence and process it.
    latent_dim = 256
    num_decoder_tokens = target_shape[0]

    # Define an input sequence and process it.
    enc_inputs = Input(shape=input_shape)
    encoder_inputs = keras.layers.Masking(mask_value=100000.0)(enc_inputs)

    encoder = LSTM(64, return_state=True)
    encoder_outputs, state_h, state_c = encoder(encoder_inputs)

    # We discard `encoder_outputs` and only keep the states.
    encoder_states = [state_h, state_c]

    # Set up the decoder, using `encoder_states` as initial state.
    dec_inputs = Input(shape=target_shape)

    # Create and combine padding and look-ahead masks to be fed into the decoder
    dec_in_padding_mask = padding_mask(dec_inputs)
    #print(dec_in_padding_mask)
    dec_in_lookahead_mask = lookahead_mask(dec_inputs.shape[1])
    #dec_in_lookahead_mask = maximum(dec_in_padding_mask, dec_in_lookahead_mask)

    #decoder_inputs = keras.layers.Masking(mask_value=100000.0)(dec_inputs)
    decoder_lstm = LSTM(64, return_sequences=True)
    decoder_outputs = decoder_lstm(dec_inputs, initial_state=encoder_states, mask=dec_in_padding_mask)

    # Add a TimeDistributed Dense layer to make a prediction for each timestep
    decoder_dense = TimeDistributed(Dense(1, activation='sigmoid'))
    decoder_outputs = decoder_dense(decoder_outputs)

    # Define the model that will turn
    # `encoder_input_data` & `decoder_input_data` into `decoder_target_data`
    model = Model([enc_inputs, dec_inputs], decoder_outputs)

    return model

In [None]:
def create_lstm_encoder_decoder(input_shape, target_shape):
    # Define an input sequence and process it.

    latent_dim = 256
    num_decoder_tokens = target_shape[0]

    # Define an input sequence and process it.
    enc_inputs = Input(shape=input_shape)
    encoder_inputs = keras.layers.Masking(mask_value=100000.0)(enc_inputs)

    encoder = (LSTM(64, return_state=True))
    encoder_outputs, state_h, state_c = encoder(encoder_inputs)

    # We discard `encoder_outputs` and only keep the states.
    encoder_states = [state_h, state_c]

    # Set up the decoder, using `encoder_states` as initial state.
    dec_inputs = Input(shape=target_shape)
    decoder_inputs = keras.layers.Masking(mask_value=100000.0)(dec_inputs)
    decoder_lstm = (LSTM(64, return_sequences=True))
    decoder_outputs = decoder_lstm(decoder_inputs, initial_state=encoder_states)


    # Add a TimeDistributed Dense layer to make a prediction for each timestep
    decoder_dense = TimeDistributed(Dense(1, activation='sigmoid'))
    decoder_outputs = decoder_dense(decoder_outputs)

    # Define the model that will turn
    # `encoder_input_data` & `decoder_input_data` into `decoder_target_data`
    model = Model([enc_inputs, dec_inputs], decoder_outputs)

    return model


loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=False)
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001, weight_decay=None, epsilon=1e-8)


input_shape = (None, 7)
target_shape = (None, 1)
model = create_lstm_encoder_decoder(input_shape,  target_shape)

model.compile(loss=loss_object, optimizer=optimizer, metrics=['accuracy'])

model.summary()




In [None]:
from tensorflow.keras.layers import Bidirectional, Concatenate

def create_lstm_encoder_decoder(input_shape, target_shape):
    # Define an input sequence and process it.

    latent_dim = 256
    num_decoder_tokens = target_shape[0]

    # Define an input sequence and process it.
    enc_inputs = Input(shape=input_shape)
    encoder_inputs = keras.layers.Masking(mask_value=100000.0)(enc_inputs)

    encoder = Bidirectional(LSTM(64, return_state=True))  # Changed to 128
    encoder_outputs, forward_h, forward_c, backward_h, backward_c = encoder(encoder_inputs)

    # We discard `encoder_outputs` and only keep the states.
    state_h = Concatenate()([forward_h, backward_h])
    state_c = Concatenate()([forward_c, backward_c])
    encoder_states = [state_h, state_c]  # Duplicate the states

    # Set up the decoder, using `encoder_states` as initial state.
    dec_inputs = Input(shape=target_shape)
    decoder_inputs = keras.layers.Masking(mask_value=100000.0)(dec_inputs)
    decoder_lstm = (LSTM(128, return_sequences=True, return_state=True))  # Remains 128
    decoder_outputs, _, _ = decoder_lstm(decoder_inputs, initial_state=encoder_states)

    # Add a TimeDistributed Dense layer to make a prediction for each timestep
    decoder_dense = TimeDistributed(Dense(1, activation='sigmoid'))
    decoder_outputs = decoder_dense(decoder_outputs)

    # Define the model that will turn
    # `encoder_input_data` & `decoder_input_data` into `decoder_target_data`
    model = Model([enc_inputs, dec_inputs], decoder_outputs)

    return model


loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=False)
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001, weight_decay=None, epsilon=1e-8)


input_shape = (None, 7)
target_shape = (None, 1)
model = create_lstm_encoder_decoder(input_shape,  target_shape)

model.compile(loss=loss_object, optimizer=optimizer, metrics=['accuracy'])

model.summary()


In [None]:
import keras_nlp

def padding_mask(input):
    # Create mask which marks the 100000.0 padding values in the input by a 1
    mask = tf.math.not_equal(input, 100000.0) # want 100000.0 to produce False
    mask = tf.cast(mask, tf.bool)
    mask = mask[:, :, 0]
    return mask

# Create a single transformer encoder layer.
encoder = keras_nlp.layers.TransformerEncoder(
    intermediate_dim=64, num_heads=7)

# Create a simple model containing the encoder.
input = keras.Input(shape=(336, 7))
mask = padding_mask(input)
output = encoder(input, padding_mask=mask)
output = TimeDistributed(Dense(1, activation='sigmoid'))(output)
model = keras.Model(inputs=input, outputs=output)
model.summary()

optimizer = tf.keras.optimizers.Adam(learning_rate=0.001, weight_decay=None, epsilon=1e-8)
model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['accuracy'])

In [10]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, LSTM, Dense, Masking, Bidirectional, GRU, Embedding


input_shape = (None, 7)
patient_data_shape = (3,)  # age, weight, gender

# Define patient data input and process it through a Dense layer
patient_data_input = Input(shape=(3,))
patient_data_processed = Dense(200, activation='relu')(patient_data_input)  # Dense layer to process the patient data

input_state = [patient_data_processed, patient_data_processed]  # Initial state for the LSTM

input_layer =  Input(input_shape)
mask_layer = keras.layers.Masking(mask_value=100000.0)(input_layer)

lstm1 = (LSTM(200, return_sequences=True))(mask_layer, initial_state=input_state)
# lstm2 = LSTM(100, return_sequences=True)(lstm1)
# lstm3 = LSTM(50, return_sequences=True)(lstm2)  # Return state for lstm3

output_layer = TimeDistributed(Dense(1, activation='sigmoid'))(lstm1)

model = Model([input_layer, patient_data_input], output_layer)

optimizer = tf.keras.optimizers.Adam(learning_rate=0.001, weight_decay=None, epsilon=1e-8)

from tensorflow.keras import backend as K

model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['accuracy'])
model.summary()

2024-04-03 10:06:14.662148: I external/local_tsl/tsl/platform/default/subprocess.cc:304] Start cannot spawn child process: No such file or directory


Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_2 (InputLayer)        [(None, None, 7)]            0         []                            
                                                                                                  
 input_1 (InputLayer)        [(None, 3)]                  0         []                            
                                                                                                  
 masking (Masking)           (None, None, 7)              0         ['input_2[0][0]']             
                                                                                                  
 dense (Dense)               (None, 200)                  800       ['input_1[0][0]']             
                                                                                              

In [11]:

model.fit(data_loader, epochs=150, validation_data=val_data_loader)  


Epoch 1/150


2024-04-03 10:06:39.369634: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:454] Loaded cuDNN version 8904
2024-04-03 10:06:42.193712: I external/local_xla/xla/service/service.cc:168] XLA service 0x7fe8bc03ed80 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2024-04-03 10:06:42.193739: I external/local_xla/xla/service/service.cc:176]   StreamExecutor device (0): NVIDIA GeForce RTX 3080 Ti, Compute Capability 8.6
2024-04-03 10:06:42.213807: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
I0000 00:00:1712131602.447559 1545265 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


    543/Unknown - 30s 38ms/step - loss: 0.6710 - accuracy: 0.5935

2024-04-03 10:07:03.657128: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 3710320218543396077
2024-04-03 10:07:03.657154: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 3085380735194330737
2024-04-03 10:07:03.657161: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 5998072374665537099
2024-04-03 10:07:03.657166: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 7238339185496156813
2024-04-03 10:07:03.657171: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 2842405305701600385
2024-04-03 10:07:03.657175: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 735753060211240393
2024-04-03 10:07:03.657180: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item

Epoch 2/150
Epoch 3/150
Epoch 4/150
Epoch 5/150
Epoch 6/150
Epoch 7/150
Epoch 8/150

KeyboardInterrupt: 

In [None]:

#res = Parallel(n_jobs=-1)(delayed(iterate_patients_save)(pat_id) for pat_id in tqdm(test_patients))
data_loader_test = custom_data_loader(batch_size, test_patients)

model.evaluate(data_loader_test)

### EVALUATION

In [13]:
def mark_sequences(binary_array, k):
    count = 0
    start_index = None
    for i in range(len(binary_array)):
        if binary_array[i] == 1:
            if count == 0:
                start_index = i
            count += 1
            if count > k:
                binary_array[start_index+1:i+1] = 0
        else:
            if count <= k and start_index is not None:
                binary_array[start_index:i] = 0
            count = 0
            start_index = None
    # Handle the last sequence of 1s
    if count <= k and start_index is not None:
        binary_array[start_index:i+1] = 0
    return binary_array

In [24]:
from keras import backend as K

def eval(pat_id, idx):
    try:
        
        patient_id, target, vitals, statics, exclusion_flag = get_data(pat_id)
        #vitals = np.load(save_file_path + str(pat_id) + '_vitals.npy')
        

        if exclusion_flag == 0:

            vitals_pt = (vitals - overall_mean) / overall_std
            decoder_input = np.zeros((vitals.shape[0], 1))

            decoder_input = np.expand_dims(decoder_input, axis=0)
            vitals_pt = np.expand_dims(vitals_pt, axis=0)
            #prediction = model.predict([vitals_pt, decoder_input])[0]
            
            #print(prediction)

            statics[0] = (statics[0] - overall_mean_weight) / overall_std_weight
            statics[1] = (statics[1] - overall_mean_age) / overall_std_age

            statics = np.expand_dims(statics, axis=0)


            prediction = model.predict([vitals_pt, statics])[0]

            K.clear_session()


                                    
            onset_index = (onset_indices_csv[onset_indices_csv['Patient_ID'] == pat_id]['Onset_Index'].values[0])
                        
            #prediction = mark_sequences((prediction > 0.5).astype(int), 1)
            
            prediction = (prediction > 0.5).astype(int)
            #print(onset_index)

            if np.isnan(onset_index) == False: # ground truth is positive

                ground_truth = 1

                onset_index = int(onset_index)
                #prediction = (prediction[:onset_index] > 0.5).astype(int)

                prediction = prediction[:onset_index]
                final_pred = (np.sum(prediction) > 0).astype(int)

                
            else: # ground truth is negative
                ground_truth = 0
                final_pred = (np.sum(prediction) > 0).astype(int)

            #print(ground_truth, final_pred)

            return ground_truth, final_pred
        
        else:
            return None, None
        
    except Exception as e:
        print(e)
        return None, None



In [25]:
subset_test_patients = test_patients.sample(1000)

In [26]:
ground_truths = []
final_preds = []

for pat_id, idx in tqdm(zip(subset_test_patients, range(len(subset_test_patients)))):
    ground_truth, final_pred = eval(pat_id, idx)
    if ground_truth is not None and final_pred is not None:
        ground_truths.append(ground_truth)
        final_preds.append(final_pred)
    



0it [00:00, ?it/s]



3it [00:00, 15.46it/s]



5it [00:00,  5.96it/s]



9it [00:01,  6.52it/s]



10it [00:01,  4.70it/s]



13it [00:02,  4.76it/s]



14it [00:02,  3.87it/s]



16it [00:03,  3.89it/s]



17it [00:03,  3.41it/s]



18it [00:04,  3.02it/s]



19it [00:04,  2.67it/s]



20it [00:05,  2.49it/s]



22it [00:05,  2.95it/s]



24it [00:06,  3.27it/s]



25it [00:06,  2.72it/s]



28it [00:07,  3.53it/s]



30it [00:08,  3.70it/s]



32it [00:08,  3.99it/s]



37it [00:09,  5.86it/s]



39it [00:09,  5.34it/s]



41it [00:10,  4.87it/s]



45it [00:10,  5.34it/s]



46it [00:11,  4.33it/s]



50it [00:11,  5.11it/s]



52it [00:12,  4.96it/s]



55it [00:12,  5.08it/s]



57it [00:13,  4.64it/s]



58it [00:13,  4.07it/s]



61it [00:14,  4.26it/s]



63it [00:14,  4.28it/s]



66it [00:15,  4.66it/s]



73it [00:16,  7.90it/s]



75it [00:16,  6.89it/s]



77it [00:17,  4.81it/s]



78it [00:17,  3.82it/s]



82it [00:18,  4.85it/s]



85it [00:18,  5.09it/s]



98it [00:19, 11.90it/s]



101it [00:20,  8.93it/s]



103it [00:21,  5.22it/s]



105it [00:22,  3.79it/s]



107it [00:22,  3.86it/s]



110it [00:23,  4.44it/s]



112it [00:23,  4.37it/s]



116it [00:24,  5.26it/s]



117it [00:24,  4.55it/s]



129it [00:25, 13.16it/s]



132it [00:26,  6.79it/s]



141it [00:27, 10.18it/s]



154it [00:28, 12.43it/s]



157it [00:29,  6.91it/s]



161it [00:30,  7.04it/s]



164it [00:30,  6.94it/s]



167it [00:31,  6.41it/s]



170it [00:31,  6.23it/s]



171it [00:32,  4.90it/s]



176it [00:32,  6.90it/s]



178it [00:33,  5.76it/s]



180it [00:34,  4.12it/s]



181it [00:34,  3.77it/s]



184it [00:35,  4.44it/s]



185it [00:35,  3.98it/s]



191it [00:35,  7.36it/s]



193it [00:36,  5.58it/s]



195it [00:37,  5.15it/s]



196it [00:37,  3.88it/s]



197it [00:38,  3.41it/s]



200it [00:38,  3.92it/s]



203it [00:39,  4.48it/s]



204it [00:39,  3.91it/s]



205it [00:40,  3.46it/s]



208it [00:40,  4.07it/s]



215it [00:41,  7.18it/s]



232it [00:42, 17.09it/s]



235it [00:42, 11.95it/s]



242it [00:43, 12.61it/s]



245it [00:43,  9.07it/s]



250it [00:44,  9.33it/s]



252it [00:45,  6.77it/s]



254it [00:45,  4.58it/s]



258it [00:46,  5.14it/s]



260it [00:47,  4.84it/s]



261it [00:47,  4.26it/s]



264it [00:48,  4.79it/s]



267it [00:48,  5.11it/s]



269it [00:49,  4.77it/s]



270it [00:49,  4.22it/s]



275it [00:50,  5.65it/s]



278it [00:50,  5.80it/s]



280it [00:50,  5.46it/s]



281it [00:51,  4.21it/s]



286it [00:52,  5.38it/s]



287it [00:52,  4.52it/s]



289it [00:53,  4.13it/s]



292it [00:53,  4.55it/s]



295it [00:54,  4.81it/s]



296it [00:55,  3.64it/s]



299it [00:55,  4.32it/s]



300it [00:55,  3.77it/s]



302it [00:56,  3.81it/s]



307it [00:57,  6.26it/s]



309it [00:57,  5.55it/s]



311it [00:58,  4.78it/s]



312it [00:58,  3.64it/s]



314it [00:59,  3.80it/s]



315it [00:59,  3.25it/s]



323it [01:00,  7.25it/s]



325it [01:00,  6.24it/s]



328it [01:01,  5.54it/s]



331it [01:01,  5.84it/s]



335it [01:02,  6.49it/s]



340it [01:03,  7.05it/s]



344it [01:03,  7.22it/s]



345it [01:04,  5.67it/s]



346it [01:04,  4.79it/s]



347it [01:04,  4.08it/s]



348it [01:05,  3.48it/s]



353it [01:06,  5.98it/s]



361it [01:06,  9.44it/s]



363it [01:07,  6.70it/s]



365it [01:07,  5.55it/s]



367it [01:08,  4.93it/s]



368it [01:08,  3.85it/s]



369it [01:09,  3.40it/s]



373it [01:10,  5.11it/s]



375it [01:10,  4.61it/s]



376it [01:11,  3.50it/s]



380it [01:11,  4.82it/s]



382it [01:12,  4.49it/s]



383it [01:12,  3.85it/s]



384it [01:13,  3.35it/s]



392it [01:13,  7.36it/s]



399it [01:14,  9.85it/s]



404it [01:15,  8.33it/s]



406it [01:16,  5.52it/s]



408it [01:16,  5.01it/s]



411it [01:17,  5.06it/s]



412it [01:17,  4.36it/s]



413it [01:18,  3.49it/s]



414it [01:18,  2.96it/s]



415it [01:19,  2.63it/s]



417it [01:19,  2.90it/s]



419it [01:20,  3.29it/s]



420it [01:20,  2.93it/s]



423it [01:21,  3.83it/s]



426it [01:21,  4.40it/s]



427it [01:22,  3.85it/s]



428it [01:22,  3.17it/s]



430it [01:23,  3.39it/s]



432it [01:23,  3.60it/s]



440it [01:24,  7.25it/s]



442it [01:25,  5.89it/s]



444it [01:25,  5.12it/s]



445it [01:26,  4.24it/s]



447it [01:26,  4.02it/s]



450it [01:27,  4.66it/s]



451it [01:27,  3.86it/s]



453it [01:28,  3.89it/s]



458it [01:28,  5.51it/s]



462it [01:29,  5.88it/s]



463it [01:29,  4.91it/s]



464it [01:30,  3.93it/s]



471it [01:31,  7.02it/s]



473it [01:31,  6.17it/s]



476it [01:31,  6.39it/s]



477it [01:32,  4.88it/s]



480it [01:33,  4.95it/s]



481it [01:33,  4.30it/s]



487it [01:33,  7.73it/s]



493it [01:34, 10.68it/s]



496it [01:34, 10.71it/s]



503it [01:34, 13.52it/s]



510it [01:35, 16.14it/s]



512it [01:35, 12.96it/s]



514it [01:36,  9.30it/s]



520it [01:36,  9.33it/s]



526it [01:37, 10.91it/s]



528it [01:38,  7.10it/s]



530it [01:38,  6.23it/s]



534it [01:39,  6.99it/s]



535it [01:39,  5.74it/s]



537it [01:39,  5.59it/s]



541it [01:40,  6.62it/s]



542it [01:40,  5.54it/s]



545it [01:41,  5.58it/s]



546it [01:41,  4.79it/s]



548it [01:41,  4.77it/s]



  labs_file['charttime'] = pd.to_datetime(labs_file['charttime'])




554it [01:42,  5.84it/s]



555it [01:43,  5.09it/s]



556it [01:43,  4.50it/s]



557it [01:43,  4.11it/s]



559it [01:44,  4.32it/s]



566it [01:44, 10.02it/s]



569it [01:45,  6.67it/s]



573it [01:45,  7.52it/s]



575it [01:46,  5.50it/s]



577it [01:47,  5.46it/s]



578it [01:47,  4.67it/s]



580it [01:47,  4.71it/s]



581it [01:48,  4.20it/s]



582it [01:48,  3.94it/s]



584it [01:48,  4.25it/s]



592it [01:49,  8.58it/s]



593it [01:49,  7.13it/s]



595it [01:50,  6.80it/s]



596it [01:50,  5.39it/s]



598it [01:50,  5.25it/s]



600it [01:51,  5.31it/s]



601it [01:51,  4.45it/s]



602it [01:52,  4.16it/s]



605it [01:52,  4.86it/s]



607it [01:52,  5.01it/s]



609it [01:53,  5.08it/s]



610it [01:53,  4.40it/s]



611it [01:54,  3.78it/s]



612it [01:54,  3.37it/s]



615it [01:54,  4.75it/s]



616it [01:55,  4.01it/s]



618it [01:55,  4.32it/s]



622it [01:56,  5.63it/s]



628it [01:56,  7.89it/s]



630it [01:56,  7.15it/s]



631it [01:57,  5.66it/s]



632it [01:57,  4.92it/s]



634it [01:58,  4.77it/s]



637it [01:58,  5.58it/s]



638it [01:58,  4.85it/s]



646it [01:59,  9.32it/s]



648it [01:59,  8.07it/s]



650it [02:00,  7.57it/s]



652it [02:00,  7.66it/s]



654it [02:00,  7.55it/s]



657it [02:00,  8.39it/s]



658it [02:01,  7.19it/s]



660it [02:01,  7.06it/s]



663it [02:01,  8.04it/s]



665it [02:02,  7.60it/s]



666it [02:02,  5.69it/s]



667it [02:02,  4.96it/s]



670it [02:03,  5.63it/s]



672it [02:03,  5.65it/s]



675it [02:03,  7.03it/s]



678it [02:04,  7.92it/s]



682it [02:04,  9.57it/s]



683it [02:04,  7.92it/s]



684it [02:04,  6.90it/s]



700it [02:05, 22.38it/s]



704it [02:05, 18.95it/s]



707it [02:05, 15.84it/s]



710it [02:06,  8.26it/s]



712it [02:07,  7.11it/s]



714it [02:08,  4.48it/s]



715it [02:08,  3.96it/s]



720it [02:09,  5.94it/s]



722it [02:09,  5.55it/s]



725it [02:10,  6.33it/s]



731it [02:10,  8.94it/s]



733it [02:11,  8.02it/s]



742it [02:11, 12.94it/s]



745it [02:12,  7.82it/s]



747it [02:12,  7.27it/s]



749it [02:13,  7.00it/s]



754it [02:13,  9.10it/s]



762it [02:13, 12.63it/s]



764it [02:14,  8.22it/s]



771it [02:14, 11.22it/s]



778it [02:15,  9.97it/s]



783it [02:16, 10.18it/s]



785it [02:16,  9.38it/s]



787it [02:17,  7.21it/s]



789it [02:17,  7.11it/s]



791it [02:17,  6.82it/s]



792it [02:18,  6.13it/s]



796it [02:18,  8.15it/s]



797it [02:18,  6.72it/s]



798it [02:19,  5.78it/s]



800it [02:19,  4.92it/s]



804it [02:20,  5.40it/s]



805it [02:20,  4.38it/s]



811it [02:21,  7.29it/s]



813it [02:22,  4.57it/s]



816it [02:22,  5.06it/s]



818it [02:23,  5.17it/s]



822it [02:23,  7.07it/s]



826it [02:23,  8.60it/s]



828it [02:23,  8.41it/s]



833it [02:24, 10.62it/s]



835it [02:24,  7.50it/s]



840it [02:25,  9.16it/s]



842it [02:25,  7.75it/s]



845it [02:25,  7.52it/s]



848it [02:26,  7.47it/s]



852it [02:26,  8.04it/s]



853it [02:27,  6.25it/s]



857it [02:27,  6.82it/s]



859it [02:28,  6.13it/s]



860it [02:28,  5.02it/s]



861it [02:29,  4.42it/s]



866it [02:29,  7.36it/s]



868it [02:29,  6.39it/s]



870it [02:30,  4.35it/s]



872it [02:31,  4.59it/s]



883it [02:31, 12.06it/s]



886it [02:32,  9.00it/s]



888it [02:33,  6.40it/s]



896it [02:33,  8.84it/s]



900it [02:34,  9.59it/s]



903it [02:34,  9.27it/s]



906it [02:34,  8.63it/s]



909it [02:35,  8.06it/s]



912it [02:35,  8.01it/s]



914it [02:36,  7.22it/s]



918it [02:36,  8.05it/s]



919it [02:36,  6.78it/s]



921it [02:37,  5.97it/s]



922it [02:37,  4.88it/s]



923it [02:38,  4.35it/s]



925it [02:38,  4.53it/s]



931it [02:38,  7.74it/s]



932it [02:39,  6.41it/s]



933it [02:39,  5.50it/s]



934it [02:39,  4.82it/s]



935it [02:40,  3.92it/s]



938it [02:40,  5.08it/s]



940it [02:41,  5.18it/s]



942it [02:41,  5.17it/s]



945it [02:41,  5.91it/s]



947it [02:42,  5.43it/s]



952it [02:42,  7.42it/s]



960it [02:43, 12.28it/s]



963it [02:43,  8.22it/s]



965it [02:44,  7.72it/s]



967it [02:44,  5.83it/s]



976it [02:45, 11.63it/s]



985it [02:46, 12.70it/s]



989it [02:47,  8.28it/s]



992it [02:48,  5.96it/s]



995it [02:48,  6.57it/s]



1000it [02:48,  5.92it/s]


In [27]:
from sklearn.metrics import confusion_matrix, classification_report


confusion_matrix(ground_truths, final_preds)

array([[ 42, 135],
       [123,  68]])

In [28]:
print(classification_report(ground_truths, final_preds))

              precision    recall  f1-score   support

           0       0.25      0.24      0.25       177
           1       0.33      0.36      0.35       191

    accuracy                           0.30       368
   macro avg       0.29      0.30      0.30       368
weighted avg       0.30      0.30      0.30       368



In [None]:
subset_test_patients = test_patients.sample(1)

patient_id, target, vitals, exclusion_flag = get_data(subset_test_patients.values[0])
#vitals = np.load(save_file_path + str(pat_id) + '_vitals.npy')


if exclusion_flag == 0:

    vitals_pt = (vitals - overall_mean) / overall_std
    decoder_input = np.zeros((vitals.shape[0], 1))

    decoder_input = np.expand_dims(decoder_input, axis=0)
    vitals_pt = np.expand_dims(vitals_pt, axis=0)
    
    prediction = model.predict([vitals_pt, decoder_input])[0]
    #print(prediction)

                            
                            
    onset_index = (onset_indices_csv[onset_indices_csv['Patient_ID'] == pat_id]['Onset_Index'].values[0])
                
    #prediction = mark_sequences((prediction > 0.5).astype(int), 1)
    plt.bar(x=np.arange(len(prediction)), height=prediction.flatten())
    plt.ylim(0, 1)
    print(onset_index)
    print(patient_id)

In [None]:
model.save(save_file_path + 'respi_init_state_gru_model.h5')

In [None]:
save_file_path