## The following code refers to Experiment 2 in the manuscript, we explored the potential improvement in the overall performance of AD progression detection by fusing information from multiple MRI planes during the training process.  

In [None]:
import os
import sys
import argparse
import numpy as np
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.metrics import BinaryAccuracy, Precision, Recall, AUC, SpecificityAtSensitivity
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
import tensorflow.keras.backend as K
from sklearn.model_selection import KFold
from sklearn.model_selection import KFold, train_test_split
from classification_models_3D.kkeras import Classifiers

In [2]:
def get_backbone_model(backbone_name, input_shape):
    """Get the specified backbone model"""
    backbone_models = {'resnet50': 'resnet50',
                       'vgg16': 'vgg16', 'densenet121': 'densenet121',
                       'inceptionv3': 'inceptionv3',
                       'efficientnetb7': 'efficientnetb7'}
    if backbone_name not in backbone_models:
        raise ValueError(f"Unsupported backbone: {backbone_name}")
    BackboneModel, preprocess_input = Classifiers.get(backbone_models[backbone_name])
    backbone = BackboneModel(input_shape=input_shape, weights=None, include_top=False)
    return backbone, preprocess_input

In [3]:
class F1Score(tf.keras.metrics.Metric):
    def __init__(self, name='f1_score', **kwargs):
        super(F1Score, self).__init__(name=name, **kwargs)
        self.precision_metric = Precision()
        self.recall_metric = Recall()
    def update_state(self, y_true, y_pred, sample_weight=None):
        self.precision_metric.update_state(y_true, y_pred, sample_weight)
        self.recall_metric.update_state(y_true, y_pred, sample_weight)
    def result(self):
        precision = self.precision_metric.result()
        recall = self.recall_metric.result()
        return 2 * ((precision * recall) / (precision + recall + K.epsilon()))
    def reset_state(self):
        self.precision_metric.reset_state()
        self.recall_metric.reset_state()
def specificity(y_true, y_pred):
    neg_y_true = 1 - y_true
    neg_y_pred = 1 - y_pred
    fp = K.sum(neg_y_true * y_pred)
    tn = K.sum(neg_y_true * neg_y_pred)
    specificity = tn / (tn + fp + K.epsilon()) 
    return specificity
early_stop = EarlyStopping( monitor='val_loss', patience=5, restore_best_weights=True, verbose=1 )

In [4]:
# def load_data(data_path, view, timesteps):
def load_data(X_data, timesteps, view):
    view_map = { 'axial_coronal': 1, 'axial_coronal_sagittal': 2}
    view_option = view_map[view]
    timestep_map = {'one': 1, 'two': 2, 'three': 3, 'four': 4}
    n_timesteps = timestep_map[timesteps]
    message = ''
    data_formate = 0
    if view_option == 1: # axial_coronal
        axial_coronal_views = X_data[:, 0:128, :, :,  :]          
        axial = axial_coronal_views[:, 0:64, :, :,  :]
        coronal = axial_coronal_views[:, 64:128, :, :, :]        
        if n_timesteps == 1:
            axial = axial[:, 0:16, :, :,  :] # axial
            coronal = coronal[:, 0:16, :, :, :] # coronal
            axial_coronal_all_steps = tf.concat([axial, coronal], axis=1)
            message = 'Data for Single time-step for axial_coronal view returned.'
            print(message)
            data_formate = axial_coronal_all_steps
            
        elif n_timesteps == 2:
            axial = axial[:, 0:32, :, :,  :] # axial
            coronal = coronal[:, 0:32, :, :, :] # coronal
            axial_coronal_all_steps = tf.concat([axial, coronal], axis=1)
            message = 'Data for two time-steps for axial_coronal view returned.'
            print(message)
            data_formate = axial_coronal_all_steps
        elif n_timesteps == 3:
            axial = axial[:, 0:48, :, :,  :] # axial
            coronal = coronal[:, 0:48, :, :, :] # coronal
            axial_coronal_all_steps = tf.concat([axial, coronal], axis=1)
            message = 'Data for three time-steps for axial_coronal view returned.'
            print(message)
            data_formate = axial_coronal_all_steps
        else:
            axial = axial[:, 0:64, :, :,  :] # axial
            coronal = coronal[:, 0:64, :, :, :] # coronal
            axial_coronal_all_steps = tf.concat([axial, coronal], axis=1)
            message = 'Data for four time-steps for axial_coronal view returned.'
            print(message)
            data_formate = axial_coronal_all_steps   
    else: # axial_coronal_sagittal view
        axial_coronal_views = X_data[:, 0:192, :, :,  :]
        axial = axial_coronal_views[:, 0:64, :, :,  :]
        coronal = axial_coronal_views[:, 64:128, :, :, :]
        sagittal = axial_coronal_views[:, 128:192, :, :, :] 
        if n_timesteps == 1:
            axial = axial[:, 0:16, :, :,  :] # axial
            coronal = coronal[:, 0:16, :, :, :] # coronal
            sagittal = sagittal[:, 0:16, :, :, :] # sagittal
            axial_coronal_all_steps = tf.concat([axial, coronal, sagittal], axis=1)
            message = 'Data for Single time-steps for axial_coronal_sagittal view returned.'
            print(message)
            data_formate = axial_coronal_all_steps
        elif n_timesteps == 2:
            axial = axial[:, 0:32, :, :,  :] # axial
            coronal = coronal[:, 0:32, :, :, :] # coronal
            sagittal = sagittal[:, 0:32, :, :, :] # sagittal
            axial_coronal_all_steps = tf.concat([axial, coronal, sagittal], axis=1)
            message = 'Data for Two time-steps for axial_coronal_sagittal view returned.'
            print(message)
            data_formate = axial_coronal_all_steps  
        elif n_timesteps == 3:
            axial = axial[:, 0:48, :, :,  :] # axial
            coronal = coronal[:, 0:48, :, :, :] # coronal
            sagittal = sagittal[:, 0:48, :, :, :] # sagittal
            axial_coronal_all_steps = tf.concat([axial, coronal, sagittal], axis=1)
            message = 'Data for Three time-steps for axial_coronal_sagittal view returned.'
            print(message)
            data_formate = axial_coronal_all_steps           
        else:
            axial = axial[:, 0:64, :, :,  :] # axial
            coronal = coronal[:, 0:64, :, :, :] # coronal
            sagittal = sagittal[:, 0:64, :, :, :] # sagittal
            axial_coronal_all_steps = tf.concat([axial, coronal, sagittal], axis=1)
            message = 'Data for Three time-steps for axial_coronal_sagittal view returned.'
            print(message)
            data_formate = axial_coronal_all_steps           
    return data_formate

In [5]:
def build_model(backbone_name, input_shape, num_classes=2):
    backbone, preprocess_input = get_backbone_model(backbone_name, input_shape)
    inputs = keras.Input(shape=input_shape)
    x = inputs
    features = backbone(x)
    gap = layers.GlobalAveragePooling3D(name='global_avg_pool')(features)
    gap_size = gap.shape[-1]
    if gap_size > 1024:
        x = layers.Dense(512, activation='relu', name='dense_reduce')(gap)
        x = layers.Dropout(0.3, name='dropout_dense')(x)
    elif gap_size < 128:
        x = layers.Dense(256, activation='relu', name='dense_expand')(gap)
        x = layers.Dropout(0.3, name='dropout_dense')(x) 
    else:
        x = layers.Dropout(0.5, name='dropout')(gap) 
    if num_classes == 2:
        outputs = layers.Dense(1, activation='sigmoid', name='predictions')(x)
        loss = BinaryCrossentropy()
        metrics = [
            BinaryAccuracy(name='accuracy'),     
            Recall(name='Sensitivity'),              
            AUC(name='auc'),                    
            SpecificityAtSensitivity(0.5) ]
    else:
        outputs = layers.Dense(num_classes, activation='softmax', name='predictions')(x)
        loss = 'sparse_categorical_crossentropy'
        metrics = ['accuracy']
    model = Model(inputs, outputs, name=f'{backbone_name}_classifier')
    return model, loss, metrics, preprocess_input

## Backbone network can be one of the five predefined 3D CNN models: [resnet50, vgg16, densenet121, inceptionv3, efficientnetb7].¶
## View: can be one of the following two options: [axial_coronal, axial_coronal_sagittal].
## time_steps:  represent the number of longitudinal time steps the 3D model is trained on. The available options are: [one, two, three, four].
## n_splits represents the cross-validation setup for N folds. In this study, we report results for 10-fold cross-validation.

In [None]:
backbone = 'inceptionv3'
view = 'axial_coronal_sagittal'    # axial_coronal, axial_coronal_sagittal
time_steps = "three"
lr = 0.00001
n_splits = 10
epochs = 120
batch_size = 16


# Read the dataset
cn_patients = np.read('./CN.npy') # CN patients
progressed_to_AD = np.read('./pAD.npy') # Progressed to AD patients
cn_labels = np.read('./CN_labels.npy')
progressed_to_AD_labels = np.read('./pAD_labels.npy')
y_labels = concatenated_y_ensor = tf.concat([CN_y, AD_y], axis=0)
X_data = concatenated_y_ensor = tf.concat([CN, AD], axis=0)

returned_data = load_data(X_data, time_steps, view)
X = tf.convert_to_tensor(returned_data)
X_numpy = X.numpy() 
Y = tf.convert_to_tensor(y_labels)
Y_numpy = Y.numpy()

histories = []
test_history = []
kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)

for fold, (train_val_idx, test_idx) in enumerate(kf.split(X_numpy, Y_numpy), 1):
    print(f"\nFold {fold}")
    X_train_val, X_test = X_numpy[train_val_idx], X_numpy[test_idx]
    y_train_val, y_test = Y_numpy[train_val_idx], Y_numpy[test_idx]
    X_train, X_val, y_train, y_val = train_test_split(
        X_train_val, y_train_val,
        test_size=0.2,  
        stratify=y_train_val,  
        random_state=fold  )
    input_shape = (X_train.shape[1], 110, 110, 1)
    print("input shape", input_shape)

    model, loss, metrics, preprocess_input = build_model(backbone, input_shape, num_classes=2 )
    optimizer = Adam(learning_rate=lr)
    model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
    history = model.fit(X_train, y_train, validation_data=(X_val, y_val), epochs=epochs, batch_size=batch_size,callbacks=[early_stop], verbose=1)
    print("Starting training...")
    history = model.fit(X_train, y_train, validation_data=(X_val, y_val), epochs=epochs, batch_size=batch_size,callbacks=[early_stop], verbose=1)
    test_loss, *test_metrics = model.evaluate(X_test, y_test, verbose=0)
    print(f"Test results - Loss: {test_loss}, Metrics: {test_metrics}")
    histories.append(history)
    test_history.append(test_metrics)
    fold += 1

In [None]:
# Convert to NumPy array
history_array = np.array(test_history)
# Compute mean for each (metric)
avg_metrics = np.mean(history_array, axis=0)
# Print with labels
metric_names = ['mean-Accuracy', 'mean-Sensitivity', 'mean-Specificity', 'mean-AUC']
for name, value in zip(metric_names, avg_metrics):
    print(f"{name}: {value:.4f}")