In [29]:
import tensorflow as tf
import pandas as pd
import numpy as np
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.optimizers import Adam
import pyarrow.parquet as pq
from sklearn.model_selection import train_test_split
import csv
import subprocess
import getpass
import os
import gzip
from os import listdir
from os.path import isfile, join
from SciServer import Authentication
import tensorflow as tf
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.metrics import AUC

# Data Loading

In [31]:
df_dataset = pd.read_csv("/home/idies/workspace/SAFE/MinooEmir/new_complete_features.csv")

In [32]:
df_dataset = df_dataset[['formatted_time', 'hf_original', 'hf_type_original', 'HDL',
       'tot_cholesterol', 'glucose', 'bnp',
       'Arterial Blood Pressure diastolic', 'Arterial Blood Pressure systolic',
       'Heart Rate', 'gender', 'race', 'age']]

In [33]:
df_dataset.columns

Index(['formatted_time', 'hf_original', 'hf_type_original', 'HDL',
       'tot_cholesterol', 'glucose', 'bnp',
       'Arterial Blood Pressure diastolic', 'Arterial Blood Pressure systolic',
       'Heart Rate', 'gender', 'race', 'age'],
      dtype='object')

In [34]:
feature_list = ['HDL',
               'tot_cholesterol', 
               'glucose', 
               'bnp',
               'Arterial Blood Pressure diastolic', 
               'Arterial Blood Pressure systolic',
               'Heart Rate', 
               'gender', 
               'race', 
               'age']

In [35]:
df_dataset.set_index("formatted_time", inplace=True)
df_dataset.head()

Unnamed: 0_level_0,hf_original,hf_type_original,HDL,tot_cholesterol,glucose,bnp,Arterial Blood Pressure diastolic,Arterial Blood Pressure systolic,Heart Rate,gender,race,age
formatted_time,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
13:45:00_11_01_2110_18106347,0,Non-HF,,,95.0,,66.068966,108.655172,100.8,0,6,48
04:02:00_17_01_2110_18780420,0,Non-HF,56.0,159.0,113.0,,,,59.958333,1,7,84
02:02:00_22_01_2110_16006168,0,Non-HF,,,96.142857,,80.428571,131.785714,92.785714,1,2,20
17:21:00_30_01_2110_14816979,0,Non-HF,39.0,189.0,98.133333,,,,96.173913,1,7,30
17:07:00_01_02_2110_13956717,0,Non-HF,,,106.777778,,52.473684,98.684211,80.157895,1,7,72


## Impute BNP

In [36]:
import random
# Function to generate random NT-proBNP values based on age
def generate_nt_proBNP(age):
    if age < 75:
        return random.uniform(0, 125)  # For adults younger than 75 years
    else:
        return random.uniform(0, 450)  # For adults 75 years or older

# Identify missing values in 'probnp'
missing_values = df_dataset['bnp'].isnull()

# Determine age of individuals with missing 'probnp' values (replace 'age_column' with the actual age column name)
missing_age = df_dataset.loc[missing_values, 'age']

# Generate random NT-proBNP values based on age
imputed_values = missing_age.apply(generate_nt_proBNP)

# Replace missing values in 'probnp' with generated values
df_dataset.loc[missing_values, 'bnp'] = imputed_values

In [37]:
df_dataset.head()

Unnamed: 0_level_0,hf_original,hf_type_original,HDL,tot_cholesterol,glucose,bnp,Arterial Blood Pressure diastolic,Arterial Blood Pressure systolic,Heart Rate,gender,race,age
formatted_time,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
13:45:00_11_01_2110_18106347,0,Non-HF,,,95.0,50.134772,66.068966,108.655172,100.8,0,6,48
04:02:00_17_01_2110_18780420,0,Non-HF,56.0,159.0,113.0,394.510415,,,59.958333,1,7,84
02:02:00_22_01_2110_16006168,0,Non-HF,,,96.142857,100.554047,80.428571,131.785714,92.785714,1,2,20
17:21:00_30_01_2110_14816979,0,Non-HF,39.0,189.0,98.133333,24.524234,,,96.173913,1,7,30
17:07:00_01_02_2110_13956717,0,Non-HF,,,106.777778,90.666634,52.473684,98.684211,80.157895,1,7,72


## Data splitting into train, validation, and test

In [38]:
print("for HF vs Non-HF")
X_id = df_dataset.index
X = df_dataset[feature_list]
y = df_dataset["hf_original"]

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.3, random_state = 0, stratify = y)    

X_train_id, X_test_id, _, _ = train_test_split(X_id, y, test_size = 0.3, random_state = 0, stratify = y)
X_train_id, X_val_id, _, _ = train_test_split(X_train_id, y_train, test_size = 0.5, random_state = 0, stratify = y_train)


for HF vs Non-HF


In [39]:
from sklearn.preprocessing import StandardScaler
from sklearn.impute import SimpleImputer

imputer = SimpleImputer(strategy="median")
X_train_imputed = imputer.fit_transform(X_train)
X_test_imputed = imputer.transform(X_test)

scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train_imputed)
X_test_scaled = scaler.transform(X_test_imputed)

X_train_scaled_df = pd.DataFrame(X_train_scaled, columns=X.columns, index=X_train.index)
X_test_scaled_df = pd.DataFrame(X_test_scaled, columns=X.columns, index=X_test.index)

In [40]:
X_train_scaled_df, X_val, y_train, y_val = train_test_split(X_train_scaled_df, y_train, test_size = 0.5, random_state = 0, stratify = y_train) 

In [41]:
base_dir = '/home/idies/workspace/SAFE/ecg_preprocessed'

# Create the full file paths
# waveform_files = [f for f in os.listdir(base_dir)]

train_file_paths = [f"{file_id}" for file_id in X_train_id]
val_file_paths = [f"{file_id}" for file_id in X_val_id]
test_file_paths = [f"{file_id}" for file_id in X_test_id]

## Aligning waveform and tabular data

In [54]:
def aligning_dataset(waveform_files, waveform_dir, tab_dataset_feature, tab_dataset_y):
    waveform_paths = []
    tabular_dat_list = []
    label = []
    
    for i, waveform_file_name in enumerate(waveform_files):
        patient_id = waveform_file_name
        if patient_id in tab_dataset_feature.index:
            waveform_paths.append(os.path.join(waveform_dir, patient_id))
            tabular_dat_list.append(tab_dataset_feature.loc[patient_id].values)
            # Assume labels are included in the tabular data
            label.append(tab_dataset_y.loc[patient_id])

    # Convert lists to numpy arrays
    tabular_dat_arr = np.array(tabular_dat_list)
    labels_arr = np.array(label)

    return waveform_paths, tabular_dat_arr, labels_arr

In [55]:
train_waveform_paths, train_tab_data, train_labels = aligning_dataset(train_file_paths,base_dir, X_train_scaled_df, y_train)

In [56]:
train_waveform_paths, train_tab_data, train_labels = aligning_dataset(train_file_paths,base_dir, X_train_scaled_df, y_train)

val_waveform_paths, val_tab_data, val_labels = aligning_dataset(val_file_paths,base_dir, X_val, y_val)

test_waveform_paths, test_tab_data, test_labels = aligning_dataset(test_file_paths,base_dir, X_test_scaled_df, y_test)

## Create data generator and loader for NN

In [57]:
# Define the data generator
def ecg_data_generator(waveform_paths, tabular_data, labels):
    for i in range(len(waveform_paths)):
        try:
            # Load the waveform data from Parquet
            waveform_data = pq.read_table(waveform_paths[i]).to_pandas().values  # Ensure it's a numpy array
            tabular_data_sample = tabular_data[i]  # Corresponding tabular data
            label = labels[i]
            yield (waveform_data, tabular_data_sample), label
        except Exception as e:
            print(f'Error loading {waveform_paths[i]}: {e}')
            continue

# Create TensorFlow datasets
def create_dataset(waveform_paths, tabular_data, labels, batch_size):
    dataset = tf.data.Dataset.from_generator(
        lambda: ecg_data_generator(waveform_paths, tabular_data, labels),
        output_signature=(
            (tf.TensorSpec(shape=(5000, 12), dtype=tf.float32), tf.TensorSpec(shape=(tabular_data.shape[1],), dtype=tf.float32)),
            tf.TensorSpec(shape=(), dtype=tf.int32)
        )
    )
    return dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE).repeat()

In [58]:

print("Shape of train_tab_data:", train_tab_data.shape)
print("Shape of val_tab_data:", val_tab_data.shape)
print("Shape of test_tab_data:", test_tab_data.shape)


Shape of train_tab_data: (6414, 10)
Shape of val_tab_data: (6414, 10)
Shape of test_tab_data: (5499, 10)


In [59]:
# Shuffle and batch the dataset
batch_size = 32

train_dataset = create_dataset(train_waveform_paths,train_tab_data, train_labels, batch_size)

val_dataset = create_dataset(val_waveform_paths,val_tab_data, val_labels, batch_size)

test_dataset = create_dataset(test_waveform_paths,test_tab_data, test_labels, batch_size)

In [60]:
# Inspect the data inside the train_dataset
for (waveform, tabular), label in val_dataset.take(1):  # Only take the first batch for inspection
    print("Waveform data shape:", waveform.shape)
    print("Waveform data:", waveform)
    print("Tabular data shape:", tabular.shape)
    print("Tabular data:", tabular)
    print("Label data shape:", label.shape)
    print("Label data:", label)

Waveform data shape: (32, 5000, 12)
Waveform data: tf.Tensor(
[[[ 0.    -0.02  -0.02  ...  0.015  0.015 -0.02 ]
  [ 0.    -0.015 -0.015 ...  0.     0.005 -0.02 ]
  [ 0.    -0.015 -0.015 ... -0.005  0.    -0.02 ]
  ...
  [-0.035  0.     0.035 ... -0.025 -0.025 -0.04 ]
  [-0.035 -0.01   0.025 ... -0.025 -0.025 -0.04 ]
  [-0.04  -0.02   0.02  ... -0.025 -0.025 -0.04 ]]

 [[ 0.115  0.065 -0.04  ... -0.065 -0.04  -0.06 ]
  [ 0.115  0.065 -0.04  ... -0.065 -0.04  -0.06 ]
  [ 0.115  0.065 -0.04  ... -0.075 -0.04  -0.07 ]
  ...
  [-0.155 -0.025  0.14  ... -0.065  0.     0.05 ]
  [-0.14  -0.025  0.125 ... -0.06   0.005  0.05 ]
  [-0.125 -0.015  0.12  ... -0.045  0.005  0.06 ]]

 [[ 0.01   0.02   0.005 ... -0.01  -0.025 -0.025]
  [ 0.     0.02   0.015 ... -0.01  -0.025 -0.025]
  [-0.005  0.01   0.01  ... -0.01  -0.025 -0.025]
  ...
  [-0.01   0.     0.005 ... -0.005 -0.015  0.005]
  [-0.01  -0.01  -0.005 ... -0.005 -0.015  0.   ]
  [-0.01  -0.02  -0.005 ... -0.01  -0.005 -0.005]]

 ...

 [[ 0.11

2024-06-20 21:15:30.124535: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


# Define model architecture

In [160]:
import tensorflow as tf
from keras import backend as K
from keras.layers import (Input, Dense, Conv1D, Dropout, MaxPooling1D, 
                          Activation, Lambda, BatchNormalization, Add,
                          Flatten, Attention, MultiHeadAttention)
from keras.optimizers import Adam
from keras.models import Model
from keras.metrics import AUC
from keras.models import Model, Sequential

from tensorflow.keras.layers import Input, Conv1D, BatchNormalization, Conv2D, MaxPooling2D, \
    ReLU, Reshape, GlobalAveragePooling1D, Dense, Concatenate, Dropout, concatenate, LeakyReLU, SpatialDropout1D, Attention
import logging
from tensorflow.keras.layers import Layer, Dense, MultiHeadAttention, LayerNormalization

# PAPER: Screening for cardiac contractile dysfunction using an artificial intelligence–enabled electrocardiogram
#        https://www.nature.com/articles/s41591-018-0240-2
# SOURCE REPO: https://github.com/chrisby/DeepCardiology
class Attia_et_al_CNN():
    def __init__(self, 
                 filter_numbers=[16, 16, 32, 32, 64, 64], 
                 kernel_widths=[7, 7, 5, 5, 3, 3], 
                 pool_sizes=[2, 2, 4, 2, 2, 4], 
                 spatial_num_filters=64, 
                 dense_dropout_rate=0.2, 
                 spatial_dropout_rate=0.2,
                 dense_units=[64, 32], 
                 use_spatial_layer=False,
                 verbose=1,
                 use_residual=True):

        self.filter_numbers = filter_numbers
        self.kernel_widths = kernel_widths
        self.pool_sizes = pool_sizes
        self.spatial_num_filters = spatial_num_filters
        self.dense_dropout_rate = dense_dropout_rate
        self.spatial_dropout_rate = spatial_dropout_rate
        self.dense_units = dense_units
        self.use_spatial_layer = use_spatial_layer
        self.verbose = verbose
        self.use_residual = use_residual

        self.att = Attention()

        self.model = None

        if self.verbose == 0:
            return
        
        print("Attia et al. CNN model initialized with the following parameters:")
        print(f"  filter_numbers: {self.filter_numbers}")
        print(f"  kernel_widths: {self.kernel_widths}")
        print(f"  pool_sizes: {self.pool_sizes}")
        print(f"  spatial_num_filters: {self.spatial_num_filters}")
        print(f"  dense_dropout_rate: {self.dense_dropout_rate}")
        print(f"  spatial_dropout_rate: {self.spatial_dropout_rate}")
        print(f"  dense_units: {self.dense_units}")
        print(f"  use_spatial_layer: {self.use_spatial_layer}")
        print(f"  use_residual: {self.use_residual}")
    
    def get_temporal_layer(self, N, k, p, input_layer):
        c = Conv1D(N, k, padding='same', kernel_initializer='he_normal')(input_layer)
        b = tf.keras.layers.BatchNormalization()(c)
        a = Activation('relu')(b)
        p = MaxPooling1D(pool_size=p)(a)
        do = SpatialDropout1D(self.spatial_dropout_rate)(p)
        return do

    def get_temporal_layer_with_residual(self, N, k, p, input_layer):
        # Main pathway
        x = Conv1D(N, k, padding='same', kernel_initializer='he_normal')(input_layer)
        x = SpatialDropout1D(self.spatial_dropout_rate)(x)
        x = BatchNormalization()(x)
        x = Activation('relu')(x)
        
        # Shortcut pathway
        # Ensure the shortcut matches the dimension of the main pathway's output, adjust filters and stride as necessary
        shortcut = Conv1D(N, 1, padding='same', kernel_initializer='he_normal')(input_layer)  # 1x1 conv for matching dimension
        shortcut = BatchNormalization()(shortcut)  # Optional, for matching feature-wise statistics
        
        # Merging the shortcut with the main pathway
        merged_output = Add()([x, shortcut])  # Element-wise addition

        x = MaxPooling1D(pool_size=p)(merged_output)
        
        return x
    
    def get_spatial_layer(self, kernel_size, input_layer):
        c = Conv1D(self.spatial_num_filters, kernel_size, padding='same', data_format="channels_first", kernel_initializer='he_normal')(input_layer)
        b = tf.keras.layers.BatchNormalization()(c)
        a = Activation('relu')(b)
        do = SpatialDropout1D(self.spatial_dropout_rate)(a)
        return do
    
    def get_fully_connected_layer(self, units, input_layer):
        d = Dense(units, kernel_initializer='he_normal')(input_layer)
        b = tf.keras.layers.BatchNormalization()(d)
        a = Activation('relu')(b)
        do = Dropout(self.dense_dropout_rate)(a)
        return do

    def build(self, input_shape=(5000, 12)):
        input_layer = Input(shape=input_shape)
        last_layer = input_layer
        
        for i in range(len(self.pool_sizes)):
            if self.use_residual:
                temp_layer = self.get_temporal_layer_with_residual(self.filter_numbers[i], self.kernel_widths[i],
                                            self.pool_sizes[i], last_layer)
            else:
                temp_layer = self.get_temporal_layer(self.filter_numbers[i], self.kernel_widths[i],
                                            self.pool_sizes[i], last_layer)
            last_layer = temp_layer
        
        if self.use_spatial_layer:
            last_layer = self.get_spatial_layer(input_shape[1], last_layer)

        last_layer = Flatten()(last_layer)

        for i in range(len(self.dense_units)):
            dense_layer = self.get_fully_connected_layer(self.dense_units[i], last_layer)
            last_layer = dense_layer

        output_layer = Dense(1, activation='sigmoid')(last_layer)
        self.model = Model(inputs=input_layer, outputs=output_layer)

        if self.verbose > 0:
            print(self.model.summary())
        return self.model

class MultiHeadCrossAttention(Layer):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.cross_attention = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        self.dense_proj = Dense(embed_dim, activation='relu')
        self.layer_norm = LayerNormalization(epsilon=1e-6)

        # Project input dimensions to match expected [batch_size, sequence_length, embed_dim]
        self.query_projection = Dense(embed_dim)
        self.value_projection = Dense(embed_dim)

    def call(self, query, value):
        # Ensure query and value match required shape [batch_size, seq_length, embed_dim]
        query = self.query_projection(tf.expand_dims(query, axis=1))
        value = self.value_projection(tf.expand_dims(value, axis=1))

        attn_output = self.cross_attention(query=query, value=value, key=value)
        attn_output = self.dense_proj(attn_output[:, 0, :])  # Reshape output if needed
        output = self.layer_norm(query[:, 0, :] + attn_output)

        return output
    
# PAPER: Screening for cardiac contractile dysfunction using an artificial intelligence–enabled electrocardiogram
#        https://www.nature.com/articles/s41591-018-0240-2
# SOURCE REPO: https://github.com/chrisby/DeepCardiology
class Attia_et_al_fusion():
    def __init__(self, 
                 filter_numbers=[16, 16, 32, 32, 64, 64], 
                 kernel_widths=[5, 5, 5, 3, 3, 3], 
                 pool_sizes=[2, 2, 4, 2, 2, 4], 
                 spatial_num_filters=64, 
                 dropout_rate=0.2, 
                 dense_units=[64, 32], 
                 fusion_strategy="concat",
                 use_waveforms=True,
                 use_residual = True,
                 spatial_dropout_rate=0.2,
                 verbose=1):
        
        # Fusion strategy options
        # concat, self_attn, cross_attn
        # mlp, tab_mlp
        
        self.filter_numbers = filter_numbers
        self.kernel_widths = kernel_widths
        self.pool_sizes = pool_sizes
        self.spatial_num_filters = spatial_num_filters
        self.dropout_rate = dropout_rate
        self.dense_units = dense_units
        self.fusion_strategy = fusion_strategy
        self.use_waveforms = use_waveforms
        self.use_residual = use_residual
        self.spatial_dropout_rate = spatial_dropout_rate

        self.verbose = verbose

        self.model = None

        self.att = Attention()

        if self.verbose == 0:
            return
        
        print("Attia et al. CNN model initialized with the following parameters:")
        print(f"  filter_numbers: {self.filter_numbers}")
        print(f"  kernel_widths: {self.kernel_widths}")
        print(f"  pool_sizes: {self.pool_sizes}")
        print(f"  spatial_num_filters: {self.spatial_num_filters}")
        print(f"  dropout_rate: {self.dropout_rate}")
        print(f"  dense_units: {self.dense_units}")
        print(f"  fusion_strategy: {self.fusion_strategy}")
        print(f"  use_waveforms: {self.use_waveforms}")
        print(f"  use_residual: {self.use_residual}")


    def create_mlp(self, input_shape):
        mlp = tf.keras.Sequential([Dense(128, input_shape=(40,), activation='relu'),
                    Dropout(self.dropout_rate),
                    Dense(32, activation='relu'),
                    Dropout(self.dropout_rate),
                    Dense(8, activation='relu'),
                    Dropout(self.dropout_rate),
                    Dense(1, activation='sigmoid')])
        return mlp

    def create_attn(self, input_shape):
        attn = Attention()
        return attn

    def get_temporal_layer(self, N, k, p, input_layer):
        c = Conv1D(N, k, padding='same')(input_layer)
        # c = SeparableConv1D(N, k, padding='same', activation='relu')(input_layer)
        b = tf.keras.layers.BatchNormalization()(c)
        a = Activation('relu')(b)
        p = MaxPooling1D(pool_size=p)(a)
        do = SpatialDropout1D(0.1)(p)
        return do
    
    def get_temporal_layer_with_residual(self, N, k, p, input_layer):
        # Main pathway
        x = Conv1D(N, k, padding='same', kernel_initializer='he_normal')(input_layer)
        x = SpatialDropout1D(self.spatial_dropout_rate)(x)
        x = BatchNormalization()(x)
        x = Activation('relu')(x)
        
        # Shortcut pathway
        # Ensure the shortcut matches the dimension of the main pathway's output, adjust filters and stride as necessary
        shortcut = Conv1D(N, 1, padding='same', kernel_initializer='he_normal')(input_layer)  # 1x1 conv for matching dimension
        shortcut = BatchNormalization()(shortcut)  # Optional, for matching feature-wise statistics
        
        # Merging the shortcut with the main pathway
        merged_output = Add()([x, shortcut])  # Element-wise addition

        x = MaxPooling1D(pool_size=p)(merged_output)
        
        return x
    
    def get_spatial_layer(self, kernel_size, input_layer):
        c = Conv1D(self.spatial_num_filters, kernel_size, kernel_initializer='he_normal')(input_layer)
        # c = Conv1D(self.spatial_num_filters, kernel_size, data_format="channels_first", kernel_initializer='he_normal')(input_layer)
        b = tf.keras.layers.BatchNormalization()(c)
        a = Activation('relu')(b)
        do = SpatialDropout1D(0.1)(a)
        return do

    def build(self, input_shape=(5000, 12), fusion_shape=(1,)):
        waveform_input = Input(shape=input_shape)
        last_layer = waveform_input
        
        # Building CNN layers for waveform processing
        for i in range(len(self.pool_sizes)):
            if self.use_residual:
                temp_layer = self.get_temporal_layer_with_residual(self.filter_numbers[i], self.kernel_widths[i],
                                            self.pool_sizes[i], last_layer)
            else:
                temp_layer = self.get_temporal_layer(self.filter_numbers[i], self.kernel_widths[i],
                                            self.pool_sizes[i], last_layer)
            last_layer = temp_layer
        
        # last_layer = self.get_spatial_layer(input_shape[1], last_layer)
        flattened_waveform = Flatten()(last_layer)

        # Final Dense layers
        x = Dense(64, activation='relu')(flattened_waveform)
        x = Dropout(self.dropout_rate)(x)
        x = Dense(32, activation='relu')(x)
        x = Dropout(self.dropout_rate)(x)

        fusion_input = Input(shape=fusion_shape)
        # f = Dense(32, activation='relu')(fusion_input)
        # f = Dropout(self.dropout_rate)(f)
        # f = Dense(16, activation='relu')(f)
        # f = Dropout(self.dropout_rate)(f)

        if self.use_waveforms:
            if self.fusion_strategy == "concat":
                x = concatenate([x, fusion_input])
                output = Dense(1, activation='sigmoid')(x)
            
            elif self.fusion_strategy == "self_attn":
                # fusion_input = Dropout(self.dropout_rate)(fusion_input)
                x = concatenate([x, fusion_input])

                embed_dim = 16  # Dimensionality of the encoder.
                num_heads = 4    # Number of attention heads.

                cross_attention_layer = MultiHeadCrossAttention(embed_dim, num_heads)
                x = cross_attention_layer(x, x)
                x = Dropout(self.dropout_rate)(x)
                output = Dense(1, activation='sigmoid')(x)

            elif self.fusion_strategy == "cross_attn":
                embed_dim = 16  # Dimensionality of the encoder.
                num_heads = 4    # Number of attention heads.

                cross_attention_layer = MultiHeadCrossAttention(embed_dim, num_heads)

                # fusion_input = Dropout(self.dropout_rate)(fusion_input)
                x = cross_attention_layer(x, fusion_input)
                x = Dropout(self.dropout_rate)(x)
                output = Dense(1, activation='sigmoid')(x)

            if self.fusion_strategy == "mlp":
                x = concatenate([x, fusion_input])
                # WF/tab concat -> MLP -> output
                x = Dense(32, activation='relu')(x)
                x = Dropout(self.dropout_rate)(x)
                x = Dense(16, activation='relu')(x)
                x = Dropout(self.dropout_rate)(x)
                x = Dense(8, activation='relu')(x)
                x = Dropout(self.dropout_rate)(x)
                output = Dense(1, activation='sigmoid')(x)

            elif self.fusion_strategy == "tab_mlp":
                x = Dense(16, activation='relu')(x)
                x = Dropout(self.dropout_rate)(x)

                fus = Dense(32, activation='relu')(fusion_input)
                fus = Dropout(self.dropout_rate)(fus)
                fus = Dense(16, activation='relu')(fus)
                fus = Dropout(self.dropout_rate)(fus)

                x = concatenate([x, fus])
                output = Dense(1, activation='sigmoid')(x)
        else:
            x = fusion_input
            output = Dense(1, activation='sigmoid')(x)

            # x = Dense(16, activation='relu')(x)
            # x = Dropout(self.dropout_rate)(x)
            # x = Dense(8, activation='relu')(x)
            # x = Dropout(self.dropout_rate)(x)

        # x = Dense(16, activation='relu')(x)
        # x = Dropout(self.dropout_rate)(x)
        # x = concatenate([x, f])
        # x = Dense(8, activation='relu')(x)

        # WF/tab concat -> MLP -> output
        # x = Dense(32, activation='relu')(x)
        # x = Dropout(self.dropout_rate)(x)
        # x = Dense(16, activation='relu')(x)
        # x = Dropout(self.dropout_rate)(x)
        # x = Dense(8, activation='relu')(x)
        # x = Dropout(self.dropout_rate)(x)


        # WF/tab concat -> attn -> output
        # x = self.att([x, x])
        # output = Dense(1, activation='sigmoid')(x)

        self.model = Model(inputs=[waveform_input, fusion_input], outputs=output)

        if self.verbose > 0:
            print(self.model.summary())

        return self.model


# Model training and evaluation

In [26]:
# Calculate steps per epoch
steps_per_epoch = len(train_waveform_paths) // batch_size
validation_steps = len(val_waveform_paths) // batch_size
test_steps = len(test_waveform_paths) // batch_size

In [172]:

# Define and build the model
fusion_model = Attia_et_al_fusion()
model = fusion_model.build(input_shape=(5000, 12), fusion_shape=(train_tab_data.shape[1],))

# Compile the model
model.compile(optimizer='adam', 
              loss='binary_crossentropy', 
              metrics=['accuracy', AUC(name='auc')])

# Training parameters
EPOCHS = 50

# Callbacks (example)
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=0.001)
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)

# Train the model
history = model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=EPOCHS,
    steps_per_epoch=steps_per_epoch,
    validation_steps=validation_steps,
    callbacks=[reduce_lr, early_stopping],
)

Attia et al. CNN model initialized with the following parameters:
  filter_numbers: [16, 16, 32, 32, 64, 64]
  kernel_widths: [5, 5, 5, 3, 3, 3]
  pool_sizes: [2, 2, 4, 2, 2, 4]
  spatial_num_filters: 64
  dropout_rate: 0.2
  dense_units: [64, 32]
  fusion_strategy: concat
  use_waveforms: True
  use_residual: True


None
Epoch 1/20
[1m200/200[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m668s[0m 3s/step - accuracy: 0.7081 - auc: 0.5384 - loss: 0.6457 - val_accuracy: 0.7713 - val_auc: 0.7270 - val_loss: 0.4987 - learning_rate: 0.0010
Epoch 2/20
[1m200/200[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m622s[0m 3s/step - accuracy: 0.7689 - auc: 0.6635 - loss: 0.5218 - val_accuracy: 0.7715 - val_auc: 0.7542 - val_loss: 0.4824 - learning_rate: 0.0010
Epoch 3/20
[1m200/200[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m624s[0m 3s/step - accuracy: 0.7800 - auc: 0.6974 - loss: 0.5016 - val_accuracy: 0.7755 - val_auc: 0.7617 - val_loss: 0.4734 - learning_rate: 0.0010
Epoch 4/20
[1m200/200[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m626s[0m 3s/step - accuracy: 0.7752 - auc: 0.7217 - loss: 0.4901 - val_accuracy: 0.7767 - val_auc: 0.7644 - val_loss: 0.4692 - learning_rate: 0.0010
Epoch 5/20
[1m200/200[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m627s[0m 3s/step - accuracy: 0.7808 - auc: 0.7482

In [None]:
# Save the trained model
model.save('trained_fusion_model.keras')

In [62]:
trained_model = tf.keras.models.load_model('trained_fusion_model.keras', custom_objects={'AUC': tf.keras.metrics.AUC})

  saveable.load_own_variables(weights_store.get(inner_path))


In [63]:
# Evaluate the model on the test dataset
test_loss, test_accuracy, test_auc = trained_model.evaluate(test_dataset, steps=test_steps)
print(f"Test Loss: {test_loss}")
print(f"Test Accuracy: {test_accuracy}")
print(f"Test AUC: {test_auc}")

[1m171/171[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m370s[0m 2s/step - accuracy: 0.7855 - auc: 0.7844 - loss: 0.4533
Test Loss: 0.463108092546463
Test Accuracy: 0.7799707651138306
Test AUC: 0.7763502597808838
