## Deep neural network for ASD classification using resting-state fMRI

This notebook evaluate a deep neural network for ASD diagnosis using functional time series data from brain regions of interest. The used resting-state fMRI data from the ABIDE dataset were preprocessed by the **Preprocessed Connectome Project (PCP)** using four pipelines, involving 1100 subjects from multiple international sites.

### Configure the loading data

 The parameters necessary for loading our neuroimaging data are defined. The `pipeline` and `atlas` used for preprocessing and ROIs extraction are specified. Additionally, we list all neuroimaging sites available in the dataset and select those we want to include in the analysis by the parameter `sites`.

In [1]:
# Preprocessing pipeline
pipeline = 'cpac'  

# List of ROIs brain atlas
atlases = ['rois_cc200', 'rois_aal']

# Phenotypic data
phenotypic = 'all_cases'

# List of all available neuroimaging sites in the dataset
all_sites = [
    'caltech', 'cmu', 'kki', 'leuven_1', 'leuven_2', 'max_mun', 'nyu', 
    'ohsu', 'olin', 'pitt', 'sbl', 'sdsu', 'stanford', 'trinity', 
    'ucla_1', 'ucla_2', 'um_1', 'um_2', 'usm', 'yale'
]

# Sites include in the analysis
sites = all_sites

# Testing site
test_site = 'yale'

### Data loading function

Definition of the `load_atlas_data(pipeline, atlas, sites, phenotypic)` function to retrieve subject time series and diagnostic labels from each neuroimaging site in `sites`. This function reads `phenotypic` information from CSV files, then loads the time series data for each subject. Also handle potential issues, such as missing files or NaN values, to ensure data integrity before analysis.

In [2]:
import os
import csv
import numpy as np


def load_atlas_data(pipeline, atlases, sites, phenotypic):
    """
    Loads time series and diagnostic labels from neuroimaging data files for the specified sites.
    
    Parameters:
        pipeline (str): Preprocessing pipeline used for the data.
        atlas (str): Atlas defining regions of interest.
        sites (list of str): List of site names to load data from.

    Returns:
        atlas_time_series (dict): Contains time series data for each site.
        atlas_labels (dict): Contains diagnostic labels for each site.
    """

    atlas_time_series = []  # Dictionary to store time series data for each site
    atlas_labels = {}  # Dictionary to store labels for each site
    
    
    print("Load labes:")
    for site in sites:
        # Define path for phenotypic data for the current site
        phenotypic_path = f"data/phenotypic/{phenotypic}/{site}/phenotypic.csv"

        try:
            with open(phenotypic_path, 'r') as file:
                reader = csv.DictReader(file)
                site_labels = []  # List to store labels for each subject at the site

                for row in reader:
                    file_id = row['file_id']  # Unique subject identifier
                    dx_group = row['dx_group']  # Diagnostic group (ASD=1, Control=0)

                    # Define path for the time series data file
                    data_file_path = os.path.join(f"data/{pipeline}/{atlases[0]}/{site}", f"{file_id}_{atlases[0]}.1D")

                    # Check if the data file exists
                    if not os.path.exists(data_file_path):
                        print(f"File Not Found Error: Data file not found at path {data_file_path}")
                        continue
                    
                    data = np.loadtxt(data_file_path)

                    # Check for NaN values and add time series to the site list
                    if np.isnan(data).any():
                        print(f"Value Error: NaN value found for subject {file_id}")
                    else:            
                        site_labels.append(1 if dx_group == '1' else 0)  # Assign 1 for ASD, 0 for control

                # Store loaded data for the current site in the dictionaries
                atlas_labels[site] = np.array(site_labels)
                print(f"Loaded labels {len(site_labels)} subjects from site {site}.")
                
        except FileNotFoundError:
            print(f"File Not Found Error: Phenotypic data not found for site {site}")

    for atlas in atlases:
        atlas_data = {}
        print(f"Data for {atlas} Atlas:")
        for site in sites:
            # Define path for phenotypic data for the current site
            phenotypic_path = f"data/phenotypic/{phenotypic}/{site}/phenotypic.csv"

            try:
                with open(phenotypic_path, 'r') as file:
                    reader = csv.DictReader(file)
                    site_time_series = []  # List to store time series for each subject at the site
                
                    for row in reader:
                        file_id = row['file_id']  # Unique subject identifier
                    
                        # Define path for the time series data file
                        data_file_path = os.path.join(f"data/{pipeline}/{atlas}/{site}", f"{file_id}_{atlas}.1D")

                        # Check if the data file exists
                        if not os.path.exists(data_file_path):
                            print(f"File Not Found Error: Data file not found at path {data_file_path}")
                            continue
                        
                        data = np.loadtxt(data_file_path)

                        # Check for NaN values and add time series to the site list
                        if np.isnan(data).any():
                            print(f"Value Error: NaN value found for subject {file_id}")
                        else:
                            site_time_series.append(data)
                    
                    # Store loaded data for the current site in the dictionaries
                    atlas_data[site] = site_time_series
                    print(f"Loaded {len(site_time_series)} subjects from site {site}.")
                    
            except FileNotFoundError:
                print(f"File Not Found Error: Phenotypic data not found for site {site}")
    
        atlas_time_series.append(atlas_data)
    
    return atlas_time_series, atlas_labels

Load data to be used in the analysis based on specified parameters.

In [3]:
atlas_time_series, atlas_labels = load_atlas_data(pipeline, atlases, sites, phenotypic)

Load labes:
Loaded labels 38 subjects from site caltech.
Loaded labels 27 subjects from site cmu.
Loaded labels 55 subjects from site kki.
Loaded labels 29 subjects from site leuven_1.
Loaded labels 35 subjects from site leuven_2.
Loaded labels 57 subjects from site max_mun.
Loaded labels 184 subjects from site nyu.
Loaded labels 28 subjects from site ohsu.
Loaded labels 36 subjects from site olin.
Loaded labels 57 subjects from site pitt.
Loaded labels 30 subjects from site sbl.
Loaded labels 36 subjects from site sdsu.
Loaded labels 40 subjects from site stanford.
Loaded labels 49 subjects from site trinity.
Loaded labels 73 subjects from site ucla_1.
Loaded labels 26 subjects from site ucla_2.
Loaded labels 108 subjects from site um_1.
Loaded labels 35 subjects from site um_2.
Loaded labels 101 subjects from site usm.
Loaded labels 56 subjects from site yale.
Data for rois_cc200 Atlas:
Loaded 38 subjects from site caltech.
Loaded 27 subjects from site cmu.
Loaded 55 subjects from si

### Tangent space embedding

This method allows the translation of connectivity matrices from fMRI data into a form that is compatible with Euclidean machine learning techniques while preserving the important geometric properties of the data. This technique is particularly useful when analyzing covariance or correlation matrices in tasks involving brain connectivity and classification of neurological conditions.

The workflow used for this notebooks involves two main steps:

**Estimate the reference tangent space**: Calculate the tangent space projection based on the mean covariance matrix of a training population. This establishes the "reference space" against which individual test subjects can later be projected.

**Project subjects onto the reference space**: Using the precomputed reference tangent space from the population, can be project the covariance matrix of a new subjects onto this space. This will yield a tangent space connectivity matrix for the subjects that aligns with those of the population.

#### Create training population function

Definition of the `create_population(time_series_data)` function for combine the time series into array. To maintain a separate testing set, we exclude the `test_site` site data from the main population data used for estimate the reference tangent space.

In [4]:
def create_population(time_series_data):
    # Initialize an empty list for the population data 
    population_data = []

    # Loop through the time series data
    for item in time_series_data:
        # Extend each item
        population_data.extend(item)

    print(f"Total subjects in population data: {len(population_data)}")
    return population_data

#### Estimate the reference tangent space

Calculate the tangent space based on the mean covariance matrix of a training population dataset. This creates a "reference space" that reflects the average connectivity patterns across the population.

#### Function to estimating tangent space functional connectivity

Definition of the `estimate_tangent_space(data)` function calculate the tangent space based on the geometric mean covariance matrix of a training population dataset. This creates a "reference space" that reflects the average connectivity patterns across the population.

The tangent space representation of functional connectivity is a powerful tool for analyzing brain connectivity. It allows the comparison of individual functional connectivity matrices in a standardized space, computed relative to a group average matrix.

In [5]:
from nilearn.connectome import ConnectivityMeasure

def estimate_tangent_space(data):
    """
    Estimate the tangent space functional connectivity.

    Parameters:
    -----------
    data : list or ndarray
        List or array of time series data for the training population, where each entry corresponds 
        to a subject's time series (time points x regions).

    Returns:
    --------
    ConnectivityMeasure
        Fitted ConnectivityMeasure object configured for tangent space transformation.
    """
    # Instantiate ConnectivityMeasure for tangent space, vectorizing and discarding the diagonal
    connectivity_measure = ConnectivityMeasure(kind='tangent', vectorize=True, discard_diagonal=True)

    # Fit the measure on the population data to establish a reference tangent space
    connectivity_measure.fit(data)

    return connectivity_measure


### Deep neural network model building function

Definition of the `build_model(input_shape)` function create DNN models with the following architecture:

Input Layer: Takes in the number of features from the input data.

Dense Layer 1: 128 neurons, ReLU activation, with L2 regularization to reduce overfitting.

Dense Layer 2: 64 neurons, ReLU activation, L2 regularization.

Output Layer: A single neuron with sigmoid activation for binary classification.

The model is compiled with the Adam optimizer and binary cross-entropy loss, as we aim to classify subjects into two classes. We also include accuracy as a performance metric to track model performance during training and evaluation.

In [6]:
from keras import layers, models, regularizers

# Define the deep neural network model architecture
def build_model(input_shape):
    """
    Builds and compiles a deep neural network model for binary classification.

    Parameters:
    - input_shape: int, the shape of the input layer, matching the number of features in the dataset

    Returns:
    - model: compiled Keras Sequential model ready for training
    """
    
    model = models.Sequential()

    # Input layer
    model.add(layers.InputLayer(input_shape=input_shape))
    model.add(layers.Dropout(0.2))

    # Hidden layers
    model.add(layers.Dense(128, activation='relu', kernel_regularizer=regularizers.l2(0.001)))
    model.add(layers.Dropout(0.4))

    model.add(layers.Dense(64, activation='relu', kernel_regularizer=regularizers.l2(0.001)))
    model.add(layers.Dropout(0.4))

    # Output layer for binary classification (ASD vs. Healthy)
    model.add(layers.Dense(1, activation='sigmoid'))

    # Compile the model with Adam optimizer and binary cross-entropy loss
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

    return model

### Definition of training callbacks

To optimize training, we set up three callbacks:

**EarlyStopping:** Stops training if validation loss doesn't improve for 10 epochs, preventing overfitting and restoring the best weights.

**ReduceLROnPlateau:** Reduces the learning rate by 50% when validation loss plateaus for 5 epochs, ensuring gradual and effective model convergence.

**ModelCheckpoint:** Saves the model with the best validation loss to 'best_model.keras', allowing easy access to the optimal version of the model.

In [7]:
from keras import callbacks

# Early stopping to prevent overfitting by stopping training when validation loss stops improving
early_stopping = callbacks.EarlyStopping(
    monitor='val_loss',             # Monitor validation loss for early stopping
    patience=10,                    # Stop training if val_loss does not improve for 10 epochs
    restore_best_weights=True       # Restore the model weights from the epoch with the lowest val_loss
)

# Reduce learning rate when the validation loss plateaus
reduce_lr = callbacks.ReduceLROnPlateau(
    monitor='val_loss',             # Monitor validation loss for learning rate reduction
    factor=0.5,                     # Reduce learning rate by a factor of 0.5
    patience=5,                     # Trigger after 5 epochs without improvement in val_loss
    min_lr=1e-5                     # Set a floor on the learning rate to avoid overly small values
)

# Save the best model based on validation loss
checkpoint = callbacks.ModelCheckpoint(
    'best_model.keras',             # Filename for the best model
    monitor='val_loss',             # Monitor validation loss for checkpoint saving
    save_best_only=True             # Only save the model when it achieves a new best val_loss
)

# Callbacks list passed to the model
callbacks_list = [early_stopping, reduce_lr, checkpoint]

### Adjust class balance function

Definition of the `adjust_class_balance(indices, labels)` function for ensures equal representation of all classes by undersampling the majority class(es). This is particularly important in supervised learning, where imbalanced classes can lead to biased models. The function return shuffled list of indices representing a class-balanced subset of the dataset.

In [8]:
import numpy as np

def adjust_class_balance(indices, labels):
    """
    Adjusts the balance of classes by undersampling the majority class.

    Parameters:
    ----------
    indices : list or ndarray
        Indices of the dataset.
    labels : list or ndarray
        Class labels corresponding to the indices.

    Returns:
    -------
    balanced_indices : ndarray
        Indices of the balanced dataset.
    """
    # Class labels to consider
    CLASS_LABELS = [0, 1]

    # Separate indices by class
    class_indices = {label: [idx for idx in indices if labels[idx] == label] for label in CLASS_LABELS}

    # Determine the minimum class count
    min_class_count = min(len(indices) for indices in class_indices.values())

    # Adjust class balance by undersampling the majority class
    balanced_indices = []
    for label, class_list in class_indices.items():
        if len(class_list) > min_class_count:
            sampled_indices = np.random.choice(class_list, size=min_class_count, replace=False)
            balanced_indices.extend(sampled_indices)
        else:
            balanced_indices.extend(class_list)

    # Shuffle the indices for randomization
    np.random.shuffle(balanced_indices)
    return np.array(balanced_indices)

### Stratified cross-validation setup for model training and validation

Set up stratified 10-fold cross-validation for each site (excluding `test_site`) to evaluate model performance across multiple splits. Here’s an overview of the process:

**Stratified k-folds**: StratifiedKFold let to maintain the balance of classes (ASD vs. NC) across each fold, reducing potential bias.

**Fold processing**: For each site, 10 training and validation folds are generated, and indices are stored in the `train_indices` and `val_indices` dictionaries. To ensure class balance after combining all group folds for training data the majority class in each site is undersampling.

**Class balance checks**: For each fold, the balance of ASD and NC samples is shown to confirm each split maintains similar distributions.

In [9]:
from sklearn.model_selection import StratifiedKFold

# Number of cross-validation folds
n_folds = 10

# Dictionaries for save the training and validation indices
train_indices = {}
val_indices = {}

# Perform stratified k-fold cross-validation for each site, excluding 'test_site' for testing
for site in sites:
    if site == test_site:
        continue

    features = atlas_time_series[0][site]
    labels = atlas_labels[site]

    # Initialize StratifiedKFold with shuffle to ensure data randomization
    skf = StratifiedKFold(n_splits=n_folds, shuffle=True)
    
    site_train_indices = []
    site_val_indices = []

    # Loop through each fold in the stratified split
    for fold, (train_idx, val_idx) in enumerate(skf.split(features, labels)):
        print(f"Processing fold #{fold} for site `{site}`")
      
        train_idx = adjust_class_balance(train_idx, labels)
        val_idx = adjust_class_balance(val_idx, labels)

        # Append training and validation indices for each fold
        site_train_indices.append(np.array(train_idx))
        site_val_indices.append(np.array(val_idx))
        
        # Print class distribution for training and validation sets for each fold
        print(f"Balance of classes in training -> ASD: {np.count_nonzero(labels[site_train_indices[fold]] == 1)} and TC: {np.count_nonzero(labels[site_train_indices[fold]] == 0)}")
        
        print(f"Balance of classes in validation -> ASD: {np.count_nonzero(labels[site_val_indices[fold]] == 1)} and TC: {np.count_nonzero(labels[site_val_indices[fold]] == 0)}")

    # Store indices for each fold in the dictionaries
    train_indices[site] = site_train_indices
    val_indices[site] = site_val_indices


Processing fold #0 for site `caltech`
Balance of classes in training -> ASD: 17 and TC: 17
Balance of classes in validation -> ASD: 2 and TC: 2
Processing fold #1 for site `caltech`
Balance of classes in training -> ASD: 17 and TC: 17
Balance of classes in validation -> ASD: 2 and TC: 2
Processing fold #2 for site `caltech`
Balance of classes in training -> ASD: 17 and TC: 17
Balance of classes in validation -> ASD: 2 and TC: 2
Processing fold #3 for site `caltech`
Balance of classes in training -> ASD: 17 and TC: 17
Balance of classes in validation -> ASD: 2 and TC: 2
Processing fold #4 for site `caltech`
Balance of classes in training -> ASD: 17 and TC: 17
Balance of classes in validation -> ASD: 2 and TC: 2
Processing fold #5 for site `caltech`
Balance of classes in training -> ASD: 17 and TC: 17
Balance of classes in validation -> ASD: 2 and TC: 2
Processing fold #6 for site `caltech`
Balance of classes in training -> ASD: 17 and TC: 17
Balance of classes in validation -> ASD: 2 an

###  Evaluation metrics function

Definition of the `calculate_metrics(y_true, y_pred, y_pred_prob)` function to evaluate the performance of a binary classification model. The function computes several key metrics using true labels (y_true), predicted labels (y_pred), and predicted probabilities (y_pred_prob). These metrics include:

**Accuracy**: The proportion of correct predictions among all predictions.

**Sensitivity (Recall)**: The ability to correctly identify positive cases.

**Precision**: The proportion of true positive predictions among all positive predictions.
**Specificity**: The ability to correctly identify negative cases.

**Area Under the Curve (AUC)**: Measures the ability of the model to distinguish between classes.
**Confusion Matrix**: Summarizes true/false positives and negatives.

The function handles edge cases where division by zero might occur, ensuring stability in metric computation. This is essential for evaluating the performance of deep learning models in ASD classification tasks.

In [10]:
from sklearn.metrics import confusion_matrix, roc_auc_score

def calculate_metrics(y_true, y_pred, y_pred_prob):
    """
    Calculate key evaluation metrics for binary classification tasks.

    Parameters:
    -----------
    y_true : array-like
        Ground truth (true labels).
    y_pred : array-like
        Predicted labels (binary).
    y_pred_prob : array-like
        Predicted probabilities for the positive class.

    Returns:
    --------
    accuracy : float
        Proportion of correctly predicted instances.
    sensitivity : float
        True positive rate (recall for the positive class).
    precision : float
        Precision for the positive class.
    specificity : float
        True negative rate (recall for the negative class).
    auc : float
        Area Under the Receiver Operating Characteristic Curve (ROC AUC).
    cm : ndarray
        Confusion matrix as a NumPy array.
    """
    # Compute confusion matrix and unpack values
    cm = confusion_matrix(y_true, y_pred)
    tn, fp, fn, tp = cm.ravel()
    
    # Calculate metrics with safeguards against division by zero
    accuracy = (tp + tn) / (tp + tn + fp + fn)
    sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0
    auc = roc_auc_score(y_true, y_pred_prob)
    
    return accuracy, sensitivity, precision, specificity, auc, cm


### Function to print metrics
Definition of the `print_metrics(split, dataset_type, accuracy, sensitivity, precision, specificity, auc, cm)` function, which displays key evaluation metrics for a specific dataset split and type. The function is designed to enhance interpretability during model evaluation by printing the following metrics in a well-formatted manner:

**Accuracy**: The overall correctness of predictions.

**Sensitivity (Recall)**: The ability to correctly detect positive cases.

**Precision**: The reliability of positive predictions.

**Specificity**: The ability to correctly detect negative cases.

**AUC-ROC Score**: The model’s ability to distinguish between positive and negative classes.

**Confusion Matrix**: A tabular summary of prediction outcomes (true positives, false positives, etc.).

The parameters allow for flexible use across various dataset types (e.g., training, validation, test) and splits during cross-validation. For example, split helps track metrics for a specific fold in k-fold cross-validation, while dataset_type differentiates between datasets. The metrics are displayed as percentages for better readability.

In [11]:
def print_metrics(split, dataset_type, accuracy, sensitivity, precision, specificity, auc, cm):
    """
    Display evaluation metrics for a specific data split and dataset type.

    Parameters:
    -----------
    split : int
        The current split index (zero-based).
    dataset_type : str
        The type of dataset (e.g., "training", "validation", "test").
    accuracy : float
        Proportion of correctly predicted instances.
    sensitivity : float
        True positive rate (recall for the positive class).
    precision : float
        Precision for the positive class.
    specificity : float
        True negative rate (recall for the negative class).
    auc : float
        Area Under the Receiver Operating Characteristic Curve (ROC AUC).
    cm : ndarray
        Confusion matrix as a NumPy array.

    Returns:
    --------
    None
    """
    print(f"\n{dataset_type.capitalize()} Metrics for Split {split + 1}:")
    print(f"  Accuracy: {accuracy * 100:.2f}%")
    print(f"  Sensitivity (Recall): {sensitivity * 100:.2f}%")
    print(f"  Precision: {precision * 100:.2f}%")
    print(f"  Specificity: {specificity * 100:.2f}%")
    print(f"  AUC-ROC Score: {auc * 100:.2f}%")
    print(f"  Confusion Matrix:\n{cm}")

### Saving tangent spaces

Prepare connectivity matrices in the tangent space representation for each fold of cross-validation.

**Steps**:

**Initialization**: The `connectivity_list` is created to store tangent space representations for each split.

**Cross-validation Loop**:
Iterates through each fold defined by n_folds.
Initializes lists to store training and validation time series data (`X_train_time_series`, `X_val_time_series`) and their corresponding labels (`y_train`, `y_val`).

**Site-Specific Aggregation**:
For each site, the function aggregates data while excluding the predefined `test_site` for external testing.
Training and validation data are selected based on indices for the current split.

**Tangent Space Representation**:
The estimate_tangent_space function is applied to the training data to extract the tangent space representation, which is then appended to `connectivity_list`.
**Output**:

A list of tangent space representations (`connectivity_list`) for training data across all cross-validation folds.

In [12]:
connectivity_list = []

for atlas in range(len(atlases)):
    connectivities = []
    # Perform cross-validation across all splits
    for split in range(n_folds):
        print(f"\n--- Split {split + 1} ---")

        # Initialize lists for training and validation data
        X_train_time_series, X_val_time_series = [], []
        y_train, y_val = [], []

        # Aggregate data from all sites except the test site
        for site in sites:
            if site == test_site:  # Skip the test site
                continue

            # Add training data for the current split
            X_train_time_series.extend(
                atlas_time_series[atlas][site][idx] for idx in train_indices[site][split]
            )
            y_train.extend(
                atlas_labels[site][idx] for idx in train_indices[site][split]
            )

            # Add validation data for the current split
            X_val_time_series.extend(
                atlas_time_series[atlas][site][idx] for idx in val_indices[site][split]
            )
            y_val.extend(
                atlas_labels[site][idx] for idx in val_indices[site][split]
            )

        # Estimate tangent space representation for training data
        connectivity_m = estimate_tangent_space(X_train_time_series)
        connectivities.append(connectivity_m)
    connectivity_list.append(connectivities)


--- Split 1 ---

--- Split 2 ---

--- Split 3 ---

--- Split 4 ---

--- Split 5 ---

--- Split 6 ---

--- Split 7 ---

--- Split 8 ---

--- Split 9 ---

--- Split 10 ---

--- Split 1 ---

--- Split 2 ---

--- Split 3 ---

--- Split 4 ---

--- Split 5 ---

--- Split 6 ---

--- Split 7 ---

--- Split 8 ---

--- Split 9 ---

--- Split 10 ---


### Cross-validation, feature transformation, and model evaluation

This code performs cross-validation to evaluate the performance of a deep neural network (DNN) model on functional connectivity data from multiple brain atlases. Key steps include:

1. **Data Aggregation**: 
   - Aggregates time-series data and labels across different sites, excluding the test site.
   - Transforms data using tangent space embeddings.

2. **Model Training**: 
   - Trains the DNN model using training data for each split.
   - Validates the model during training using the validation dataset.

3. **Performance Evaluation**: 
   - Evaluates the model's predictions using accuracy, sensitivity, precision, specificity, and AUC metrics.
   - Prints performance for both validation and test datasets for each split.

4. **Results**: 
   - Outputs mean performance metrics across all splits for both validation and test datasets.

**Key Variables**:
- `metrics`: Stores accumulated metrics for validation and test datasets.
- `X_train`, `X_val`, `X_test`: Feature matrices for training, validation, and testing.
- `y_train`, `y_val`, `y_test`: Corresponding labels.

This workflow helps assess model performance and generalization on unseen data.

In [13]:
# Initialize accumulators for metrics
metrics = {
    "validation": {"accuracy": 0, "sensitivity": 0, "precision": 0, "specificity": 0, "auc": 0},
    "test": {"accuracy": 0, "sensitivity": 0, "precision": 0, "specificity": 0, "auc": 0}
}

# Cross-validation across all splits
for split in range(n_folds):
    print(f"\n--- Split {split + 1} ---")

    # Aggregate training and validation data across all sites
    y_train, y_val = [], []

    print("Labels")

    for site in sites:
        if site == test_site:
            continue

        y_train.extend([atlas_labels[site][idx] for idx in train_indices[site][split]])
        y_val.extend([atlas_labels[site][idx] for idx in val_indices[site][split]])

    y_test = atlas_labels[test_site]

    print(f"Time series data")
    
    X_train, X_val, X_test = np.array([None]), np.array([None]), np.array([None])
    
    for atlas in range(len(atlases)):
        X_train_time_series_atlas, X_val_time_series_atlas = [], [] 
        
        for site in sites:
            if site == test_site:
                continue
         
            X_train_time_series_atlas.extend([atlas_time_series[atlas][site][idx] for idx in train_indices[site][split]])
            X_val_time_series_atlas.extend([atlas_time_series[atlas][site][idx] for idx in val_indices[site][split]])
         
        # Prepare tangent space for feature extraction
        X_train_atlas = connectivity_list[atlas][split].transform(X_train_time_series_atlas)
        X_val_atlas = connectivity_list[atlas][split].transform(X_val_time_series_atlas)
        X_test_atlas = connectivity_list[atlas][split].transform(atlas_time_series[atlas][test_site])
        
        print(f"atlas {atlases[atlas]}: {X_train_atlas.shape}")
        
        if X_train.any() == None:
            X_train = X_train_atlas
        else:
            X_train = np.concatenate((X_train, X_train_atlas),axis=1)
        
        if X_val.any() == None:
            X_val = X_val_atlas
        else:
            X_val = np.concatenate((X_val, X_val_atlas),axis=1)
        
        if X_test.any() == None:
            X_test = X_test_atlas
        else:
            X_test = np.concatenate((X_test, X_test_atlas),axis=1)

    X_train, X_val, X_test = map(np.array, [X_train, X_val, X_test])
    y_train, y_val, y_test = map(np.array, [y_train, y_val, y_test])

    # Print dataset statistics
    print(f"Training set shape: {X_train.shape}, class balance: ASD={np.sum(y_train == 1)}, TC={np.sum(y_train == 0)}")
    print(f"Validation set shape: {X_val.shape}, class balance: ASD={np.sum(y_val == 1)}, TC={np.sum(y_val == 0)}")
    print(f"Test set shape: {X_test.shape}, class balance: ASD={np.sum(y_test == 1)}, TC={np.sum(y_test == 0)}")

    # Build and train the model
    dnn = build_model(X_train.shape[1])
    history = dnn.fit(X_train, y_train, validation_data=(X_val, y_val), batch_size=32, epochs=100, callbacks=callbacks_list)

    # Evaluate on validation set
    validation_pred_prob = dnn.predict(X_val).ravel()
    validation_pred = (validation_pred_prob > 0.5).astype(int)
    acc, sens, prec, spec, auc, cm = calculate_metrics(y_val, validation_pred, validation_pred_prob)
    metrics["validation"]["accuracy"] += acc
    metrics["validation"]["sensitivity"] += sens
    metrics["validation"]["precision"] += prec
    metrics["validation"]["specificity"] += spec
    metrics["validation"]["auc"] += auc
    print_metrics(split, "validation", acc, sens, prec, spec, auc, cm)

    # Evaluate on test set
    test_pred_prob = dnn.predict(X_test).ravel()
    test_pred = (test_pred_prob > 0.5).astype(int)
    acc, sens, prec, spec, auc, cm = calculate_metrics(y_test, test_pred, test_pred_prob)
    metrics["test"]["accuracy"] += acc
    metrics["test"]["sensitivity"] += sens
    metrics["test"]["precision"] += prec
    metrics["test"]["specificity"] += spec
    metrics["test"]["auc"] += auc
    print_metrics(split, "test", acc, sens, prec, spec, auc, cm)

# Print mean metrics
for dataset in metrics:
    print(f"\n--- Mean {dataset.capitalize()} Metrics Across All Splits ---")
    for metric, value in metrics[dataset].items():
        print(f"{metric.capitalize()}: {(value / n_folds) * 100:.2f}%")


--- Split 1 ---
Labels
Time series data
atlas rois_cc200: (834, 19900)
atlas rois_aal: (834, 6670)
Training set shape: (834, 26570), class balance: ASD=417, TC=417
Validation set shape: (94, 26570), class balance: ASD=47, TC=47
Test set shape: (56, 26570), class balance: ASD=28, TC=28
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100

Validation Metrics for Split 1:
  Accuracy: 65.96%
  Sensitivity (Recall): 57.45%
  Precision: 69.23%
  Specificity: 74.47%
  AUC-ROC Score: 69.17%
  Confusion Matrix:
[[35 12]
 [20 27]]

Test Metrics for Split 1:
  Accuracy: 67.86%
  Sensitivity (Recall): 64.29%
  Precision: 69.23%
  Specificity: 71.43%
  AUC-ROC Score: 80.87%
  Confusion Matrix:
[[20  8]
 [10 18]]

--- Split 2 ---
Labels
Time series data
atlas rois_cc200: (832, 19900)
atlas rois_aal: (832, 6670)
Training set shape: (832, 26570), class balance: ASD=416, TC=416
Validation set shape: (96, 26570), class bal

In [14]:
X_data, X_test = np.array([None]), np.array([None])
y_data = []

# Aggregate labels from all sites except the test site
for site in sites:
    if site == test_site:  # Skip the test site
        continue

    # Add training labels for the current split
    y_data.extend(atlas_labels[site])

y_test = atlas_labels[test_site]

for atlas in range(len(atlases)):
    # Initialize lists for training data
    X_train_time_series_atlas = []

    # Aggregate data from all sites except the test site
    for site in sites:
        if site == test_site:  # Skip the test site
            continue

        # Add training data for the current site
        X_train_time_series_atlas.extend(
            atlas_time_series[atlas][site]
        )

    # Estimate tangent space representation for training data
    connectivity_m = estimate_tangent_space(X_train_time_series_atlas)
    # Transform data into feature vectors
    X_train_atlas = connectivity_m.transform(X_train_time_series_atlas)
    X_test_atlas = connectivity_m.transform(atlas_time_series[atlas][test_site])

    if X_data.any() == None:
        X_data = X_train_atlas
    else:
        X_data = np.concatenate((X_data, X_train_atlas),axis=1)

    if X_test.any() == None:
        X_test = X_test_atlas
    else:
        X_test = np.concatenate((X_test, X_test_atlas),axis=1)

Handles the data preparation, model training, and evaluation on training, validation, and test datasets.

1. **Splitting the Data**:
   - The training data (`X_data`, `y_data`) is split into training (`X_train`, `y_train`) and validation (`X_val`, `y_val`) sets using `train_test_split` with stratification to preserve class balance.
   - The test set remains separate as `X_test` and `y_test`.

2. **Data Conversion**:
   - Training, validation, and test datasets (`X_train`, `X_val`, `X_test`, `y_train`, `y_val`, `y_test`) are converted to numpy arrays for compatibility with the model.

3. **Dataset Statistics**:
   - Prints the shapes of the training, validation, and test sets along with the class distribution for `ASD` (positive class) and `TC` (negative class).

4. **Model Training**:
   - A deep neural network (DNN) is constructed using the `build_model` function, with the input shape derived from `X_train`.
   - The model is trained on the training set (`X_train`, `y_train`) and evaluated on the validation set (`X_val`, `y_val`) using:
     - Batch size: 32
     - Epochs: 100
     - Callbacks: `callbacks_list` for monitoring and early stopping.

5. **Evaluation**:
   - Predictions and metrics (accuracy, sensitivity, precision, specificity, AUC) are calculated for the following sets:
     - **Training Set**:
       - Predictions: `train_pred_prob`, thresholded at 0.5 for binary classification.
       - Metrics calculated using `calculate_metrics` and printed using `print_metrics`.
     - **Validation Set**:
       - Predictions: `val_pred_prob`, thresholded at 0.5 for binary classification.
       - Metrics calculated and printed as above.
     - **Test Set**:
       - Predictions: `test_pred_prob`, thresholded at 0.5 for binary classification.
       - Metrics calculated and printed as above.

In [16]:
from sklearn.model_selection import train_test_split

# Split the data into training and validation sets
X_train, X_val, y_train, y_val = train_test_split(X_data, y_data, stratify=y_data, test_size=0.2)

# Convert labels to numpy arrays
X_train, X_val, X_test = map(np.array, [X_train, X_val, X_test])
y_train, y_val, y_test = map(np.array, [y_train, y_val, y_test])

# Print dataset statistics
print(f"Training set shape: {X_train.shape}, class balance: ASD={np.sum(y_train == 1)}, TC={np.sum(y_train == 0)}")
print(f"Validation set shape: {X_val.shape}, class balance: ASD={np.sum(y_val == 1)}, TC={np.sum(y_val == 0)}")
print(f"Test set shape: {X_test.shape}, class balance: ASD={np.sum(y_test == 1)}, TC={np.sum(y_test == 0)}")

# Build and train the model
dnn = build_model(X_train.shape[1])
history = dnn.fit(X_train, y_train, validation_data=(X_val, y_val), batch_size=32, epochs=100, callbacks=callbacks_list)

# Evaluate on training set
train_pred_prob = dnn.predict(X_train).ravel()
train_pred = (train_pred_prob > 0.5).astype(int)
acc, sens, prec, spec, auc, cm = calculate_metrics(y_train, train_pred, train_pred_prob)
print_metrics(1, "training", acc, sens, prec, spec, auc, cm)

# Evaluate on validation set
val_pred_prob = dnn.predict(X_val).ravel()
val_pred = (val_pred_prob > 0.5).astype(int)
acc, sens, prec, spec, auc, cm = calculate_metrics(y_val, val_pred, val_pred_prob)
print_metrics(1, "validation", acc, sens, prec, spec, auc, cm)

# Evaluate on test set
test_pred_prob = dnn.predict(X_test).ravel()
test_pred = (test_pred_prob > 0.5).astype(int)
acc, sens, prec, spec, auc, cm = calculate_metrics(y_test, test_pred, test_pred_prob)
print_metrics(1, "test", acc, sens, prec, spec, auc, cm)

Training set shape: (835, 26570), class balance: ASD=401, TC=434
Validation set shape: (209, 26570), class balance: ASD=100, TC=109
Test set shape: (56, 26570), class balance: ASD=28, TC=28
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100

Training Metrics for Split 2:
  Accuracy: 100.00%
  Sensitivity (Recall): 100.00%
  Precision: 100.00%
  Specificity: 100.00%
  AUC-ROC Score: 100.00%
  Confusion Matrix:
[[434   0]
 [  0 401]]

Validation Metrics for Split 2:
  Accuracy: 73.68%
  Sensitivity (Recall): 72.00%
  Precision: 72.73%
  Specificity: 75.23%
  AUC-ROC Score: 78.06%
  Confusion Matrix:
[[82 27]
 [28 72]]

Test Metrics for Split 2:
  Accuracy: 76.79%
  Sensitivity (Recall): 78.57%
  Precision: 75.86%
 