## This code is associated with Experiment 5 described in the manuscript for Homogeneous-3D-CNN-BiLSTM-ERMHA

In [None]:
import os
import sys
import argparse
import numpy as np
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.metrics import BinaryAccuracy, Precision, Recall, AUC, SpecificityAtSensitivity
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
import tensorflow.keras.backend as K
from sklearn.model_selection import KFold
from sklearn.model_selection import KFold, train_test_split
from classification_models_3D.kkeras import Classifiers

In [2]:
def get_backbone_model(backbone_name, input_shape):
    """Get the specified backbone model"""
    backbone_models = {
        'resnet50': 'resnet50',
        'vgg16': 'vgg16', 
        'densenet121': 'densenet121',
        'inceptionv3': 'inceptionv3',
        'efficientnetb7': 'efficientnetb7'}
    if backbone_name not in backbone_models:
        raise ValueError(f"Unsupported backbone: {backbone_name}")    
    BackboneModel, preprocess_input = Classifiers.get(backbone_models[backbone_name])
    backbone = BackboneModel(input_shape=input_shape, weights=None, include_top=False)
    return backbone, preprocess_input

In [3]:
class F1Score(tf.keras.metrics.Metric):
    def __init__(self, name='f1_score', **kwargs):
        super(F1Score, self).__init__(name=name, **kwargs)
        self.precision_metric = Precision()
        self.recall_metric = Recall()
    def update_state(self, y_true, y_pred, sample_weight=None):
        self.precision_metric.update_state(y_true, y_pred, sample_weight)
        self.recall_metric.update_state(y_true, y_pred, sample_weight)
    def result(self):
        precision = self.precision_metric.result()
        recall = self.recall_metric.result()
        return 2 * ((precision * recall) / (precision + recall + K.epsilon()))
    def reset_state(self):
        self.precision_metric.reset_state()
        self.recall_metric.reset_state()

def specificity(y_true, y_pred):
    neg_y_true = 1 - y_true
    neg_y_pred = 1 - y_pred
    fp = K.sum(neg_y_true * y_pred)
    tn = K.sum(neg_y_true * neg_y_pred)
    specificity = tn / (tn + fp + K.epsilon())
    return specificity

early_stop = EarlyStopping(
    monitor='val_loss', patience=5, restore_best_weights=True, verbose=1)

In [29]:
def load_data(X_data, timesteps):
    timestep_map = {'one': 1, 'two': 2, 'three': 3, 'four': 4}
    n_timesteps = timestep_map[timesteps]    
    message = ''
    data_formate = 0

    axial_view = 0
    coronal_view = 0
    sagittal_view = 0
        
    axial_coronal_views = X_data[:, 0:192, :, :,  :]        
    axial = axial_coronal_views[:, 0:64, :, :,  :]
    coronal = axial_coronal_views[:, 64:128, :, :, :]
    sagittal = axial_coronal_views[:, 128:192, :, :, :]

    if n_timesteps == 1:
        axial_view = axial[:, 0:16, :, :,  :] # axial
        coronal_view = coronal[:, 0:16, :, :, :] # coronal
        sagittal_view = sagittal[:, 0:16, :, :, :] # sagittal
        message = 'Data for Single time-steps for axial_coronal_sagittal view returned.'
        print(message)
    elif n_timesteps == 2:
        axial_view = axial[:, 0:32, :, :,  :] # axial
        coronal_view = coronal[:, 0:32, :, :, :] # coronal
        sagittal_view = sagittal[:, 0:32, :, :, :] # sagittal
        message = 'Data for Two time-steps for axial_coronal_sagittal view returned.'
        print(message)
            
    elif n_timesteps == 3:
        axial_view = axial[:, 0:48, :, :,  :] # axial
        coronal_view = coronal[:, 0:48, :, :, :] # coronal
        sagittal_view = sagittal[:, 0:48, :, :, :] # sagittal           
        message = 'Data for Three time-steps for axial_coronal_sagittal view returned.'
        print(message)
    else:
        axial_view = axial[:, 0:64, :, :,  :] # axial
        coronal_view = coronal[:, 0:64, :, :, :] # coronal
        sagittal_view = sagittal[:, 0:64, :, :, :] # sagittal                        
        message = 'Data for Four time-steps for axial_coronal_sagittal view returned.'
        print(message)
              
    return  [axial_view, coronal_view, sagittal_view] 

In [34]:
def build_model(backbone_name, view1_shape, view2_shape, view3_shape, num_classes=2):
    input_1 = keras.Input(shape=view1_shape, name='input_1')  # axial view
    input_2 = keras.Input(shape=view2_shape, name='input_2')  # coronal view
    input_3 = keras.Input(shape=view3_shape, name='input_3')  # sagittal view
    
    backbone_1, preprocess_input = get_backbone_model(backbone_name, view1_shape)
    backbone_2, _ = get_backbone_model(backbone_name, view2_shape)
    backbone_3, _ = get_backbone_model(backbone_name, view3_shape)
    
    backbone_1_wrapped = keras.Model(inputs=backbone_1.input, outputs=backbone_1.output, 
                                     name=f'{backbone_name}_branch_1' )    
    backbone_2_wrapped = keras.Model(inputs=backbone_2.input, outputs=backbone_2.output, 
                                     name=f'{backbone_name}_branch_2')
    backbone_3_wrapped = keras.Model(inputs=backbone_3.input, outputs=backbone_3.output, 
                                     name=f'{backbone_name}_branch_3')
    
    features_1 = backbone_1_wrapped(input_1)
    features_2 = backbone_2_wrapped(input_2)
    features_3 = backbone_3_wrapped(input_3)
    
    gap_1 = layers.GlobalAveragePooling3D(name='gap_branch_1')(features_1)
    gap_2 = layers.GlobalAveragePooling3D(name='gap_branch_2')(features_2)
    gap_3 = layers.GlobalAveragePooling3D(name='gap_branch_3')(features_3)
    
    fused_features = layers.Concatenate(axis=-1, name='feature_fusion')([gap_1, gap_2, gap_3])
    gap_size = fused_features.shape[-1]
    
    if gap_size >= 128:
        timesteps = 16
        features_per_step = gap_size // timesteps
        if gap_size % timesteps != 0:
            padding_size = timesteps - (gap_size % timesteps)
            gap_padded = layers.Lambda(lambda x: tf.pad(x, [[0, 0], [0, padding_size]], 'constant'))(fused_features)
            gap_size = gap_size + padding_size
            features_per_step = gap_size // timesteps
        else:
            gap_padded = fused_features
        x = layers.Reshape((timesteps, features_per_step), name='reshape_for_bilstm')(gap_padded)
    else:
        timesteps = 8
        features_per_step = max(16, gap_size // timesteps)
        gap_expanded = layers.Dense(timesteps * features_per_step, activation='relu', name='expand_for_bilstm')(fused_features)
        x = layers.Reshape((timesteps, features_per_step), name='reshape_for_bilstm')(gap_expanded)
    
    x = layers.Bidirectional(layers.LSTM(128, return_sequences=True, dropout=0.1, recurrent_dropout=0.1,
                                         recurrent_regularizer=keras.regularizers.l1(0.01),
                                         name='bilstm_1'),name='bidirectional_lstm_1')(x)
    
    bilstm_output = layers.Bidirectional(layers.LSTM(64, return_sequences=True, dropout=0.1,
                                                     recurrent_dropout=0.1, recurrent_regularizer=keras.regularizers.l1(0.03),
                                                     name='bilstm_2' ), name='bidirectional_lstm_2')(x)
    
    def create_ermha_layer(inputs, num_heads=8, key_dim=64, name_prefix='ermha'):
        
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=key_dim, name=f'{name_prefix}_mhsa'
        )(inputs, inputs)  
        residual_output = layers.Add(name=f'{name_prefix}_residual')([inputs, attention_output])        
        normalized_output = layers.LayerNormalization(name=f'{name_prefix}_layernorm')(residual_output)        
        return normalized_output
    
    ermha_1 = create_ermha_layer(bilstm_output, num_heads=8, key_dim=64, name_prefix='ermha_1')    
    ermha_2 = create_ermha_layer(ermha_1, num_heads=8, key_dim=64, name_prefix='ermha_2')
    ermha_pooled = layers.GlobalAveragePooling1D(name='ermha_gap')(ermha_2)    
    x = layers.Dense(128, activation='relu', name='dense_final')(ermha_pooled)
    x = layers.Dropout(0.2, name='dropout_final')(x)    
    if num_classes == 2:
        outputs = layers.Dense(1, activation='sigmoid', name='predictions')(x)
        loss = BinaryCrossentropy()
        metrics = [
            BinaryAccuracy(name='accuracy'),    
            Recall(name='Sensitivity'),            
            AUC(name='auc'),                   
            SpecificityAtSensitivity(0.5) ]
    else:
        outputs = layers.Dense(num_classes, activation='softmax', name='predictions')(x)
        loss = 'sparse_categorical_crossentropy'
        metrics = ['accuracy']
    model = Model(inputs=[input_1, input_2, input_3], outputs=outputs, name=f'{backbone_name}_multi_input_bilstm_classifier')
    return model, loss, metrics, preprocess_input

## A separate instance of the same backbone feature extractor will be used for each plane: axial, coronal, and sagittal.
## Backbone network can be: [resnet50, vgg16, densenet121, inceptionv3, efficientnetb7]
## View: axial_coronal_sagittal
## Number of time_steps can be : one, two, three, four
## n_splits N-fold cross-validation.

In [None]:

backbone = 'vgg16'
time_steps = "three" # one, two, three, four
lr = 0.00001
n_splits = 7
epochs = 200 
batch_size = 8



kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)
# Read the dataset
cn_patients = np.read('./CN.npy') # CN patients
progressed_to_AD = np.read('./pAD.npy') # Progressed to AD patients
cn_labels = np.read('./CN_labels.npy')
progressed_to_AD_labels = np.read('./pAD_labels.npy')
y_labels = concatenated_y_ensor = tf.concat([CN_y, AD_y], axis=0)
X_data = concatenated_y_ensor = tf.concat([CN, AD], axis=0)
returned_data = load_data(X_data, time_steps, view)

X = tf.convert_to_tensor(X_data)
X_numpy = X.numpy() 
Y = tf.convert_to_tensor(y_labels)
Y_numpy = Y.numpy()


histories = []
test_history = []
kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)

for fold, (train_val_idx, test_idx) in enumerate(kf.split(X_numpy, Y_numpy), 1):
    print(f"\nFold {fold}")

    X_train_val, X_test = X_numpy[train_val_idx], X_numpy[test_idx]
    y_train_val, y_test = Y_numpy[train_val_idx], Y_numpy[test_idx]

    X_train, X_val, y_train, y_val = train_test_split(
        X_train_val, y_train_val, test_size=0.2, stratify=y_train_val, 
        random_state=fold  )
    
    train_axial_coronal_sagittal = load_data(X_train, time_steps)
    valid_axial_coronal_sagittal = load_data(X_val, time_steps)

    test_axial_coronal_sagittal = load_data(X_test, time_steps)
    
    input_shape = (axial_coronal_sagittal[0].shape[1], 110, 110, 1)
    print("input shape", input_shape)

    model, loss, metrics, preprocess_input = build_model(backbone, input_shape, input_shape, input_shape, num_classes=2 )

    optimizer = Adam(learning_rate=lr)
    model.compile(optimizer=optimizer, loss=loss, metrics=metrics)

    print("Starting training...")
    history = model.fit(train_axial_coronal_sagittal, y_train, validation_data=(valid_axial_coronal_sagittal, y_val),
                        epochs=epochs, batch_size=batch_size,
                        callbacks=[early_stop], verbose=1)
    
    test_loss, *test_metrics = model.evaluate(test_axial_coronal_sagittal, y_test, verbose=0)
    print(f"Test results - Loss: {test_loss}, Metrics: {test_metrics}")
    histories.append(history)
    test_history.append(test_metrics)
    fold += 1

In [None]:
# Convert to NumPy array
history_array = np.array(test_history)
avg_metrics = np.mean(history_array, axis=0)
metric_names = ['mean-Accuracy', 'mean-Sensitivity', 'mean-Specificity', 'mean-AUC']
for name, value in zip(metric_names, avg_metrics):
    print(f"{name}: {value:.4f}")