
## This code corresponds to Experiment 1 described in the manuscript. The primary objective of this experiment is to assess the contribution and diagnostic value of each individual anatomical MRI plane. The planes are axial, sagittal, and coronal, derived from longitudinal 3D MRI across four time steps. The details regarding the dataset preparation steps are described in the manuscript. 


In [24]:
import os
import sys
import argparse
import numpy as np
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'

import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers, Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.metrics import BinaryAccuracy, Precision, Recall, AUC, SpecificityAtSensitivity
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
import tensorflow.keras.backend as K
from sklearn.model_selection import KFold
from sklearn.model_selection import KFold, train_test_split
from classification_models_3D.kkeras import Classifiers

In [2]:
def get_backbone_model(backbone_name, input_shape):
    """Get the specified backbone model"""
    backbone_models = {
        'resnet50': 'resnet50',
        'vgg16': 'vgg16', 
        'densenet121': 'densenet121',
        'inceptionv3': 'inceptionv3',
        'efficientnetb7': 'efficientnetb7'}
    if backbone_name not in backbone_models:
        raise ValueError(f"Unsupported backbone: {backbone_name}")
    BackboneModel, preprocess_input = Classifiers.get(backbone_models[backbone_name])
    backbone = BackboneModel(input_shape=input_shape, weights=None, include_top=False)
    return backbone, preprocess_input

In [3]:
class F1Score(tf.keras.metrics.Metric):
     def __init__(self, name='f1_score', **kwargs):
        super(F1Score, self).__init__(name=name, **kwargs)
        self.precision_metric = Precision()
        self.recall_metric = Recall()
    def update_state(self, y_true, y_pred, sample_weight=None):
        self.precision_metric.update_state(y_true, y_pred, sample_weight)
        self.recall_metric.update_state(y_true, y_pred, sample_weight)
    def result(self):
        precision = self.precision_metric.result()
        recall = self.recall_metric.result()
        return 2 * ((precision * recall) / (precision + recall + K.epsilon()))
    def reset_state(self):
        self.precision_metric.reset_state()
        self.recall_metric.reset_state()
def specificity(y_true, y_pred):
    neg_y_true = 1 - y_true
    neg_y_pred = 1 - y_pred
    fp = K.sum(neg_y_true * y_pred)
    tn = K.sum(neg_y_true * neg_y_pred)
    specificity = tn / (tn + fp + K.epsilon())
    return specificity
early_stop = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True, verbose=1)

In [4]:
def load_data(X_data, timesteps, view):
    view_map = {'coronal': 0, 'axial': 1, 'sagittal': 2}
    view_option = view_map[view]
    timestep_map = {'one': 1, 'two': 2, 'three': 3, 'four': 4}
    n_timesteps = timestep_map[timesteps]
    message = ''
    data_formate = 0
    if view_option == 0: # coronal view
        coronal_view = X_data[:, 0:63, :, :, :]
        if n_timesteps == 1:
            ts_one = coronal_view[:, 0:15, :, :, :]
            message = 'Data for Sinlge time-step coronal view returned.'
            print(message)
            data_formate = ts_one
        elif n_timesteps == 2:
            ts_two = coronal_view[:, 0:31, :, :, :]
            message = 'Data for Two time-steps coronal view returned.'
            print(message)
            data_formate = ts_two
        elif n_timesteps == 3:
            ts_three = coronal_view[:, 0:47, :, :, :]
            message = 'Data for Three time-steps coronal view returned.'
            print(message)
            data_formate = ts_three
        else:
            ts_four = coronal_view
            message = 'Data for Four time-steps coronal view returned.'
            print(message)
            data_formate = ts_four              
    elif view_option == 1: # axial view
        axial_view = X_data[:, 64:127, :, :,  :]   
        # print(axial_view)
        if n_timesteps == 1:
            ts_one = axial_view[:, 0:15, :, :, :]            
            message = 'Data for Single time-step axial view returned.'
            print(message)
            data_formate = ts_one
        elif n_timesteps == 2:
            ts_two = axial_view[:, 0:31, :, :, :]
            message = 'Data for Two time-steps axial view returned.'
            print(message)
            data_formate = ts_two
        elif n_timesteps == 3:
            ts_three = axial_view[:, 0:47, :, :, :]
            message = 'Data for Three time-steps axial view returned.'
            print(message)
            data_formate = ts_three            
        else:
            ts_four = axial_view
            message = 'Data for Four time-steps axial view returned.'
            print(message)
            data_formate = ts_four
    else: # sagittal view
        sagittal_view = X_data[:, 128:192, :, :, :]
        if n_timesteps == 1:
            ts_one = sagittal_view[:, 0:15, :, :, :]
            message = 'Data for Single time-step sagittal view returned.'
            print(message)
            data_formate = ts_one
        elif n_timesteps == 2:
            ts_two = sagittal_view[:,  0:31, :, :,:]
            message = 'Data for Two time-steps sagittal view returned.'
            print(message)
            data_formate = ts_two
        elif n_timesteps == 3:
            ts_three = sagittal_view[:, 0:47, :, :, :]
            message = 'Data for Three time-steps sagittal view returned.'
            print(message)
            data_formate = ts_three            
        else:
            ts_four = sagittal_view
            message = 'Data for Four time-steps sagittal view returned.'
            print(message)
            data_formate = ts_four     
    return data_formate

In [5]:
def train_model(args):
    X_train, X_test, y_train, y_test = train_test_split(
        X_numpy, Y_numpy, test_size=0.2, random_state=42, stratify=Y_numpy     )
    X_train, X_val, y_train, y_val = train_test_split(
        X_train, y_train, test_size=0.2, random_state=42, stratify=y_train    )
    X_train_tf = tf.convert_to_tensor(X_train)
    y_train_tf = tf.convert_to_tensor(y_train)

    X_valid_tf = tf.convert_to_tensor(X_val)
    y_valid_tf = tf.convert_to_tensor(y_val)

    X_test_tf = tf.convert_to_tensor(X_test)
    y_test_tf = tf.convert_to_tensor(y_test)
    
    print(f"Train samples: {len(X_train_tf)}")
    print(f"Validation samples: {len(X_valid_tf)}")
    print(f"Test samples: {len(X_test_tf)}")
    input_shape = (X_train.shape[1], 110, 110, 1)
    model, loss, metrics, preprocess_input = create_model(backbone, input_shape)
    optimizer = Adam(learning_rate=0.0001)
    model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
    print("Starting training...")
    history = model.fit(X_train, y_train, batch_size=8, epochs=2, validation_data=(X_val, y_val), callbacks=[early_stop], verbose=1)

    return model, history


## Backbone can be one of the five predefined 3D CNN models: [resnet50, vgg16, densenet121, inceptionv3, efficientnetb7].
## View can be one of the following options: [axial, coronal, sagittal].
## Timesteps represent the number of longitudinal time steps the 3D model is trained on. The available options are: [one, two, three, four].
## n_splits represents the cross-validation setup for N folds. In this study, we report results for 10-fold cross-validation.



In [None]:

backbone = 'vgg16'   # resnet50, vgg16, densenet121, inceptionv3, and efficientnetb7
view = 'axial' #     # axial, coronal, sagittal
time_steps = 'two'   # one, two, three, four
lr = 0.00001
n_splits = 10
epochs = 120
batch_size = 16


# Read the dataset
cn_patients = np.read('./CN.npy') # CN patients
progressed_to_AD = np.read('./pAD.npy') # Progressed to AD patients
cn_labels = np.read('./CN_labels.npy')
progressed_to_AD_labels = np.read('./pAD_labels.npy')
y_labels = concatenated_y_ensor = tf.concat([CN_y, AD_y], axis=0)
X_data = concatenated_y_ensor = tf.concat([CN, AD], axis=0)


returned_data = load_data(X_data, time_steps, view)
X = tf.convert_to_tensor(returned_data)
X_numpy = X.numpy() 
Y = tf.convert_to_tensor(y_labels)
Y_numpy = Y.numpy()

histories = []
test_history = []
kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)

for fold, (train_val_idx, test_idx) in enumerate(kf.split(X_numpy, Y_numpy), 1):
    print(f"\nFold {fold}")

    X_train_val, X_test = X_numpy[train_val_idx], X_numpy[test_idx]
    y_train_val, y_test = Y_numpy[train_val_idx], Y_numpy[test_idx]

    X_train, X_val, y_train, y_val = train_test_split(
        X_train_val, y_train_val, test_size=0.2, stratify=y_train_val, random_state=fold )

    print(f"Train: {X_train.shape}, Val: {X_val.shape}, Test: {X_test.shape}")
    input_shape = (X_train.shape[1], 110, 110, 1)   
    model, loss, metrics, preprocess_input = build_model(backbone, input_shape, num_classes=2 )  
    optimizer = Adam(learning_rate=lr)
    model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
    print("Starting training...")
    history = model.fit(X_train, y_train, validation_data=(X_val, y_val), epochs=epochs, batch_size=batch_size,callbacks=[early_stop], verbose=1)

    test_loss, *test_metrics = model.evaluate(X_test, y_test, verbose=0)
    print(f"Test results - Loss: {test_loss}, Metrics: {test_metrics}")
    histories.append(history)
    test_history.append(test_metrics)
    fold += 1

In [None]:
# Convert to NumPy array
history_array = np.array(test_history)
# Compute mean for each column (metric)
avg_metrics = np.mean(history_array, axis=0)
# Print with labels
metric_names = ['mean-Accuracy', 'mean-Sensitivity', 'mean-Specificity', 'mean-AUC']
for name, value in zip(metric_names, avg_metrics):
    print(f"{name}: {value:.4f}")