### Import Modules

In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Subset
import time
import logging
import matplotlib.pyplot as plt
import sys
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
from torch_geometric.nn import GCNConv, GATConv, SAGEConv, ChebConv, global_mean_pool
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader
from torchmetrics import Accuracy
import seaborn as sns
import nibabel as nib
from nilearn import plotting
torch.manual_seed(42)

### Define Functions and Classes

In [None]:
class BrainConnectivityDataset(Dataset):
    def __init__(self, data_dir, edge_modalities, node_modalities=None, additional_variables=None, scaling=True):
        super().__init__()
        self.data_dir = data_dir
        self.edge_modalities = edge_modalities
        self.node_modalities = node_modalities if node_modalities else []
        self.additional_variables = additional_variables if additional_variables else []
        self.scaling = scaling
        self.df = pd.read_csv(os.path.join(data_dir, 'Subjects.csv'))
        self.subjects = self.df['No'].to_numpy()
    def len(self):
        return len(self.subjects)
    def get(self, idx):
        subject = self.subjects[idx]
        edge_index_list = []
        edge_attr_list = []
        for modality in self.edge_modalities:
            conn_matrix = pd.read_csv(
                os.path.join(self.data_dir, modality, f"{subject:03d}.csv"),
                header=None).values
            if not conn_matrix.shape[0] == conn_matrix.shape[1]:
                raise ValueError(f"Non-square connectivity matrix found in {modality}")
            if self.scaling:
                conn_matrix = self._scale_data(conn_matrix)
            edge_index, edge_attr = self._matrix_to_edge_index_attr(conn_matrix)
            edge_index_list.append(edge_index)
            edge_attr_list.append(edge_attr)
        node_features_list = []
        if self.node_modalities:
            for modality in self.node_modalities:
                node_features = pd.read_csv(
                    os.path.join(self.data_dir, modality, f"{subject:03d}.csv"),
                    header=None).values
                if self.scaling:
                    node_features = self._scale_data(node_features)
                node_features_list.append(torch.tensor(node_features, dtype=torch.float))
        else:
            node_features_list = []
            for modality in self.edge_modalities:
                conn_matrix = pd.read_csv(
                    os.path.join(self.data_dir, modality, f"{subject:03d}.csv"),
                    header=None).values
                if self.scaling:
                    conn_matrix = self._scale_data(conn_matrix)
                node_features_list.append(torch.tensor(conn_matrix, dtype=torch.float))
        node_features = torch.cat(node_features_list, dim=1)
        if not all(torch.equal(edge_index_list[0], edge_idx) for edge_idx in edge_index_list[1:]):
            raise ValueError(f"Different graph structures detected for subject {subject}")
        edge_attr = torch.cat(edge_attr_list, dim=1)
        additional_features = torch.tensor([
            self.df[var][idx] for var in self.additional_variables 
            if var != 'Sex'
        ], dtype=torch.float).view(1, -1)
        label = torch.tensor(self.df['Sex'][idx], dtype=torch.long) if 'Sex' in self.df.columns else None
        return Data(
            x=node_features,
            edge_index=edge_index_list[0],
            edge_attr=edge_attr,
            additional_features=additional_features,
            y=label,
            num_nodes=node_features.shape[0]
        )
    def load_data(self, batch_size, inference=False, test_size=0.2):
        if inference:
            test_loader = DataLoader(self, batch_size=1, num_workers=0, pin_memory=torch.cuda.is_available())
            subjects = self.subjects
            return test_loader, subjects
        else:
            indices = list(range(len(self)))
            labels = [self[i].y.item() for i in indices]
            train_indices, val_indices = train_test_split(indices, test_size=test_size, stratify=labels, random_state=42)
            train_ds = Subset(self, train_indices)
            val_ds = Subset(self, val_indices)
            train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=torch.cuda.is_available())
            val_loader = DataLoader(val_ds, batch_size=batch_size, num_workers=0, pin_memory=torch.cuda.is_available())
            return train_loader, val_loader, train_ds, val_ds
    def _calculate_degree(self, matrix, binary=True):
        if binary:
            matrix = (matrix != 0).astype(float)
        return np.sum(matrix, axis=1, keepdims=True)
    def _scale_data(self, data):
        if data.ndim == 2 and data.shape[0] == data.shape[1]:
            is_matrix=True
        else:
            is_matrix=False
        if is_matrix:
            values = data[~np.eye(data.shape[0], dtype=bool)]
        else:
            values = data.flatten()
        absmax = np.abs(values).max()
        scaled_data = data / (absmax + 1e-8)
        if is_matrix:
            scaled_data[np.eye(scaled_data.shape[0], dtype=bool)] = 1
        return scaled_data
    def _matrix_to_edge_index_attr(self, matrix):
        num_nodes_per_graph = matrix.shape[0]
        edge_index = []
        edge_attr = []
        for i in range(num_nodes_per_graph):
            for j in range(num_nodes_per_graph):
                if i != j:
                    edge_index.append([i, j])
                    edge_attr.append(matrix[i, j])
        return (torch.tensor(edge_index, dtype=torch.long).t(),
                torch.tensor(edge_attr, dtype=torch.float).view(-1, 1))

class BrainGCNConv(GCNConv):
    def __init__(self, in_channels, out_channels, num_nodes_per_graph, num_edge_features, modality_weights=None):
        super().__init__(in_channels, out_channels)
        self.num_nodes_per_graph = num_nodes_per_graph
        if modality_weights is None:
            self.modality_weights = nn.Parameter(torch.ones(num_edge_features))
        else:
            self.modality_weights = nn.Parameter(modality_weights)
    def forward(self, x, edge_index, edge_weight=None):
        if edge_weight is not None and edge_weight.dim() > 1:
            modality_weights = F.softmax(self.modality_weights, dim=0)
            outputs = []
            num_modalities = edge_weight.size(1)
            num_nodes_per_graph = self.num_nodes_per_graph
            for i in range(num_modalities):
                modality_weight = edge_weight[:, i] * modality_weights[i]
                if x.size(1) == num_nodes_per_graph * num_modalities:
                    x_modality = x[:, i*num_nodes_per_graph:(i+1)*num_nodes_per_graph]
                elif x.size(1) == num_modalities:
                    x_modality = x[:, i:i+1]
                else:
                    x_modality = x
                outputs.append(super().forward(x_modality, edge_index, modality_weight))
            return torch.mean(torch.stack(outputs), dim=0)
        else:
            return super().forward(x, edge_index, edge_weight)
        
class BrainGATConv(GATConv):
    def __init__(self, in_channels, out_channels, num_nodes_per_graph, num_edge_features, modality_weights=None, **kwargs):
        heads = kwargs.get('heads', 1)
        super().__init__(in_channels, out_channels, heads=heads, concat=True)
        self.num_nodes_per_graph = num_nodes_per_graph
        if modality_weights is None:
            self.modality_weights = nn.Parameter(torch.ones(num_edge_features))
        else:
            self.modality_weights = nn.Parameter(modality_weights)
    def forward(self, x, edge_index, edge_weight=None):
        if edge_weight is not None and edge_weight.dim() > 1:
            modality_weights = F.softmax(self.modality_weights, dim=0)
            outputs = []
            num_modalities = edge_weight.size(1)
            num_nodes_per_graph = self.num_nodes_per_graph
            for i in range(num_modalities):
                modality_weight = edge_weight[:, i] * modality_weights[i]
                if x.size(1) == num_nodes_per_graph * num_modalities:
                    x_modality = x[:, i*num_nodes_per_graph:(i+1)*num_nodes_per_graph]
                elif x.size(1) == num_modalities:
                    x_modality = x[:, i:i+1]
                else:
                    x_modality = x
                outputs.append(super().forward(x_modality, edge_index, edge_attr=modality_weight.unsqueeze(1)))
            return torch.mean(torch.stack(outputs), dim=0)
        else:
            return super().forward(x, edge_index)

class BrainSAGEConv(SAGEConv):
    def __init__(self, in_channels, out_channels, num_nodes_per_graph, num_edge_features, modality_weights=None):
        super().__init__(in_channels, out_channels, normalize=True)
        self.num_nodes_per_graph = num_nodes_per_graph
        if modality_weights is None:
            self.modality_weights = nn.Parameter(torch.ones(num_edge_features))
        else:
            self.modality_weights = nn.Parameter(modality_weights)
    def forward(self, x, edge_index, edge_weight=None):
        if edge_weight is not None and edge_weight.dim() > 1:
            modality_weights = F.softmax(self.modality_weights, dim=0)
            outputs = []
            num_modalities = edge_weight.size(1)
            num_nodes_per_graph = self.num_nodes_per_graph
            for i in range(num_modalities):
                modality_weight = edge_weight[:, i] * modality_weights[i]
                if x.size(1) == num_nodes_per_graph * num_modalities:
                    x_modality = x[:, i*num_nodes_per_graph:(i+1)*num_nodes_per_graph]
                elif x.size(1) == num_modalities:
                    x_modality = x[:, i:i+1]
                else:
                    x_modality = x
                x_j = self.lin_r(x_modality)
                out = self.propagate(edge_index, x=x_modality, edge_weight=modality_weight)
                out = self.lin_l(out)
                out = out + x_j
                if self.normalize:
                    out = F.normalize(out, p=2., dim=-1)
                outputs.append(out)
            return torch.mean(torch.stack(outputs), dim=0)
        else:
            x_j = self.lin_r(x)
            out = self.propagate(edge_index, x=x, edge_weight=edge_weight)
            out = self.lin_l(out)
            out = out + x_j
            if self.normalize:
                out = F.normalize(out, p=2., dim=-1)  
            return out

class BrainChebConv(ChebConv):
    def __init__(self, in_channels, out_channels, num_nodes_per_graph, num_edge_features, modality_weights=None, K=2):
        super().__init__(in_channels, out_channels, K=K)
        self.num_nodes_per_graph = num_nodes_per_graph
        if modality_weights is None:
            self.modality_weights = nn.Parameter(torch.ones(num_edge_features))
        else:
            self.modality_weights = nn.Parameter(modality_weights)
    def forward(self, x, edge_index, edge_weight=None):
        if edge_weight is not None and edge_weight.dim() > 1:
            modality_weights = F.softmax(self.modality_weights, dim=0)
            outputs = []
            num_modalities = edge_weight.size(1)
            num_nodes_per_graph = self.num_nodes_per_graph
            for i in range(num_modalities):
                modality_weight = edge_weight[:, i] * modality_weights[i]
                if x.size(1) == num_nodes_per_graph * num_modalities:
                    x_modality = x[:, i*num_nodes_per_graph:(i+1)*num_nodes_per_graph]
                elif x.size(1) == num_modalities:
                    x_modality = x[:, i:i+1]
                else:
                    x_modality = x
                outputs.append(super().forward(x_modality, edge_index, modality_weight))
            return torch.mean(torch.stack(outputs), dim=0)
        else:
            return super().forward(x, edge_index, edge_weight)
    def __repr__(self):
        return f'{self.__class__.__name__}({self.in_channels}, {self.out_channels}, K={self.K})'

class BrainGNN(torch.nn.Module):
    def __init__(self, conv_layer, num_nodes_per_graph, num_node_features, num_edge_features, num_additional_features,
                 num_channels=[32, 32, 64, 64, 128], modality_weights=None, **kwargs):
        super().__init__()
        K = kwargs.get('K', 2) # ChebNet
        heads = kwargs.get('heads', 4)  # GAT
        self.conv_layers = nn.ModuleList()
        ConvLayer = self._get_conv_layer(conv_layer)
        if conv_layer == "GAT":
            for channel in num_channels:
                if channel % heads != 0:
                    raise ValueError(f"All channel numbers must be divisible by number of heads ({heads})")
            channels_to_use = [c // heads for c in num_channels]
        else:
            channels_to_use = num_channels
        conv_params = {
            'num_nodes_per_graph': num_nodes_per_graph, 
            'num_edge_features': num_edge_features,
            'modality_weights': modality_weights
        }
        if conv_layer == "ChebNet":
            conv_params['K'] = K  
        if conv_layer == "GAT":
            conv_params['heads'] = heads
        self.conv_layers.append(
            ConvLayer(num_node_features, channels_to_use[0], **conv_params)
        )
        for i in range(len(channels_to_use) - 1):
            num_in_channels = channels_to_use[i] * (heads if conv_layer == "GAT" else 1)
            self.conv_layers.append(
                ConvLayer(num_in_channels, channels_to_use[i + 1], **conv_params)
            )
        self.batch_norms = nn.ModuleList()
        for i, channel in enumerate(channels_to_use):
            if conv_layer == "GAT":
                self.batch_norms.append(nn.BatchNorm1d(channel * heads))
            else:
                self.batch_norms.append(nn.BatchNorm1d(channel))
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.2)
        num_final_channels = channels_to_use[-1] * (heads if conv_layer == "GAT" else 1)
        self.fc_layers = nn.Sequential(
            nn.Linear(num_final_channels + num_additional_features, num_final_channels // 2),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(num_final_channels // 2, num_final_channels // 4),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(num_final_channels // 4, 2)
        )
    def forward(self, data):
        x = data.x
        for i, (conv, bn) in enumerate(zip(self.conv_layers, self.batch_norms)):
            identity = x
            x = conv(x, data.edge_index, data.edge_attr)
            x = bn(x)
            x = self.relu(x)
            x = self.dropout(x)
            if i > 0 and x.size() == identity.size():
                x = x + identity
        x = global_mean_pool(x, data.batch)
        x = torch.cat([x, data.additional_features], dim=1)
        x = self.fc_layers(x)
        return x
    def _get_conv_layer(self, conv_layer):
        conv_layers = {
            "GCN": BrainGCNConv,
            "GAT": BrainGATConv,
            "GraphSAGE": BrainSAGEConv,
            "ChebNet": BrainChebConv,
        }
        if conv_layer not in conv_layers:
            raise ValueError(f"Unsupported convolution layer: {conv_layer}")
        return conv_layers[conv_layer]

def get_model(model_name, num_nodes_per_graph, num_node_features, num_edge_features, num_additional_features,
              num_channels=[32, 32, 64, 64, 128], modality_weights=None):
    if model_name == "GCN":
        return BrainGNN(
            conv_layer="GCN",
            num_nodes_per_graph=num_nodes_per_graph,
            num_node_features=num_node_features,
            num_edge_features=num_edge_features,
            num_additional_features=num_additional_features,
            num_channels=num_channels,
            modality_weights=modality_weights,
        )
    elif model_name == "GAT":
        return BrainGNN(
            conv_layer="GAT",
            num_nodes_per_graph=num_nodes_per_graph,
            num_node_features=num_node_features,
            num_edge_features=num_edge_features,
            num_additional_features=num_additional_features,
            num_channels=num_channels,
            modality_weights=modality_weights,
            heads=4
        )
    elif model_name == "GraphSAGE":
        return BrainGNN(
            conv_layer="GraphSAGE",
            num_nodes_per_graph=num_nodes_per_graph,
            num_node_features=num_node_features,
            num_edge_features=num_edge_features,
            num_additional_features=num_additional_features,
            num_channels=num_channels,
            modality_weights=modality_weights
        )
    elif model_name == "ChebNet":
        return BrainGNN(
            conv_layer="ChebNet",
            num_nodes_per_graph=num_nodes_per_graph,
            num_node_features=num_node_features,
            num_edge_features=num_edge_features,
            num_additional_features=num_additional_features,
            num_channels=num_channels,
            modality_weights=modality_weights,
            K=2
        )

def visualize_conn_matrix(sample, edge_modalities, title=None):
    _, axs = plt.subplots(1, len(edge_modalities), figsize=(len(edge_modalities) * 4, 4), squeeze=False)
    batch_size = len(sample.ptr) - 1
    num_nodes_per_graph = sample.num_nodes // batch_size
    num_edges_per_graph = sample.edge_index.size(1) // batch_size 
    for col, modality in enumerate(edge_modalities):
        edge_attr = sample.edge_attr[:num_edges_per_graph, col]
        edge_index = sample.edge_index[:, :num_edges_per_graph]
        conn_matrix = torch.ones((num_nodes_per_graph, num_nodes_per_graph)).to(edge_attr.device)
        conn_matrix[edge_index[0], edge_index[1]] = edge_attr
        ax = axs[0, col]
        matrix_numpy = conn_matrix.cpu().numpy()
        abs_max = np.abs(matrix_numpy).max()
        im = ax.imshow(matrix_numpy, cmap='hot',
                       vmin=-abs_max if np.any(matrix_numpy < 0) else 0, vmax=abs_max)
        ax.set_title(f'{modality}')
        plt.colorbar(im, ax=ax)
        ax.set_aspect('equal')
        ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
    if title:
        plt.suptitle(title, fontsize=12, y=1.05)
    plt.tight_layout()
    plt.show()

def train_one_epoch(model, device, train_loader, optimizer, criterion, scaler, metric):
    model.train()
    epoch_loss = 0.0
    metric.reset()
    for batch_data in train_loader:
        batch_data = batch_data.to(device)
        optimizer.zero_grad()
        with torch.cuda.amp.autocast(enabled=scaler is not None):
            outputs = model(batch_data)
            loss = criterion(outputs, batch_data.y) 
        if scaler:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()  
        epoch_loss += loss.item()
        metric(outputs[:, 1], batch_data.y)
    epoch_metric = metric.compute().item()
    return epoch_loss / len(train_loader), epoch_metric

def validate_one_epoch(model, device, val_loader, metric):
    model.eval()
    metric.reset()
    with torch.no_grad():
        for batch_data in val_loader:
            batch_data = batch_data.to(device)
            outputs = model(batch_data)
            metric(outputs[:, 1], batch_data.y)
    return metric.compute().item()

class EarlyStopping:
    def __init__(self, patience=30, delta=0):
        self.patience = patience
        self.delta = delta
        self.best_score = None
        self.early_stop = False
        self.counter = 0
    def __call__(self, metric):
        score = metric
        if self.best_score is None:
            self.best_score = score
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.counter = 0

def train_model(model_dir, model, device, train_loader, val_loader, logger,
        criterion, metric, max_epochs=100, learning_rate=1e-4, weight_decay=1e-5, val_interval=1, es_patience=30):
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs)
    scaler = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None
    start_time = time.time()
    best_metric = -1
    best_metric_epoch = -1
    early_stopping = EarlyStopping(patience=es_patience, delta=0)
    epoch_loss_values, epoch_metric_values, metric_values = [], [], []
    for epoch in range(max_epochs):
        epoch_start_time = time.time()
        epoch_loss, epoch_metric = train_one_epoch(
            model, device, train_loader, optimizer, criterion, scaler, metric)
        epoch_loss_values.append(epoch_loss)
        epoch_metric_values.append(epoch_metric)
        if (epoch + 1) % val_interval == 0:
            val_metric = validate_one_epoch(model, device, val_loader, metric)
            metric_values.append(val_metric)
            if val_metric > best_metric:
                best_metric = val_metric
                best_metric_epoch = epoch + 1
                best_model_state = model.state_dict()
                torch.save(model.state_dict(), os.path.join(model_dir, "BestMetricModel.pth"))
                logger.info(f"Best accuracy: {best_metric:.4f} at epoch {best_metric_epoch}")
            early_stopping(val_metric)
            if early_stopping.early_stop:
                logger.info(f"Early stopping triggered at epoch {epoch + 1}")
                print(f"; Early stopping triggered at epoch {epoch + 1}", end="")
                break
        epoch_end_time = time.time()
        logger.info(f"Epoch {epoch + 1} computed for {(epoch_end_time - epoch_start_time)/60:.2f} mins - Training loss: {epoch_loss:.4f}, Training accuracy: {epoch_metric:.4f}, Validation accuracy: {val_metric:.4f}")
        lr_scheduler.step()
        sys.stdout.write(f"\rEpoch {epoch + 1}/{max_epochs} completed")
        sys.stdout.flush()
    end_time = time.time()
    total_time = end_time - start_time
    logger.info(f"Best accuracy: {best_metric:.3f} at epoch {best_metric_epoch}; Total time consumed: {total_time/60:.2f} mins")
    print(f"\nBest accuracy: {best_metric:.3f} at epoch {best_metric_epoch}; Total time consumed: {total_time/60:.2f} mins")
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
    return model, epoch_loss_values, epoch_metric_values, metric_values

def plot_metric_values(model_dir, epoch_loss_values, epoch_metric_values, metric_values, val_interval=1):
    _, axs = plt.subplots(1, 2, figsize=(12, 6))
    axs[0].plot( [i + 1 for i in range(len(epoch_loss_values))], epoch_loss_values, label='Training Loss', color='red')
    axs[0].set_title('Training Loss')
    axs[0].set_xlabel('Epoch')
    axs[0].set_ylabel('Loss')
    axs[1].plot([i + 1 for i in range(len(epoch_metric_values))], epoch_metric_values, label='Training Accuracy', color='red')
    axs[1].plot([val_interval * (i + 1) for i in range(len(metric_values))], metric_values, label='Validation Accuracy', color='blue')
    axs[1].set_title('Training Accuracy vs. Validation Accuracy')
    axs[1].set_xlabel('Epoch')
    axs[1].set_ylabel('Accuracy')
    axs[1].legend()
    plt.tight_layout()
    plt.savefig(os.path.join(model_dir, "Performance.png"), dpi=300)

def occlusion_based_sensitivity(model, sample, style='diverging', directed=False, window_size=10):
    model.eval()
    device = next(model.parameters()).device
    sample = sample.to(device)
    baseline_output = model(sample)
    baseline_pred = F.softmax(baseline_output, dim=1)[0, 1].item()
    num_nodes_per_graph = sample.num_nodes
    sensitivity_matrix = torch.zeros((num_nodes_per_graph, num_nodes_per_graph), device=device)
    if directed:
        node_pairs = [(i, j) for i in range(num_nodes_per_graph) for j in range(num_nodes_per_graph) if i != j]
    else:
        node_pairs = [(i, j) for i in range(num_nodes_per_graph) for j in range(i+1, num_nodes_per_graph)]
    for window_start in range(0, len(node_pairs), window_size):
        window_end = min(window_start + window_size, len(node_pairs))
        current_pairs = node_pairs[window_start:window_end]
        temp_data = Data(
            x=sample.x,
            edge_index=sample.edge_index,
            edge_attr=sample.edge_attr.clone(),
            additional_features=sample.additional_features,
            y=sample.y,
            num_nodes=sample.num_nodes,
            batch=sample.batch
        ).to(device)
        mask = torch.zeros_like(temp_data.edge_index[0], dtype=torch.bool)
        for i, j in current_pairs:
            if directed:
                mask |= (temp_data.edge_index[0] == i) & (temp_data.edge_index[1] == j)
            else:
                mask |= ((temp_data.edge_index[0] == i) & (temp_data.edge_index[1] == j)) | \
                        ((temp_data.edge_index[0] == j) & (temp_data.edge_index[1] == i))
        temp_data.edge_attr[mask] = 0
        with torch.no_grad():
            output = model(temp_data)
            pred = F.softmax(output, dim=1)[0, 1].item()
        if style == 'diverging':
            sensitivity = pred - baseline_pred
        elif style == 'absolute':
            sensitivity = abs(pred - baseline_pred)
        for i, j in current_pairs:
            sensitivity_matrix[i, j] = sensitivity
            if not directed:
                sensitivity_matrix[j, i] = sensitivity
    return sensitivity_matrix

def occlusion_sensitivity_analysis(model, loader, model_dir, modalities, style='diverging', directed=False, num_samples=20):
    accumulated_sensitivity = None
    accumulated_samples = 0
    sampling_complete = False
    for batch_data in loader:
        if sampling_complete:
            break
        batch_size = len(batch_data.ptr) - 1
        for graph_idx in range(batch_size):
            if accumulated_samples >= num_samples:
                sampling_complete = True
                break
            start_idx = batch_data.ptr[graph_idx].item()
            end_idx = batch_data.ptr[graph_idx + 1].item()
            sample_x = batch_data.x[start_idx:end_idx, :]
            edge_mask = (batch_data.edge_index[0] >= start_idx) & (batch_data.edge_index[0] < end_idx)
            sample_edge_index = batch_data.edge_index[:, edge_mask] - start_idx
            sample_edge_attr = batch_data.edge_attr[edge_mask, :]
            sample_additional_features = batch_data.additional_features[graph_idx, :]
            sample_y = batch_data.y[graph_idx]
            sample_num_nodes = end_idx - start_idx
            sample = Data(
                x=sample_x,
                edge_index=sample_edge_index,
                edge_attr=sample_edge_attr,
                additional_features=sample_additional_features.unsqueeze(-1),
                y=sample_y,
                num_nodes=sample_num_nodes,
                batch=torch.zeros(sample_num_nodes, dtype=torch.long)
            )
            accumulated_samples += 1
            sensitivity_matrix = occlusion_based_sensitivity(model, sample, style=style, directed=directed)
            if accumulated_sensitivity is None:
                accumulated_sensitivity = sensitivity_matrix
            else:
                accumulated_sensitivity += sensitivity_matrix
    mean_sensitivity = accumulated_sensitivity / num_samples
    df_sensitivity = pd.DataFrame(mean_sensitivity.cpu().numpy())
    df_sensitivity.to_csv(os.path.join(model_dir, f"SensitivityMap_{'_'.join(modalities)}.csv"), index=False, header=False)
    num_nodes_per_graph = mean_sensitivity.shape[0]
    if directed:
        values = mean_sensitivity.view(-1)
        mask = ~torch.eye(num_nodes_per_graph, dtype=torch.bool, device=mean_sensitivity.device).view(-1)
        values = values[mask]
        abs_values = torch.abs(values)
        threshold = torch.quantile(abs_values, 0.95)
        high_sensitivity_pairs = []
        for i in range(num_nodes_per_graph):
            for j in range(num_nodes_per_graph):
                if i != j and abs(mean_sensitivity[i, j]) >= threshold:
                    high_sensitivity_pairs.append({
                        'Node_i': i + 1,
                        'Node_j': j + 1,
                        'Sensitivity': mean_sensitivity[i, j].item()
                    })
    else:
        upper_tri = torch.triu(mean_sensitivity, diagonal=1)
        values = upper_tri.view(-1)
        values = values[values != 0]
        abs_values = torch.abs(values)
        threshold = torch.quantile(abs_values, 0.95)
        high_sensitivity_pairs = []
        for i in range(num_nodes_per_graph):
            for j in range(i+1, num_nodes_per_graph):
                if abs(mean_sensitivity[i, j]) >= threshold:
                    high_sensitivity_pairs.append({
                        'Node_i': i + 1,
                        'Node_j': j + 1,
                        'Sensitivity': mean_sensitivity[i, j].item()
                    })
    df_top_pairs = pd.DataFrame(high_sensitivity_pairs)
    df_top_pairs = df_top_pairs.sort_values('Sensitivity', key=abs, ascending=False)
    df_top_pairs.to_csv(os.path.join(model_dir, f"Top5PercentEdges_{'+'.join(modalities)}.csv"), index=False)
    plt.figure(figsize=(10, 8))
    sensitivity_numpy = mean_sensitivity.cpu().numpy()
    heatmap_params = {
        'xticklabels': False,
        'yticklabels': False,
        'square': True,
    }
    if style == 'diverging':
        heatmap_params.update({
            'cmap': 'RdBu_r',
            'center': 0,
            'vmin': -np.abs(sensitivity_numpy).max(),
            'vmax': np.abs(sensitivity_numpy).max()
        })
    elif style == 'absolute':
        heatmap_params.update({
            'cmap': 'hot',
            'vmin': 0,
            'vmax': sensitivity_numpy.max()
        })
    sns.heatmap(sensitivity_numpy, **heatmap_params)
    plt.title(f"Sensitivity Map: {'+'.join(modalities)}", fontsize=16, pad=20)
    plt.tight_layout()
    plt.savefig(os.path.join(model_dir, f"Sensitivity_{'+'.join(modalities)}.png"), dpi=300, bbox_inches='tight')
    plt.show()
    plt.close()

def visualize_brain_network(model_dir, atlas_path, modalities, directed=False):
    edges_filename = f"Top5PercentEdges_{'+'.join(modalities)}.csv"
    edges_path = os.path.join(model_dir, edges_filename)
    edges_df = pd.read_csv(edges_path)
    edges_df['Node_i'] = edges_df['Node_i'] - 1
    edges_df['Node_j'] = edges_df['Node_j'] - 1
    atlas_img = nib.load(atlas_path)
    coords = plotting.find_parcellation_cut_coords(atlas_img)
    n_nodes = len(coords)
    connectivity_matrix = np.zeros((n_nodes, n_nodes))
    max_abs_sensitivity = np.abs(edges_df['Sensitivity']).max()
    edges_df['Scaled_Sensitivity'] = edges_df['Sensitivity'] / max_abs_sensitivity
    for _, row in edges_df.iterrows():
        i, j = int(row['Node_i']), int(row['Node_j'])
        connectivity_matrix[i, j] = row['Scaled_Sensitivity']
        if not directed:
            connectivity_matrix[j, i] = row['Scaled_Sensitivity']
    if directed:
        out_sensitivity = np.sum(connectivity_matrix, axis=1)
        in_sensitivity = np.sum(connectivity_matrix, axis=0)
        node_sensitivity =  np.mean([out_sensitivity, in_sensitivity], axis=0)
    else:
        node_sensitivity = np.sum(connectivity_matrix, axis=1)
    max_abs_node_sensitivity = np.abs(node_sensitivity).max()
    node_sensitivity = node_sensitivity / max_abs_node_sensitivity
    node_colors = np.zeros((n_nodes, 3))
    positive_mask = node_sensitivity > 0
    negative_mask = node_sensitivity < 0
    node_colors[positive_mask] = np.array([1, 0, 0])
    node_colors[negative_mask] = np.array([0, 0, 1])
    node_colors *= np.abs(node_sensitivity)[:, np.newaxis]
    alpha = 0.3
    node_colors = [(plt.cm.colors.to_rgba(c, alpha)) for c in node_colors]
    plt.figure(figsize=(12, 12))
    plotting.plot_connectome(
        connectivity_matrix, coords,
        node_color=node_colors,
        node_size=100 * np.abs(node_sensitivity) + 20,
        edge_cmap='seismic',
        edge_vmin=-1,
        edge_vmax=1,
        colorbar=True,
        edge_threshold=None,
        annotate=False,
        black_bg=False
    )
    plt.suptitle("Connectional and regional sensitivity", bbox=dict(facecolor='white', edgecolor=None))
    sizes = [0.3, 0.6, 0.9]
    legend_elements = [plt.scatter([], [], c='black', s=(100 * s + 20), label=f'Node sensitivity: {s:.1f}') for s in sizes]
    plt.legend(handles=legend_elements, loc='best', 
              title='Node size scale', frameon=True, 
              facecolor='white', edgecolor='black', labelcolor='black')
    bn_path = os.path.join(model_dir, f"Top5PercentEdges_{'+'.join(modalities)}_BrainNetwork.tif")
    plt.savefig(bn_path, dpi=300, bbox_inches='tight', facecolor='white')
    plt.show()
    plt.close()
    print("\nNetwork summary:")
    print(f"Number of nodes: {n_nodes}")
    print(f"Number of edges: {len(edges_df)}")
    print(f"Maximum absolute sensitivity: {max_abs_sensitivity:.6f}")
    print(f"Nodes with highest positive sensitivity:")
    top_positive = np.argsort(node_sensitivity)[-5:][::-1]
    for node in top_positive:
        print(f"Node {node+1}: {node_sensitivity[node]:.3f}")
    print(f"Nodes with lowest negative sensitivity:")
    top_negative = np.argsort(node_sensitivity)[:5]
    for node in top_negative:
        print(f"Node {node+1}: {node_sensitivity[node]:.3f}")

def apply_best_model(model_dir, model, device, test_loader, pred_dir, subjects):
    model.load_state_dict(torch.load(os.path.join(model_dir, "BestMetricModel.pth")))
    model.eval()
    os.makedirs(pred_dir, exist_ok=True)
    prob_values = np.array([])
    pred_values = np.array([])
    with torch.no_grad():
        for batch_data in test_loader:
            batch_data = batch_data.to(device)
            outputs = model(batch_data)
            pred_probs = F.softmax(outputs, dim=1)
            pred_labels = torch.argmax(pred_probs, dim=1)
            prob_values = np.append(prob_values, pred_probs[:, 1].cpu().numpy())
            pred_values = np.append(pred_values, pred_labels.cpu().numpy())
    df = pd.DataFrame({
        'No': subjects,
        'Pr': prob_values,
        'PredSex': pred_values
    })
    df.to_csv(os.path.join(pred_dir, "PredSex.csv"), index=False)

### Prepare Inputs

In [None]:
data_dir = "SexClassification"
model_dir_prefix = "SexClassification"
model_name = "GCN" # any supported model name: GCN, GAT, GraphSAGE, ChebNet
edge_modalities = ["PC", "Count"]
node_modalities = None
additional_variables = ["Sex", "Age"]
test_size = 0.2
scaling = True
num_channels = [32, 32, 64, 64, 128]
batch_size = 32
max_epochs = 100
learning_rate = 1e-3
weight_decay = 5e-4
val_interval = 1
es_patience = 30

model_dir = f"{model_dir_prefix}_{model_name}"
os.makedirs(model_dir, exist_ok=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
print('Device:', device)
log_file = os.path.join(model_dir, "Prediction.log")
logging.basicConfig(filename=log_file, level=logging.INFO, format="%(message)s")
logger = logging.getLogger()

### Read Data

In [None]:
ds = BrainConnectivityDataset(
    data_dir=os.path.join(data_dir, "train"),
    edge_modalities=edge_modalities,
    node_modalities=node_modalities,
    additional_variables=additional_variables,
    scaling=scaling
)
train_loader, val_loader, train_ds, val_ds = ds.load_data(batch_size=batch_size, inference=False, test_size=test_size)

# Check data shape
tr = next(iter(train_loader))
num_nodes_per_graph = tr.num_nodes // batch_size
if node_modalities:
    if len(node_modalities) == len(edge_modalities):
        num_node_features = tr.x.shape[1] // len(node_modalities)
    else:
        num_node_features = tr.x.shape[1]
else:
    num_node_features = num_nodes_per_graph
print('Data shape for training:')
print(f'\u2022 Node features: ({batch_size} \u00D7 {tr.x.shape[0] // batch_size}, {tr.x.shape[1]}) \u00D7 {len(train_loader)}')
print(f'\u2022 Edge features: ({batch_size} \u00D7 {tr.edge_attr.shape[0] // batch_size}, {tr.edge_attr.shape[1]}) \u00D7 {len(train_loader)}')
print(f'\u2022 Additional features: {tuple(tr.additional_features.shape)} \u00D7 {len(train_loader)}')
print(f'\u2022 Labels: {tuple(tr.y.shape)} \u00D7 {len(train_loader)}')
vl = next(iter(val_loader))
print('\nData shape for validation:')
print(f'\u2022 Node features: ({batch_size} \u00D7 {vl.x.shape[0] // batch_size}, {vl.x.shape[1]}) \u00D7 {len(val_loader)}')
print(f'\u2022 Edge features: ({batch_size} \u00D7 {vl.edge_attr.shape[0] // batch_size}, {vl.edge_attr.shape[1]}) \u00D7 {len(val_loader)}')
print(f'\u2022 Additional features: {tuple(vl.additional_features.shape)} \u00D7 {len(val_loader)}')
print(f'\u2022 Labels: {tuple(vl.y.shape)} \u00D7 {len(val_loader)}')

# Visualize data
sample = next(iter(train_loader))
visualize_conn_matrix(sample, edge_modalities)
print("Value ranges for connectivity matrices:")
for i, modality in enumerate(edge_modalities):
    edge_attr = sample.edge_attr[:, i]
    print(f"\u2022 {modality}: [{edge_attr.min():.3f}, {edge_attr.max():.3f}]")

### Train Model

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_edge_features = len(edge_modalities)
num_additional_features = 1 # Age
model = get_model(model_name, num_nodes_per_graph, num_node_features, num_edge_features, num_additional_features, num_channels=num_channels).to(device)
print(f"Selected model: {model_name}")
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
num_feature_extractor_params = sum(p.numel() for name, p in model.named_parameters() if p.requires_grad and not name.startswith('fc_layers'))
num_fclayers_params = sum(p.numel() for p in model.fc_layers.parameters() if p.requires_grad)
print(f"Number of trainable parameters: {num_params} = {num_feature_extractor_params} + {num_fclayers_params} for feature extractor and FC layers")

criterion = nn.CrossEntropyLoss()
metric = Accuracy(task='binary').to(device)
model, epoch_loss_values, epoch_metric_values, metric_values = train_model(model_dir, model, device, train_loader, val_loader,
    logger, criterion, metric, max_epochs, learning_rate, weight_decay, val_interval, es_patience)
plot_metric_values(model_dir, epoch_loss_values, epoch_metric_values, metric_values, val_interval)

# Visualize outcome
sample_indices = [20, 40]
for row, sample_index in enumerate(sample_indices):
    single_dataset = Subset(val_ds, [sample_index - 1])
    single_loader = DataLoader(single_dataset, batch_size=1)
    sample = next(iter(single_loader))
    model.eval()
    with torch.no_grad():
        batch_data = sample.to(device)
        output = model(batch_data)
        pred_prob = torch.softmax(output, dim=1)
        pred_label = torch.argmax(pred_prob, dim=1)
    actual_label = sample.y.item()
    title = f'Sample {sample_index}: Predicted = {"Male" if pred_label.item() == 1 else "Female"} ({pred_prob[0, pred_label].item():.2f}), Actual = {"Male" if actual_label == 1 else "Female"}'
    visualize_conn_matrix(sample, edge_modalities, title=title)

### Occlusion Sensitivity Analysis

In [None]:
occlusion_sensitivity_analysis(
    model, 
    val_loader,
    model_dir, 
    edge_modalities,
    style='diverging',
    directed=False,
    num_samples=20
)

atlas_path = os.path.join(data_dir, "BN_Atlas_246_1mm.nii.gz")
visualize_brain_network(model_dir, atlas_path, edge_modalities, directed=False)

### Inference

In [None]:
test_ds = BrainConnectivityDataset(
    data_dir=os.path.join(data_dir, "test"),
    edge_modalities=edge_modalities,
    node_modalities=node_modalities,
    additional_variables=additional_variables,
    scaling=scaling
)
test_loader, subjects = test_ds.load_data(batch_size=1, inference=True) 
pred_dir = os.path.join(model_dir, "Prediction")
apply_best_model(model_dir, model, device, test_loader, pred_dir, subjects)