### Import Modules

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'  # %env CUDA_VISIBLE_DEVICES=0
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Subset
import time
import logging
import matplotlib.pyplot as plt
import seaborn as sns
import sys
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, roc_curve, auc, precision_recall_curve
from torch_geometric.nn import GCNConv, GATConv, SAGEConv, TransformerConv, global_mean_pool
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader
from torchmetrics import Accuracy
from monai.utils import set_determinism
import nibabel as nib
from nilearn import plotting
import json
import warnings

### Define Functions and Classes

In [None]:
class BrainConnectivityDataset(Dataset):
    def __init__(self, data_dir, model_name, node_feature_modalities=None, edge_feature_modalities=None, scaling='global', global_stats=None):
        super().__init__()
        self.data_dir = data_dir
        self.node_feature_modalities = node_feature_modalities if node_feature_modalities else []
        self.edge_feature_modalities = edge_feature_modalities if edge_feature_modalities else []
        self.model_type = model_name
        self.scaling = scaling
        # Automatically determine edge construction based on model name
        if model_name in ['GCN', 'GraphSAGE']:
            self.edge_construction = 'sparse'
        elif model_name in ['GAT', 'TransformerConv']:
            self.edge_construction = 'dense'
        else:
            raise ValueError(f"Unsupported model name: {model_name}")
        # Load subject list from Subjects.csv
        self.df = pd.read_csv(os.path.join(data_dir, 'Subjects.csv'))
        self.subjects = self.df['ID'].to_numpy()
        # Compute global statistics if global scaling is requested
        if scaling == 'global':
            if global_stats is not None:
                self.global_stats = global_stats
                print("Using provided global statistics for scaling.")
            else:
                print("\n" + "="*70)
                print("Computing global statistics from training set...")
                print("="*70)
                self.global_stats = self._compute_global_stats()
                print("="*70 + "\n")    
        # Warning for multimodal edges with GCN/GraphSAGE
        num_edge_modalities = len(self.edge_feature_modalities)
        if model_name in ['GCN', 'GraphSAGE'] and num_edge_modalities > 1:
            if model_name == 'GCN':
                warnings.warn(
                    f"GCN only supports scalar edge weights. "
                    f"Only the first ('{edge_feature_modalities[0]}') will be used as edge weights.",
                    UserWarning
                )
            elif model_name == 'GraphSAGE':
                warnings.warn(
                    f"GraphSAGE does not use edge weights or features. "
                    f"Only binary connectivity (edge_index) will be used.",
                    UserWarning
                )
    def _compute_global_stats(self): # Compute global mean and std from all training subjects
        stats = {}
        all_modalities = list(set(self.node_feature_modalities + self.edge_feature_modalities))
        for modality in all_modalities:
            print(f"\n{modality}:")
            all_values = []
            for subject in self.subjects:
                data = pd.read_csv(os.path.join(self.data_dir, modality, f"{subject:03d}.csv"), header=None).values
                if data.ndim == 1 or (data.ndim == 2 and data.shape[1] == 1): # Vector: all values                  
                    all_values.extend(data.flatten())
                elif data.ndim == 2 and data.shape[0] == data.shape[1]: # Matrix: exclude diagonal and zeros                   
                    mask = (np.abs(data) > 0) & (~np.eye(data.shape[0], dtype=bool))
                    if mask.any():
                        all_values.extend(data[mask])
            all_values = np.array(all_values)
            stats[modality] = {
                'n_values': len(all_values),
                'n_subjects': len(self.subjects),
                'mean': np.mean(all_values),
                'std': np.std(all_values),
                'min': float(np.min(all_values)),
                'max': float(np.max(all_values))
            }
            print(f"  Data points: {stats[modality]['n_values']:,} from {stats[modality]['n_subjects']} subjects")
            print(f"  Mean: {stats[modality]['mean']:.6f}")
            print(f"  Std: {stats[modality]['std']:.6f}")
            print(f"  Range: [{stats[modality]['min']:.6f}, {stats[modality]['max']:.6f}]")  
        return stats
    def len(self):
        return len(self.subjects)    
    def get(self, idx):
        subject = self.subjects[idx]        
        # Determine number of nodes from first available modality
        num_nodes = self._get_num_nodes(subject)        
        # Construct node features
        if self.node_feature_modalities: # Use provided modalities
            x = self._load_node_features(subject, num_nodes)
        else: # Dummy node features (all ones)
            x = torch.ones(num_nodes, 1, dtype=torch.float)
        # Construct edge features
        if self.edge_feature_modalities: # Use provided modalities            
            edge_matrices = self._load_edge_matrices(subject, num_nodes)            
            # Convert to edge_index and edge_attr based on edge construction
            if self.edge_construction == 'dense':
                edge_index, edge_attr = self._matrices_to_edge_index_attr_dense(edge_matrices, num_nodes)
            elif self.edge_construction == 'sparse':
                edge_index, edge_attr = self._matrices_to_edge_index_attr_sparse(edge_matrices, num_nodes)
        else: # Dummy edge features: fully connected with all ones
            edge_index = self._create_fully_connected_edge_index(num_nodes)
            edge_attr = torch.ones(edge_index.size(1), 1, dtype=torch.float)        
        # Load label
        label = torch.tensor(self.df['Sex'][idx], dtype=torch.long) if 'Sex' in self.df.columns else None
        return Data(
            x=x,
            edge_index=edge_index,
            edge_attr=edge_attr,
            y=label,
            num_nodes=num_nodes
        )
    def _get_num_nodes(self, subject): # Determine number of nodes from first available modality
        all_modalities = self.node_feature_modalities + self.edge_feature_modalities
        if not all_modalities:
            raise ValueError("Must specify at least one modality in node_feature_modalities or edge_feature_modalities to determine graph size")
        first_modality = all_modalities[0]
        data = pd.read_csv(
            os.path.join(self.data_dir, first_modality, f"{subject:03d}.csv"),
            header=None
        ).values
        if data.ndim == 2 and data.shape[0] == data.shape[1]:
            return data.shape[0]
        else:
            return len(data.flatten())
    def _load_node_features(self, subject, num_nodes): # Load and process node features from specified modalities
        node_features_list = []        
        for modality in self.node_feature_modalities:
            data = pd.read_csv(os.path.join(self.data_dir, modality, f"{subject:03d}.csv"), header=None).values            
            # Process based on data shape
            if data.ndim == 1 or (data.ndim == 2 and data.shape[1] == 1): # Vector: use directly as node feature
                node_feature = data.flatten()                
                if len(node_feature) != num_nodes:
                    raise ValueError(f"Node feature dimension mismatch for {modality}: expected {num_nodes}, got {len(node_feature)}")                
                # Apply scaling
                if self.scaling == 'global':
                    node_feature = self._scale_global(node_feature, modality)
                elif self.scaling == 'per_subject':
                    node_feature = self._scale_persubject(node_feature)
            elif data.ndim == 2 and data.shape[0] == data.shape[1]: # Matrix: scale first, then compute node feature
                if data.shape[0] != num_nodes:
                    raise ValueError(f"Matrix size mismatch for {modality}: expected ({num_nodes}, {num_nodes}), got {data.shape}")
                # Apply scaling
                if self.scaling == 'global':
                    data = self._scale_global(data, modality, is_matrix=True)
                elif self.scaling == 'per_subject':
                    data = self._scale_persubject(data, is_matrix=True)
                # Compute node feature from edges
                node_feature = self._compute_node_feature_from_edges(data)
            else:
                raise ValueError(f"Invalid data shape for {modality}: {data.shape}: expected (N,) or (N, N)")
            node_features_list.append(torch.tensor(node_feature, dtype=torch.float).unsqueeze(1))
        return torch.cat(node_features_list, dim=1)  # [num_nodes, num_node_features]
    def _load_edge_matrices(self, subject, num_nodes): # Load and scale edge matrices from specified modalities
        edge_matrices = {}
        for modality in self.edge_feature_modalities:
            conn_matrix = pd.read_csv(os.path.join(self.data_dir, modality, f"{subject:03d}.csv"), header=None).values
            if not (conn_matrix.ndim == 2 and conn_matrix.shape[0] == conn_matrix.shape[1]):
                raise ValueError(f"Edge feature dimension mismatch for {modality}: got {conn_matrix.shape}")
            if conn_matrix.shape[0] != num_nodes:
                raise ValueError(f"Matrix size mismatch for {modality}: expected ({num_nodes}, {num_nodes}), got {conn_matrix.shape}")
            # Apply scaling
            if self.scaling == 'global':
                conn_matrix = self._scale_global(conn_matrix, modality, is_matrix=True)
            elif self.scaling == 'per_subject':
                conn_matrix = self._scale_persubject(conn_matrix, is_matrix=True)
            edge_matrices[modality] = conn_matrix
        return edge_matrices 
    def _scale_global(self, data, modality, is_matrix=False): # Scale data using global statistics: (data - global_mean) / global_std
        mean = self.global_stats[modality]['mean']
        std = self.global_stats[modality]['std']
        if is_matrix: # Matrix: only scale non-zero values to preserve sparsity 
            mask = (np.abs(data) > 0) & (~np.eye(data.shape[0], dtype=bool))
            scaled_data = np.zeros_like(data, dtype=np.float64)
            if mask.any():
                scaled_data[mask] = (data[mask] - mean) / (std + 1e-8)
            np.fill_diagonal(scaled_data, 1) # Preserve diagonal
        else: # Vector: scale all values
            scaled_data = (data - mean) / (std + 1e-8)
        return scaled_data
    def _scale_persubject(self, data, is_matrix=False): # Scale data to [-1, 1] range per subject
        if is_matrix: # Matrix: exclude diagonal
            values = data[~np.eye(data.shape[0], dtype=bool)]
            absmax = np.abs(values).max()
            scaled_data = data / (absmax + 1e-8)
            np.fill_diagonal(scaled_data, 1)
        else: # Vector
            absmax = np.abs(data).max()
            scaled_data = data / (absmax + 1e-8)
        return scaled_data
    def _compute_node_feature_from_edges(self, conn_matrix): # Compute node feature as mean of all connected edges (excluding diagonal)
        num_nodes = conn_matrix.shape[0]
        node_feature = np.zeros(num_nodes)    
        for i in range(num_nodes):
            connections = np.concatenate([conn_matrix[i, :i], conn_matrix[i, i+1:]])
            node_feature[i] = np.mean(connections)       
        return node_feature   
    def _matrices_to_edge_index_attr_dense(self, edge_matrices, num_nodes): # Create dense edge_index with all possible edges (no self-loops)
        modalities = list(edge_matrices.keys())
        num_modalities = len(modalities)        
        # Create all possible edges (excluding self-loops)
        edge_list = []
        for i in range(num_nodes):
            for j in range(num_nodes):
                if i != j:
                    edge_list.append([i, j])        
        num_edges = len(edge_list)
        edge_index = torch.tensor(edge_list, dtype=torch.long).t() # [2, num_edges] where num_edges = num_nodes * (num_nodes - 1)  
        # Extract edge attributes for all edges
        edge_attr = torch.zeros(num_edges, num_modalities, dtype=torch.float) # [num_edges, num_modalities]
        for mod_idx, modality in enumerate(modalities):
            matrix = edge_matrices[modality]
            for edge_idx, (i, j) in enumerate(edge_list):
                edge_attr[edge_idx, mod_idx] = matrix[i, j]
        return edge_index, edge_attr    
    def _matrices_to_edge_index_attr_sparse(self, edge_matrices, num_nodes): # Create sparse edge_index with non-zero edges only (no self-loops)
        modalities = list(edge_matrices.keys())
        num_modalities = len(modalities)        
        # Use first modality to determine which edges exist
        primary_modality = modalities[0]
        primary_matrix = edge_matrices[primary_modality]        
        # Create mask for non-zero, non-diagonal entries
        mask = (np.abs(primary_matrix) > 0) & (~np.eye(num_nodes, dtype=bool))
        edge_indices = np.argwhere(mask)
        num_edges = len(edge_indices)        
        edge_index = torch.tensor(edge_indices.T, dtype=torch.long) # [2, num_edges] where num_edges = number of non-zero edges    
        # Extract edge attributes from all modalities
        edge_attr = torch.zeros(num_edges, num_modalities, dtype=torch.float) # [num_edges, num_modalities]
        for mod_idx, modality in enumerate(modalities):
            matrix = edge_matrices[modality]
            for edge_idx, (i, j) in enumerate(edge_indices):
                edge_attr[edge_idx, mod_idx] = matrix[i, j]        
        return edge_index, edge_attr  
    def _create_fully_connected_edge_index(self, num_nodes): # Create fully connected edge_index (no self-loops)
        edge_index = []
        for i in range(num_nodes):
            for j in range(num_nodes):
                if i != j:
                    edge_index.append([i, j])
        return torch.tensor(edge_index, dtype=torch.long).t()
    
def load_data(dataset, batch_size, inference=False, test_size=0.2):
    if inference: # Test
        test_loader = DataLoader(
            dataset, 
            batch_size=1, 
            num_workers=0, 
            pin_memory=torch.cuda.is_available()
        )
        subjects = dataset.subjects
        return test_loader, subjects
    else: # Training/Validation
        # Split data into train and validation sets with stratification
        indices = list(range(len(dataset)))
        labels = [dataset[i].y.item() for i in indices]
        train_indices, val_indices = train_test_split(indices, test_size=test_size, stratify=labels, random_state=42)
        train_ds = Subset(dataset, train_indices)
        val_ds = Subset(dataset, val_indices)
        train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=torch.cuda.is_available())
        val_loader = DataLoader(val_ds, batch_size=batch_size, num_workers=0, pin_memory=torch.cuda.is_available())
        return train_loader, val_loader

class BrainGNN(torch.nn.Module):
    def __init__(self, model_type, num_node_features, num_edge_features, 
                 num_channels=[32, 32, 64, 64, 128], **kwargs):
        super().__init__()
        self.model_type = model_type
        heads = kwargs.get('heads', 4) # For attention-based models
        # Adjust channels for multi-head attention
        if model_type in ['GAT', 'TransformerConv']:
            for channel in num_channels:
                if channel % heads != 0:
                    raise ValueError(f"All channel numbers must be divisible by number of heads ({heads}): got channel={channel}")
            channels_to_use = [c // heads for c in num_channels]
        else:
            channels_to_use = num_channels
        # Build convolution layers
        self.conv_layers = nn.ModuleList()
        self.batch_norms = nn.ModuleList()          
        # First layer
        if model_type == 'GCN':
            self.conv_layers.append(GCNConv(num_node_features, channels_to_use[0]))
        elif model_type == 'GraphSAGE':
            self.conv_layers.append(SAGEConv(num_node_features, channels_to_use[0]))
        elif model_type == 'GAT':
            self.conv_layers.append(GATConv(num_node_features, channels_to_use[0], heads=heads, edge_dim=num_edge_features))
        elif model_type == 'TransformerConv':
            self.conv_layers.append(TransformerConv(num_node_features, channels_to_use[0], heads=heads, edge_dim=num_edge_features))
        else:
            raise ValueError(f"Unsupported model type: {model_type}")
        # Batch norm for first layer
        if model_type in ['GAT', 'TransformerConv']:
            self.batch_norms.append(nn.BatchNorm1d(channels_to_use[0] * heads))
        else:
            self.batch_norms.append(nn.BatchNorm1d(channels_to_use[0]))
        # Subsequent layers
        for i in range(len(channels_to_use) - 1):
            num_in_channels = channels_to_use[i] * (heads if model_type in ['GAT', 'TransformerConv'] else 1)
            if model_type == 'GCN':
                self.conv_layers.append(GCNConv(num_in_channels, channels_to_use[i + 1]))
            elif model_type == 'GraphSAGE':
                self.conv_layers.append(SAGEConv(num_in_channels, channels_to_use[i + 1]))
            elif model_type == 'GAT':
                self.conv_layers.append(GATConv(num_in_channels, channels_to_use[i + 1], heads=heads, edge_dim=num_edge_features))
            elif model_type == 'TransformerConv':
                self.conv_layers.append(TransformerConv(num_in_channels, channels_to_use[i + 1], heads=heads, edge_dim=num_edge_features))
            # Batch norm
            if model_type in ['GAT', 'TransformerConv']:
                self.batch_norms.append(nn.BatchNorm1d(channels_to_use[i + 1] * heads))
            else:
                self.batch_norms.append(nn.BatchNorm1d(channels_to_use[i + 1]))
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.2)
        # Fully connected layers
        num_final_channels = channels_to_use[-1] * (heads if model_type in ['GAT', 'TransformerConv'] else 1)
        self.fc_layers = nn.Sequential(
            nn.Linear(num_final_channels, num_final_channels // 2),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(num_final_channels // 2, num_final_channels // 4),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(num_final_channels // 4, 2)  # Binary classification
        )
    def forward(self, data):
        x = data.x
        edge_index = data.edge_index
        edge_attr = data.edge_attr  
        # Graph convolution layers with residual connections
        for i, (conv, bn) in enumerate(zip(self.conv_layers, self.batch_norms)):
            identity = x   
            # Apply convolution
            if self.model_type == 'GCN':
                # GCN: Use edge_attr as edge_weight (must be scalar)
                edge_weight = edge_attr[:, 0] if edge_attr.dim() > 1 else edge_attr
                x = conv(x, edge_index, edge_weight=edge_weight)
            elif self.model_type == 'GraphSAGE':
                # GraphSAGE: No edge features supported
                x = conv(x, edge_index)
            else: # GAT or TransformerConv
                # GAT/TransformerConv: Use edge_attr as vector features
                x = conv(x, edge_index, edge_attr=edge_attr) 
            x = bn(x)
            x = self.relu(x)
            x = self.dropout(x)
            # Residual connection (skip connection)
            if i > 0 and x.size() == identity.size():
                x = x + identity
        # Global pooling
        x = global_mean_pool(x, data.batch)
        # Classification head
        x = self.fc_layers(x)
        return x

def get_model(model_name, num_node_features, num_edge_features, num_channels=[32, 32, 64, 64, 128], heads=4):
    # Validate model name
    supported_models = ['GCN', 'GraphSAGE', 'GAT', 'TransformerConv']
    if model_name not in supported_models:
        raise ValueError(f"Unsupported model name: {model_name}")
    # Common parameters for all models
    params = {
        'model_type': model_name,
        'num_node_features': num_node_features,
        'num_edge_features': num_edge_features,
        'num_channels': num_channels
    }
    # Add heads only for attention-based models
    if model_name in ['GAT', 'TransformerConv']:
        params['heads'] = heads
    return BrainGNN(**params)

def visualize_conn_matrix(sample, edge_feature_modalities, title=None, graph_idx=0):
    n_modalities = len(edge_feature_modalities)
    _, axs = plt.subplots(1, n_modalities, 
                          figsize=(n_modalities * 4, 4), 
                          squeeze=False)
    # Step 1: Get graph boundaries from ptr
    batch_size = len(sample.ptr) - 1
    if graph_idx >= batch_size:
        raise ValueError(f"graph_idx {graph_idx} >= batch_size {batch_size}")
    node_start = sample.ptr[graph_idx].item()
    node_end = sample.ptr[graph_idx + 1].item()
    num_nodes = node_end - node_start
    # Step 2: Find edges belonging to this graph
    edge_mask = (sample.edge_index[0] >= node_start) & \
                (sample.edge_index[0] < node_end) & \
                (sample.edge_index[1] >= node_start) & \
                (sample.edge_index[1] < node_end)
    #  Step 3: Extract edges and attributes
    edge_index_graph = sample.edge_index[:, edge_mask]
    edge_attr_graph = sample.edge_attr[edge_mask]
    # Step 4: Convert to local indexing (0 to num_nodes-1)
    edge_index_local = edge_index_graph - node_start
    # Visualize each modality
    for col, modality in enumerate(edge_feature_modalities):
        edge_attr = edge_attr_graph[:, col]
        # Initialize with zeros
        conn_matrix = torch.zeros((num_nodes, num_nodes), 
                                  dtype=torch.float32,
                                  device=edge_attr.device)
        # Fill in edge values
        conn_matrix[edge_index_local[0], edge_index_local[1]] = edge_attr
        # Convert to numpy for plotting
        matrix_numpy = conn_matrix.cpu().numpy()
        # Color scale
        non_zero_values = matrix_numpy[matrix_numpy != 0]
        if len(non_zero_values) > 0:
            abs_max = np.abs(non_zero_values).max()
            has_negative = np.any(non_zero_values < 0)
        else:
            abs_max = 1
            has_negative = False
        # Plot
        ax = axs[0, col]
        im = ax.imshow(
            matrix_numpy, 
            cmap='RdBu_r' if has_negative else 'hot',
            vmin=-abs_max if has_negative else 0, 
            vmax=abs_max,
            aspect='equal'
        )
        ax.set_title(f'{modality}', fontsize=11)
        plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
        ax.tick_params(left=False, bottom=False, 
                      labelleft=False, labelbottom=False)
    if title:
        plt.suptitle(f"{title} (Graph {graph_idx})", fontsize=13, y=1.02)
    plt.tight_layout()
    plt.show()

def get_grad_scaler(device):
    if device.type != "cuda":
        return None
    try: # Try newest API first (PyTorch 2.0+)
        return torch.GradScaler("cuda")
    except (AttributeError, TypeError):
        try: # Try torch.amp (PyTorch 1.10+)
            return torch.amp.GradScaler("cuda")
        except (AttributeError, TypeError): # Fall back to old API
            return torch.cuda.amp.GradScaler()

def get_autocast_context(device, enabled=True):
    if not enabled:
        from contextlib import nullcontext
        return nullcontext()
    try:
        # Try newest API first (PyTorch 2.0+)
        return torch.autocast(device_type=device.type, dtype=torch.float16)
    except (AttributeError, TypeError):
        try:
            # Try torch.amp (PyTorch 1.10+)
            return torch.amp.autocast(device.type)
        except (AttributeError, TypeError):
            # Fall back to CUDA-specific (old)
            if device.type == "cuda":
                return torch.cuda.amp.autocast()
            else:
                from contextlib import nullcontext
                return nullcontext()

def train_one_epoch(model, device, train_loader, optimizer, criterion, scaler, metric):
    model.train() # Set model to training mode
    epoch_loss = 0.0
    metric.reset()
    for batch_data in train_loader:
        # Prepare data
        batch_data = batch_data.to(device)
        # Forward pass with mixed precision (if available)
        optimizer.zero_grad()
        with get_autocast_context(device, enabled=(scaler is not None)):
            outputs = model(batch_data)
            loss = criterion(outputs, batch_data.y)
        # Backward pass with gradient scaling (if available)
        if scaler:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()
        # Accumulate metrics
        epoch_loss += loss.item()
        metric(outputs[:, 1], batch_data.y)
    epoch_metric = metric.compute().item()
    return epoch_loss / len(train_loader), epoch_metric

def validate_one_epoch(model, device, val_loader, metric):
    model.eval()
    metric.reset()
    with torch.no_grad():
        for batch_data in val_loader:
            batch_data = batch_data.to(device)
            outputs = model(batch_data)
            metric(outputs[:, 1], batch_data.y)
    return metric.compute().item()

class EarlyStopping:
    def __init__(self, patience=30, delta=0):
        self.patience = patience # Number of epochs to wait before stopping
        self.delta = delta # Minimum improvement threshold
        self.best_score = None
        self.early_stop = False
        self.counter = 0
    def __call__(self, metric):
        score = metric
        if self.best_score is None: # First epoch
            self.best_score = score
        elif score < self.best_score + self.delta: # Metric decreased
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else: # Metric improved
            self.best_score = score
            self.counter = 0

def train_model(model_dir, model, device, train_loader, val_loader, logger,
        criterion, metric, max_epochs=100, learning_rate=1e-4, weight_decay=1e-5, val_interval=1, es_patience=30):
    # Setup optimizer and learning rate scheduler
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs)
    scaler = get_grad_scaler(device)
    start_time = time.time()
    best_metric = -1
    best_metric_epoch = -1
    best_model_state = None
    early_stopping = EarlyStopping(patience=es_patience, delta=0)
    epoch_loss_values, epoch_metric_values, metric_values = [], [], []
    for epoch in range(max_epochs):
        epoch_start_time = time.time()
        # Training phase
        epoch_loss, epoch_metric = train_one_epoch(model, device, train_loader, optimizer, criterion, scaler, metric)
        epoch_loss_values.append(epoch_loss)
        epoch_metric_values.append(epoch_metric)
        # Validation phase
        if (epoch + 1) % val_interval == 0:
            val_metric = validate_one_epoch(model, device, val_loader, metric)
            metric_values.append(val_metric)
            # Save best model
            if val_metric > best_metric:
                best_metric = val_metric
                best_metric_epoch = epoch + 1
                best_model_state = model.state_dict()
                torch.save(model.state_dict(), os.path.join(model_dir, "BestMetricModel.pth"))
                logger.info(f"Best accuracy: {best_metric:.4f} at epoch {best_metric_epoch}")
            # Check early stopping
            early_stopping(val_metric)
            if early_stopping.early_stop:
                logger.info(f"Early stopping triggered at epoch {epoch + 1}")
                print(f"; Early stopping triggered at epoch {epoch + 1}", end="")
                break
        epoch_end_time = time.time()
        logger.info(
            f"Epoch {epoch + 1} completed for {(epoch_end_time - epoch_start_time)/60:.2f} mins - "
            f"Training loss: {epoch_loss:.4f}, Training accuracy: {epoch_metric:.4f}, Validation accuracy: {val_metric:.4f}"
        )
        # Update learning rate
        lr_scheduler.step()
        sys.stdout.write(f"\rEpoch {epoch + 1}/{max_epochs} completed")
        sys.stdout.flush()
    end_time = time.time()
    total_time = end_time - start_time
    logger.info(
        f"Best accuracy: {best_metric:.3f} at epoch {best_metric_epoch}; "
        f"Total time consumed: {total_time/60:.2f} mins"
    )
    print(
        f"\nBest accuracy: {best_metric:.3f} at epoch {best_metric_epoch}; "
        f"Total time consumed: {total_time/60:.2f} mins"
    )
    # Load best model weights
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
    return model, epoch_loss_values, epoch_metric_values, metric_values

def plot_metric_values(model_dir, epoch_loss_values, epoch_metric_values, metric_values, val_interval=1):
    _, axs = plt.subplots(1, 2, figsize=(12, 5))
    axs[0].plot([i + 1 for i in range(len(epoch_loss_values))], epoch_loss_values, label='Training Loss', color='red')
    axs[0].set_title('Training Loss')
    axs[0].set_xlabel('Epoch')
    axs[0].set_ylabel('Loss')
    axs[1].plot([i + 1 for i in range(len(epoch_metric_values))], epoch_metric_values, label='Training Accuracy', color='red')
    axs[1].plot([val_interval * (i + 1) for i in range(len(metric_values))], metric_values, label='Validation Accuracy', color='blue')
    axs[1].set_title('Training Accuracy vs. Validation Accuracy')
    axs[1].set_xlabel('Epoch')
    axs[1].set_ylabel('Accuracy')
    axs[1].legend()
    plt.tight_layout()
    plt.savefig(os.path.join(model_dir, "Performance.png"), dpi=300)

def occlusion_based_sensitivity(model, sample, style='diverging', directed=False, window_size=10):
    # Compute sensitivity by systematically occluding edges
    model.eval()
    device = next(model.parameters()).device
    sample = sample.to(device)
    # Get baseline prediction
    baseline_output = model(sample)
    baseline_pred = F.softmax(baseline_output, dim=1)[0, 1].item()
    num_nodes_per_graph = sample.num_nodes
    sensitivity_matrix = torch.zeros((num_nodes_per_graph, num_nodes_per_graph), device=device)
    # Define node pairs to evaluate
    if directed:
        node_pairs = [(i, j) for i in range(num_nodes_per_graph) for j in range(num_nodes_per_graph) if i != j]
    else:
        node_pairs = [(i, j) for i in range(num_nodes_per_graph) for j in range(i+1, num_nodes_per_graph)]
    # Process in windows for efficiency
    for window_start in range(0, len(node_pairs), window_size):
        window_end = min(window_start + window_size, len(node_pairs))
        current_pairs = node_pairs[window_start:window_end]
        # Create temporary data with occluded edges
        temp_data = Data(
            x=sample.x,
            edge_index=sample.edge_index,
            edge_attr=sample.edge_attr.clone(),
            y=sample.y,
            num_nodes=sample.num_nodes
        ).to(device)
        # Mask edges
        mask = torch.zeros_like(temp_data.edge_index[0], dtype=torch.bool)
        for i, j in current_pairs:
            if directed:
                mask |= (temp_data.edge_index[0] == i) & (temp_data.edge_index[1] == j)
            else:
                mask |= ((temp_data.edge_index[0] == i) & (temp_data.edge_index[1] == j)) | \
                        ((temp_data.edge_index[0] == j) & (temp_data.edge_index[1] == i))
        temp_data.edge_attr[mask] = 0
        # Compute sensitivity
        with torch.no_grad():
            output = model(temp_data)
            pred = F.softmax(output, dim=1)[0, 1].item()
        if style == 'diverging':
            sensitivity = pred - baseline_pred
        elif style == 'absolute':
            sensitivity = abs(pred - baseline_pred)
        # Update sensitivity matrix
        for i, j in current_pairs:
            sensitivity_matrix[i, j] = sensitivity
            if not directed:
                sensitivity_matrix[j, i] = sensitivity
    return sensitivity_matrix

def occlusion_sensitivity_analysis(model, loader, model_dir, edge_feature_modalities, 
        style='diverging', directed=False, num_samples=20, labels_file='brainnetome_labels.csv'):
    # Load brain region labels
    try:
        df_labels = pd.read_csv(labels_file)
        labels_dict = dict(zip(df_labels['Index'] - 1, df_labels['Abbreviation']))  # 0-indexed
        full_names_dict = dict(zip(df_labels['Index'] - 1, df_labels['Full_Name']))
        print(f"Loaded {len(df_labels)} brain region labels from {labels_file}")
    except FileNotFoundError:
        print(f"Warning: {labels_file} not found. Using node indices only.")
        labels_dict = None
        full_names_dict = None
    accumulated_sensitivity = None
    accumulated_samples = 0
    sampling_complete = False
    for batch_data in loader:
        if sampling_complete:
            break
        batch_size = len(batch_data.ptr) - 1
        for graph_idx in range(batch_size):
            if accumulated_samples >= num_samples:
                sampling_complete = True
                break
            # Extract single graph from batch
            start_idx = batch_data.ptr[graph_idx].item()
            end_idx = batch_data.ptr[graph_idx + 1].item()
            sample_x = batch_data.x[start_idx:end_idx, :]
            edge_mask = (batch_data.edge_index[0] >= start_idx) & (batch_data.edge_index[0] < end_idx)
            sample_edge_index = batch_data.edge_index[:, edge_mask] - start_idx
            sample_edge_attr = batch_data.edge_attr[edge_mask, :]
            sample_y = batch_data.y[graph_idx]
            sample_num_nodes = end_idx - start_idx
            sample = Data(
                x=sample_x,
                edge_index=sample_edge_index,
                edge_attr=sample_edge_attr,
                y=sample_y,
                num_nodes=sample_num_nodes,
                batch=torch.zeros(sample_num_nodes, dtype=torch.long)
            )
            accumulated_samples += 1
            # Compute sensitivity for this sample
            sensitivity_matrix = occlusion_based_sensitivity(
                model, sample, style=style, directed=directed
            )
            if accumulated_sensitivity is None:
                accumulated_sensitivity = sensitivity_matrix
            else:
                accumulated_sensitivity += sensitivity_matrix
    # Average sensitivity across samples
    mean_sensitivity = accumulated_sensitivity / num_samples
    # Save sensitivity matrix
    df_sensitivity = pd.DataFrame(mean_sensitivity.cpu().numpy())
    df_sensitivity.to_csv(
        os.path.join(model_dir, f"SensitivityMap_{'_'.join(edge_feature_modalities)}.csv"), 
        index=False, 
        header=False
    )
    # Extract top 5% most sensitive edges
    num_nodes_per_graph = mean_sensitivity.shape[0]
    if directed:
        values = mean_sensitivity.view(-1)
        mask = ~torch.eye(num_nodes_per_graph, dtype=torch.bool, device=mean_sensitivity.device).view(-1)
        values = values[mask]
        abs_values = torch.abs(values)
        threshold = torch.quantile(abs_values, 0.95)
        high_sensitivity_pairs = []
        for i in range(num_nodes_per_graph):
            for j in range(num_nodes_per_graph):
                if i != j and abs(mean_sensitivity[i, j]) >= threshold:
                    # Add region labels
                    if labels_dict:
                        node_i_label = labels_dict.get(i, f"Node_{i+1}")
                        node_j_label = labels_dict.get(j, f"Node_{j+1}")
                    else:
                        node_i_label = f"Node_{i+1}"
                        node_j_label = f"Node_{j+1}"
                    high_sensitivity_pairs.append({
                        'Node_i': i + 1,
                        'Node_i_Label': node_i_label,
                        'Node_j': j + 1,
                        'Node_j_Label': node_j_label,
                        'Sensitivity': mean_sensitivity[i, j].item()
                    })
    else:
        upper_tri = torch.triu(mean_sensitivity, diagonal=1)
        values = upper_tri.view(-1)
        values = values[values != 0]
        abs_values = torch.abs(values)
        threshold = torch.quantile(abs_values, 0.95)
        high_sensitivity_pairs = []
        for i in range(num_nodes_per_graph):
            for j in range(i+1, num_nodes_per_graph):
                if abs(mean_sensitivity[i, j]) >= threshold:
                    # Add region labels
                    if labels_dict:
                        node_i_label = labels_dict.get(i, f"Node_{i+1}")
                        node_j_label = labels_dict.get(j, f"Node_{j+1}")
                    else:
                        node_i_label = f"Node_{i+1}"
                        node_j_label = f"Node_{j+1}"
                    high_sensitivity_pairs.append({
                        'Node_i': i + 1,
                        'Node_i_Label': node_i_label,
                        'Node_j': j + 1,
                        'Node_j_Label': node_j_label,
                        'Sensitivity': mean_sensitivity[i, j].item()
                    })
    df_top_pairs = pd.DataFrame(high_sensitivity_pairs)
    df_top_pairs = df_top_pairs.sort_values('Sensitivity', key=abs, ascending=False)
    df_top_pairs.to_csv(
        os.path.join(model_dir, f"Top5PercentEdges_{'+'.join(edge_feature_modalities)}.csv"), 
        index=False
    )
    # Visualize sensitivity heatmap
    plt.figure(figsize=(10, 8))
    sensitivity_numpy = mean_sensitivity.cpu().numpy()
    heatmap_params = {
        'xticklabels': False,
        'yticklabels': False,
        'square': True,
    }
    if style == 'diverging':
        heatmap_params.update({
            'cmap': 'RdBu_r',
            'center': 0,
            'vmin': -np.abs(sensitivity_numpy).max(),
            'vmax': np.abs(sensitivity_numpy).max()
        })
    elif style == 'absolute':
        heatmap_params.update({
            'cmap': 'hot',
            'vmin': 0,
            'vmax': sensitivity_numpy.max()
        })
    sns.heatmap(sensitivity_numpy, **heatmap_params)
    plt.title(f"Sensitivity Map: {'+'.join(edge_feature_modalities)}", fontsize=16, pad=20)
    plt.tight_layout()
    plt.savefig(
        os.path.join(model_dir, f"Sensitivity_{'+'.join(edge_feature_modalities)}.png"), 
        dpi=300, 
        bbox_inches='tight'
    )
    plt.show()
    plt.close()
    # Print summary with labels
    print("\n" + "="*70)
    print("Network Summary")
    print("="*70)
    # Calculate node-level sensitivity
    if directed:
        out_sensitivity = mean_sensitivity.sum(dim=1).cpu().numpy()
        in_sensitivity = mean_sensitivity.sum(dim=0).cpu().numpy()
        node_sensitivity = (out_sensitivity + in_sensitivity) / 2
    else:
        node_sensitivity = mean_sensitivity.sum(dim=1).cpu().numpy()
    print(f"\nNumber of nodes: {num_nodes_per_graph}")
    print(f"Number of top 5% edges: {len(df_top_pairs)}")
    print(f"Maximum absolute sensitivity: {np.abs(sensitivity_numpy).max():.6f}")
    # Top positive sensitivity nodes
    print(f"\n{'='*70}")
    print("Nodes with highest positive sensitivity:")
    print(f"{'='*70}")
    top_positive = np.argsort(node_sensitivity)[-5:][::-1]
    for rank, node in enumerate(top_positive, 1):
        if labels_dict:
            label = labels_dict.get(node, f"Unknown")
            full_name = full_names_dict.get(node, "Unknown region")
            print(f"{rank}. Node {node+1:3d} ({label:15s}): {node_sensitivity[node]:7.3f}")
            print(f"    {full_name}")
        else:
            print(f"{rank}. Node {node+1:3d}: {node_sensitivity[node]:7.3f}")
    # Top negative sensitivity nodes
    print(f"\n{'='*70}")
    print("Nodes with lowest negative sensitivity:")
    print(f"{'='*70}")
    top_negative = np.argsort(node_sensitivity)[:5]
    for rank, node in enumerate(top_negative, 1):
        if labels_dict:
            label = labels_dict.get(node, f"Unknown")
            full_name = full_names_dict.get(node, "Unknown region")
            print(f"{rank}. Node {node+1:3d} ({label:15s}): {node_sensitivity[node]:7.3f}")
            print(f"    {full_name}")
        else:
            print(f"{rank}. Node {node+1:3d}: {node_sensitivity[node]:7.3f}")
    print(f"\n{'='*70}\n")

def visualize_brain_network(model_dir, atlas_path, edge_feature_modalities, directed=False):
    edges_filename = f"Top5PercentEdges_{'+'.join(edge_feature_modalities)}.csv"
    edges_path = os.path.join(model_dir, edges_filename)
    edges_df = pd.read_csv(edges_path)
    # Adjust node indices (1-based to 0-based)
    edges_df['Node_i'] = edges_df['Node_i'] - 1
    edges_df['Node_j'] = edges_df['Node_j'] - 1
    # Load atlas and get coordinates
    atlas_img = nib.load(atlas_path)
    coords = plotting.find_parcellation_cut_coords(atlas_img)
    n_nodes = len(coords)
    # Build connectivity matrix
    connectivity_matrix = np.zeros((n_nodes, n_nodes))
    max_abs_sensitivity = np.abs(edges_df['Sensitivity']).max()
    edges_df['Scaled_Sensitivity'] = edges_df['Sensitivity'] / max_abs_sensitivity
    for _, row in edges_df.iterrows():
        i, j = int(row['Node_i']), int(row['Node_j'])
        connectivity_matrix[i, j] = row['Scaled_Sensitivity']
        if not directed:
            connectivity_matrix[j, i] = row['Scaled_Sensitivity']
    # Calculate node sensitivity
    if directed:
        out_sensitivity = np.sum(connectivity_matrix, axis=1)
        in_sensitivity = np.sum(connectivity_matrix, axis=0)
        node_sensitivity = np.mean([out_sensitivity, in_sensitivity], axis=0)
    else:
        node_sensitivity = np.sum(connectivity_matrix, axis=1)
    max_abs_node_sensitivity = np.abs(node_sensitivity).max()
    node_sensitivity = node_sensitivity / max_abs_node_sensitivity
    # Create node colors
    node_colors = np.zeros((n_nodes, 3))
    positive_mask = node_sensitivity > 0
    negative_mask = node_sensitivity < 0
    node_colors[positive_mask] = np.array([1, 0, 0])
    node_colors[negative_mask] = np.array([0, 0, 1])
    node_colors *= np.abs(node_sensitivity)[:, np.newaxis]
    alpha = 0.3
    node_colors = [(plt.cm.colors.to_rgba(c, alpha)) for c in node_colors]
    # Plot brain network
    plt.figure(figsize=(12, 12))
    plotting.plot_connectome(
        connectivity_matrix, coords,
        node_color=node_colors,
        node_size=100 * np.abs(node_sensitivity) + 20,
        edge_cmap='seismic',
        edge_vmin=-1,
        edge_vmax=1,
        colorbar=True,
        edge_threshold=None,
        annotate=False,
        black_bg=False
    )
    plt.suptitle("Connectional and regional sensitivity", bbox=dict(facecolor='white', edgecolor=None))
    # Add legend
    sizes = [0.3, 0.6, 0.9]
    legend_elements = [plt.scatter([], [], c='black', s=(100 * s + 20), label=f'Node sensitivity: {s:.1f}') for s in sizes]
    plt.legend(handles=legend_elements, loc='best', title='Node size scale', frameon=True, facecolor='white', edgecolor='black', labelcolor='black')
    bn_path = os.path.join(model_dir, f"Top5PercentEdges_{'+'.join(edge_feature_modalities)}_BrainNetwork.png")
    plt.savefig(bn_path, dpi=300, bbox_inches='tight', facecolor='white')
    plt.show()
    plt.close()

def apply_best_model(model_dir, model, device, test_loader, pred_dir, subjects):
    # Load best model weights
    model.load_state_dict(torch.load(os.path.join(model_dir, "BestMetricModel.pth")))
    model.eval()
    os.makedirs(pred_dir, exist_ok=True)
    prob_values = []
    pred_values = []
    with torch.no_grad():
        for batch_data in test_loader:
            batch_data = batch_data.to(device)
            # Model inference
            outputs = model(batch_data)
            pred_probs = F.softmax(outputs, dim=1)
            pred_labels = torch.argmax(pred_probs, dim=1)
            prob_values.extend(pred_probs[:, 1].cpu().numpy())
            pred_values.extend(pred_labels.cpu().numpy())
    # Save predictions
    pred_df = pd.DataFrame({
        'ID': subjects,
        'Probability': prob_values,
        'PredictedSex': pred_values
    })
    pred_df.to_csv(os.path.join(pred_dir, "PredictedSex.csv"), index=False)

def calculate_test_metric(metric, gt_dir, pred_dir):
    # Load ground truth and predictions
    gt_df = pd.read_csv(os.path.join(gt_dir, 'Subjects_GT.csv'))
    pred_df = pd.read_csv(os.path.join(pred_dir, 'PredictedSex.csv'))
    # Merge on subject ID
    merged_df = pd.merge(gt_df, pred_df, on='ID', how='inner')
    if len(merged_df) == 0:
        print("Error: No matching subjects found!")
        return
    # Extract values
    true_labels = merged_df['Sex'].values
    pred_labels = merged_df['PredictedSex'].values
    pred_probs = merged_df['Probability'].values
    # Compute accuracy
    metric.reset()
    y_pred = torch.from_numpy(pred_labels).long()
    y_true = torch.from_numpy(true_labels).long()
    metric(y_pred, y_true)
    accuracy = metric.compute().item()
    # Compute confusion matrix
    cm = confusion_matrix(true_labels, pred_labels)
    tn, fp, fn, tp = cm.ravel()
    # Calculate additional metrics
    sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0  # Recall for positive class
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0  # Recall for negative class
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    f1 = 2 * (precision * sensitivity) / (precision + sensitivity) if (precision + sensitivity) > 0 else 0 
    # Print results
    print(f"\n{'='*50}")
    print(f"Classification performance:")
    print(f"\u2022 Accuracy: {accuracy:.3f}")
    print(f"\u2022 Sensitivity (Recall): {sensitivity:.3f}")
    print(f"\u2022 Specificity: {specificity:.3f}")
    print(f"\u2022 Precision: {precision:.3f}")
    print(f"\u2022 F1-Score: {f1:.3f}")
    print(f"\nConfusion matrix:")
    print(f"\u2022 True negatives (TN): {tn}")
    print(f"\u2022 False positives (FP): {fp}")
    print(f"\u2022 False negatives (FN): {fn}")
    print(f"\u2022 True positives (TP): {tp}")
    print(f"{'='*50}\n")
    # Save detailed results
    results_df = merged_df[['ID', 'Sex', 'PredictedSex', 'Probability']].copy()
    results_df['Correct'] = (results_df['Sex'] == results_df['PredictedSex']).astype(int)
    results_df.to_csv(os.path.join(pred_dir, "TestResults.csv"), index=False)
    # Save summary statistics
    summary_df = pd.DataFrame({
        'Accuracy': [accuracy],
        'Sensitivity': [sensitivity],
        'Specificity': [specificity],
        'Precision': [precision],
        'F1_Score': [f1],
        'True_Negatives': [tn],
        'False_Positives': [fp],
        'False_Negatives': [fn],
        'True_Positives': [tp]
    })
    summary_df.to_csv(os.path.join(pred_dir, "TestSummary.csv"), index=False)
    # Plot confusion matrix with color mapping for each cell
    fig, ax = plt.subplots(figsize=(10, 8))
    # Create custom colormap for each cell
    cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    # TN and TP: Blue (correct predictions)
    # FP and FN: Red (incorrect predictions)
    colors = np.array([[0.2, 0.4, 0.8, cm_normalized[0, 0]], # TN (blue)
                       [0.8, 0.2, 0.2, cm_normalized[0, 1]], # FP (red)
                       [0.8, 0.2, 0.2, cm_normalized[1, 0]], # FN (red)
                       [0.2, 0.4, 0.8, cm_normalized[1, 1]]]) # TP (blue)
    colors = colors.reshape(2, 2, 4)
    # Plot with custom colors
    for i in range(2):
        for j in range(2):
            ax.add_patch(plt.Rectangle((j, i), 1, 1, facecolor=colors[i, j], edgecolor='white', linewidth=2))
            # Add text annotations
            text_color = 'white' if cm_normalized[i, j] > 0.5 else 'black'
            ax.text(j + 0.5, i + 0.5, str(cm[i, j]), ha='center', va='center', fontsize=24, fontweight='bold', color=text_color)
            # Add percentage
            ax.text(j + 0.5, i + 0.7, f'({cm_normalized[i, j]*100:.1f}%)', ha='center', va='center', fontsize=12, color=text_color)
    ax.set_xlim(0, 2)
    ax.set_ylim(0, 2)
    ax.set_xticks([0.5, 1.5])
    ax.set_yticks([0.5, 1.5])
    ax.set_xticklabels(['Female', 'Male'], fontsize=14)
    ax.set_yticklabels(['Female', 'Male'], fontsize=14)
    ax.invert_yaxis()
    plt.xlabel('Predicted', fontsize=16, fontweight='bold')
    plt.ylabel('Actual', fontsize=16, fontweight='bold')
    plt.title(f"Confusion Matrix (Accuracy = {accuracy:.3f})", fontsize=18, fontweight='bold', pad=20)
    # Add cell labels
    ax.text(0.5, -0.15, 'TN', ha='center', va='center', fontsize=10, color='blue', fontweight='bold', transform=ax.transData)
    ax.text(1.5, -0.15, 'FP', ha='center', va='center', fontsize=10, color='red', fontweight='bold', transform=ax.transData)
    ax.text(0.5, 2.15, 'FN', ha='center', va='center', fontsize=10, color='red', fontweight='bold', transform=ax.transData)
    ax.text(1.5, 2.15, 'TP', ha='center', va='center', fontsize=10, color='blue', fontweight='bold', transform=ax.transData)
    plt.tight_layout()
    plt.savefig(os.path.join(pred_dir, "ConfusionMatrix.png"), dpi=300, bbox_inches='tight')
    plt.show()
    # Plot ROC curve
    fpr, tpr, _ = roc_curve(true_labels, pred_probs)
    roc_auc = auc(fpr, tpr)
    plt.figure(figsize=(8, 8))
    plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.3f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random classifier')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate', fontsize=12)
    plt.ylabel('True Positive Rate', fontsize=12)
    plt.title('ROC Curve', fontsize=14, fontweight='bold')
    plt.legend(loc="lower right", fontsize=11)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(pred_dir, "ROC_Curve.png"), dpi=300)
    plt.show()
    # Plot Precision-Recall curve
    precision_curve, recall_curve, _ = precision_recall_curve(true_labels, pred_probs)
    pr_auc = auc(recall_curve, precision_curve)
    plt.figure(figsize=(8, 8))
    plt.plot(recall_curve, precision_curve, color='blue', lw=2, label=f'PR curve (AUC = {pr_auc:.3f})')
    plt.axhline(y=true_labels.sum()/len(true_labels), color='navy', linestyle='--', lw=2, label='Random classifier')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('Recall', fontsize=12)
    plt.ylabel('Precision', fontsize=12)
    plt.title('Precision-Recall Curve', fontsize=14, fontweight='bold')
    plt.legend(loc="lower left", fontsize=11)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(pred_dir, "PR_Curve.png"), dpi=300)
    plt.show()

### Prepare Inputs

In [None]:
data_dir = os.path.join("SexClassification", "Datasets")
model_dir_prefix = "SexClassification"
model_name = "GAT" # Any supported model name: GCN, GraphSAGE (unimodal edge), GAT, TransformerConv (multimodal edge)
node_feature_modalities = ["PC", "GC"]
# If file is (N,): directly used as node feature
# If file is (N, N): mean of edges computed for each node
edge_feature_modalities = ["FA", "MD"]  

# Training parameters
test_size = 0.2
scaling = 'global' # Options: 'global', 'per_subject', 'none'
num_channels = [64, 64, 128]
batch_size = 16
max_epochs = 100
learning_rate = 1e-3
weight_decay = 1e-4
val_interval = 1
es_patience = 30

# Setup output directory and logging
node_str = '+'.join(node_feature_modalities) if node_feature_modalities else ''
edge_str = '+'.join(edge_feature_modalities) if edge_feature_modalities else ''
model_dir = f"{model_dir_prefix}_{model_name}_N[{node_str}]_E[{edge_str}]"
os.makedirs(model_dir, exist_ok=True)
log_file = os.path.join(model_dir, "Training.log")
logging.basicConfig(filename=log_file, level=logging.INFO, format="%(message)s")
logger = logging.getLogger()

# Save configuration
config = {
    'model_name': model_name,
    'node_feature_modalities': node_feature_modalities,
    'edge_feature_modalities': edge_feature_modalities,
    'scaling': scaling,
    'num_channels': num_channels
}
config_file = os.path.join(model_dir, "Config.json")
with open(config_file, 'w') as f:
    json.dump(config, f, indent=2)

### Read Data

In [None]:
set_determinism(seed=0)
train_ds = BrainConnectivityDataset(
    os.path.join(data_dir, "train"),
    model_name=model_name,
    node_feature_modalities=node_feature_modalities,
    edge_feature_modalities=edge_feature_modalities,
    scaling=scaling
)
train_loader, val_loader = load_data(
    train_ds, batch_size=batch_size, inference=False, test_size=test_size
)

# Check data shape
tr = next(iter(train_loader))
num_nodes_per_graph = tr.num_nodes // batch_size
num_node_features = tr.x.shape[1]
num_edge_features = tr.edge_attr.shape[1]
print('\nData shape for training:')
print(f'\u2022 Node features: ({batch_size} \u00D7 {tr.x.shape[0] // batch_size}, {tr.x.shape[1]}) \u00D7 {len(train_loader)}')
print(f'\u2022 Edge features: ({batch_size} \u00D7 {tr.edge_attr.shape[0] // batch_size}, {tr.edge_attr.shape[1]}) \u00D7 {len(train_loader)}')
print(f'\u2022 Labels: {tuple(tr.y.shape)} \u00D7 {len(train_loader)}')
vl = next(iter(val_loader))
print('\nData shape for validation:')
print(f'\u2022 Node features: ({batch_size} \u00D7 {vl.x.shape[0] // batch_size}, {vl.x.shape[1]}) \u00D7 {len(val_loader)}')
print(f'\u2022 Edge features: ({batch_size} \u00D7 {vl.edge_attr.shape[0] // batch_size}, {vl.edge_attr.shape[1]}) \u00D7 {len(val_loader)}')
print(f'\u2022 Labels: {tuple(vl.y.shape)} \u00D7 {len(val_loader)}')

# Visualize data
sample = next(iter(train_loader))
visualize_conn_matrix(sample, edge_feature_modalities, title="Training First Sample")
print("Value ranges for connectivity matrices:")
for i, modality in enumerate(edge_feature_modalities):
    edge_attr = sample.edge_attr[:, i]
    print(f"\u2022 {modality}: [{edge_attr.min():.3f}, {edge_attr.max():.3f}]")

### Train Model

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = get_model(model_name, num_node_features, num_edge_features, num_channels=num_channels)
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)
model.to(device)
print(f"\nSelected model: {model_name}")
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Number of trainable parameters: {num_params:,}")
criterion = nn.CrossEntropyLoss()
metric = Accuracy(task='binary').to(device)
model, epoch_loss_values, epoch_metric_values, metric_values = train_model(
    model_dir, model, device, train_loader, val_loader, logger,
    criterion, metric, max_epochs, learning_rate, weight_decay, val_interval, es_patience
)
plot_metric_values(model_dir, epoch_loss_values, epoch_metric_values, metric_values, val_interval)

# Visualize outcome
sample_indices = [20, 50] # Sex = 0, 1
for sample_index in sample_indices:
    single_dataset = Subset(train_ds, [sample_index - 1])
    single_loader = DataLoader(single_dataset, batch_size=1)
    sample = next(iter(single_loader))
    model.eval()
    with torch.no_grad():
        batch_data = sample.to(device)
        output = model(batch_data)
        pred_prob = torch.softmax(output, dim=1)
        pred_label = torch.argmax(pred_prob, dim=1)
    actual_label = sample.y.item()
    title = (f'Training Sample {sample_index}: '
             f'Predicted = {"Male" if pred_label.item() == 1 else "Female"} '
             f'({pred_prob[0, pred_label].item():.2f}), '
             f'Actual = {"Male" if actual_label == 1 else "Female"}')
    visualize_conn_matrix(sample, edge_feature_modalities, title=title)

### Occlusion Sensitivity Analysis

In [None]:
labels_file = os.path.join(data_dir, "brainnetome_labels.csv")
occlusion_sensitivity_analysis(
    model,
    val_loader,
    model_dir,
    edge_feature_modalities,
    style='diverging',
    directed=False,
    num_samples=20,
    labels_file=labels_file
)
atlas_path = os.path.join(data_dir, "BN_Atlas_246_1mm.nii.gz")
visualize_brain_network(
    model_dir, 
    atlas_path, 
    edge_feature_modalities,
    directed=False
)

### Inference

In [None]:
if scaling == 'global':
    test_ds = BrainConnectivityDataset(
        os.path.join(data_dir, "test"),
        model_name=model_name,
        node_feature_modalities=node_feature_modalities,
        edge_feature_modalities=edge_feature_modalities,
        scaling=scaling,
        global_stats=train_ds.global_stats
    )
else:
    test_ds = BrainConnectivityDataset(
        os.path.join(data_dir, "test"),
        model_name=model_name,
        node_feature_modalities=node_feature_modalities,
        edge_feature_modalities=edge_feature_modalities,
        scaling=scaling
    )
test_loader, subjects = load_data(test_ds, batch_size=1, inference=True)
pred_dir = os.path.join(model_dir, "Prediction")
apply_best_model(model_dir, model, device, test_loader, pred_dir, subjects)