# fNIRS Graph Algorithm Pipeline

## Section 1: Data Preparation

This section contains:
- Dataset class for loading fNIRS data
- Feature extraction (node features: 6 statistical features, edge features: correlation + coherence)
- Normalization statistics computation
- Data transformations (standardization, augmentation)
- K-fold data loader creation with subject-level stratification

In [1]:
# Deep Learning and PyTorch
import torch
import torch_geometric
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader
import torchmetrics
from torch_geometric.utils import dense_to_sparse
from torchinfo import summary

# Data Processing
import numpy as np
import helper_utils

# System and File Operations
import os
import glob
from pathlib import Path

# Type Hints
from typing import Dict, Tuple, Union, List, Optional

# Utilities
import warnings
import configparser
import matplotlib.pyplot as plt
import pickle
import importlib
importlib.reload(helper_utils)
helper_utils.set_seed(42)
import optuna
from pprint import pprint
import inspect

# Configuration
warnings.filterwarnings("ignore", category=UserWarning)

# Version Information
print(f"Torch version: {torch.__version__}")
print(f"Cuda available: {torch.cuda.is_available()}")
print(f"Torch geometric version: {torch_geometric.__version__}")

Torch version: 2.6.0+cu124
Cuda available: True
Torch geometric version: 2.6.1


### 1.1 Feature Extraction Functions

In [2]:
def get_node_features(data: np.ndarray, channels_first: bool = True) -> torch.Tensor:
    """
    Extract node features from fNIRS data using statistical measures.

    Args:
        data: Input fNIRS data array
        channels_first: If True, data shape is (C, T), else (T, C)

    Returns:
        all_node_feats: Tensor of shape (C, F) with statistical features per channel
    """
    stats = helper_utils.compute_statistical_features(data, channels_first=channels_first)

    FEATURE_KEYS = ("mean", "min", "max", "skewness", "kurtosis", "variance")

    C = len(stats["mean"])
    F = len(FEATURE_KEYS)
    all_node_feats = np.empty((C, F), dtype=np.float64)

    for i in range(C):
        all_node_feats[i, 0] = stats["mean"][i]
        all_node_feats[i, 1] = stats["min"][i]
        all_node_feats[i, 2] = stats["max"][i]
        all_node_feats[i, 3] = stats["skewness"][i]
        all_node_feats[i, 4] = stats["kurtosis"][i]
        all_node_feats[i, 5] = stats["variance"][i]

    return torch.tensor(all_node_feats, dtype=torch.float)

def get_edge_features(hb_data: np.ndarray, fs: float, directed: bool = False,
                     corr_threshold: float = 0.0, self_loops: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Extract edge features from fNIRS data using correlation and coherence.

    Args:
        hb_data: fNIRS hemoglobin data array
        fs: Sampling frequency
        directed: If True, create directed graph; if False, undirected
        corr_threshold: Minimum absolute correlation threshold for edge inclusion.
                       Only edges with |correlation| > corr_threshold are kept.
                       Default 0.0 includes all edges.
        self_loops: If True, include self-loops (i=j); if False, exclude them.
                   Default False excludes self-loops.

    Returns:
        edge_index: Tensor of shape (2, E) with edge indices
        edge_attr: Tensor of shape (E, 2) with [abs_correlation, coherence] features
    """
    # Ensure data is in correct orientation (channels x time)
    if hb_data.shape[0] != 23 and hb_data.shape[1] == 23:
        hb_data = hb_data.T

    # Compute Pearson correlation matrix
    R = helper_utils.pearson_correlation_matrix(hb_data, channels_first=True)

    # Take absolute value (like fMRI code)
    R_abs = np.abs(R)

    # Compute coherence matrix
    coh_mat, _, _ = helper_utils.coherence_matrix(hb_data, fs=fs, coherence_ratio='1/3',
                                                      channels_first=True, return_spectrum=False)

    assert R_abs.shape == coh_mat.shape and R_abs.ndim == 2 and R_abs.shape[0] == R_abs.shape[1], \
        "R_abs and coh_mat must be square matrices of the same shape"

    C = R_abs.shape[0]

    if not directed:
        # Undirected graph: take upper triangle
        k = 0 if self_loops else 1  # k=0 includes diagonal, k=1 excludes it
        i, j = np.triu_indices(C, k=k)

        # Filter edges: absolute correlation > threshold
        mask = R_abs[i, j] > corr_threshold
        i, j = i[mask], j[mask]

        edge_index = np.vstack([i, j])                    # (2, E)
        # Store absolute correlation values
        all_edge_feats = np.column_stack([R_abs[i, j], coh_mat[i, j]])  # (E, 2)
    else:
        # Directed graph: include all pairs (or all except diagonal)
        i_list, j_list = [], []
        for i in range(C):
            for j in range(C):
                # Skip self-loops if not enabled
                if i == j and not self_loops:
                    continue

                # Only include edge if absolute correlation > threshold
                if R_abs[i, j] > corr_threshold:
                    i_list.append(i)
                    j_list.append(j)

        i = np.asarray(i_list, dtype=int)
        j = np.asarray(j_list, dtype=int)
        edge_index = np.vstack([i, j])                    # (2, E)
        # Store absolute correlation values
        all_edge_feats = np.column_stack([R_abs[i, j], coh_mat[i, j]])  # (E, 2)

    return torch.tensor(edge_index, dtype=torch.long), torch.tensor(all_edge_feats, dtype=torch.float)

### 1.2 Dataset Class

In [8]:
class fNIRSGraphDatasetNonRecurrent(Dataset):
    def __init__(
        self,
        root: Union[str, Path],
        data_type: str,
        max_trials: int,
        directed: bool = False,
        corr_threshold: float = 0.1,
        self_loops: bool = False,
        transform=None,
        pre_transform=None,
        pre_filter=None
    ):
        self.root = root
        self.data_type = data_type
        self.max_trials = max_trials
        self.directed = directed
        self.corr_threshold = corr_threshold
        self.self_loops = self_loops
        self.data_list = []

        super().__init__(
            root,
            transform,
            pre_transform,
            pre_filter
        )

    @property
    def raw_file_names(self):
        return []

    @property
    def processed_file_names(self):
        return []

    def process(self):
        base = Path(self.root)
        label = {"healthy": 0, "anxiety": 1}

        for label_name in ("healthy", "anxiety"):
            label_dir = base / label_name

            for subj_dir in sorted([p for p in label_dir.iterdir() if p.is_dir()]):
                subj_id = subj_dir.name

                # Load metadata to get sampling frequency
                sfreq = None
                ini_path = subj_dir / f"{subj_id}.data"
                if ini_path.exists():
                    cfg = configparser.ConfigParser()
                    cfg.read(ini_path)
                    sfreq = float(cfg["metadata"]["sfreq"])
                fs = float(sfreq) if (sfreq is not None and np.isfinite(sfreq)) else 10.1

                hb_dir = subj_dir / self.data_type
                if not hb_dir.exists():
                    continue

                trial_files = sorted(hb_dir.glob("*.npy"),  key=lambda p: int(p.stem))
                if self.max_trials is not None and self.max_trials > 0:
                    trial_files = trial_files[: self.max_trials]

                for trial_path in trial_files:
                    trial_idx = int(trial_path.stem)

                    # Load one trial → (C, T) or (T, C)
                    arr = np.load(trial_path)
                    if arr.ndim != 2:
                        raise ValueError(f"Expected 2D array, got shape {arr.shape} at {trial_path}")

                    # Node features
                    node_feats = get_node_features(arr, channels_first=True)

                    # Spatial features
                    edge_index, edge_feats = get_edge_features(arr, fs=fs, directed=self.directed, corr_threshold=self.corr_threshold, self_loops=self.self_loops)

                    # Label
                    y = torch.tensor([label[label_name]], dtype=torch.long)

                    # Create Data object
                    data = Data(
                        x=node_feats,
                        edge_index=edge_index, edge_attr=edge_feats,
                        y=y,
                        subject_id=subj_id, trial_idx=trial_idx, label_str=label_name
                    )

                    # If transform is defined, apply it
                    if self.transform is not None:
                        data = self.transform(data)

                    self.data_list.append(data)

    def len(self) -> int:
        return len(self.data_list)

    def get(self, idx: int) -> Data:
        return self.data_list[idx]

In [10]:
root_dir = r'../data/processed_data_HOMER3/GNG'
DATA_TYPE = 'hbo'
MAX_TRIALS = 2
dataset = fNIRSGraphDatasetNonRecurrent(
    root=root_dir,
    data_type=DATA_TYPE,
    max_trials=MAX_TRIALS,
    directed=True,
    corr_threshold=0.1,
    self_loops=True
)

data = dataset[1]
print(f"Dataset Keys: {data.keys()}")
print("-" * 50)
print(f"Node Features Shape: {data.x.shape}")
print(f"Example of Node Features:\n{data.x}")
print("-" * 50)
print(f"Edge Features Shape: {data.edge_attr.shape}")
print(f"Example of Edge Features:\n{data.edge_attr}")

# sample_data.edge_attr to txt
tx = data.edge_attr.numpy()
np.savetxt('edge_attr.txt', tx, fmt='%f')

Dataset Keys: ['edge_index', 'label_str', 'y', 'x', 'subject_id', 'trial_idx', 'edge_attr']
--------------------------------------------------
Node Features Shape: torch.Size([23, 6])
Example of Node Features:
tensor([[-0.1166, -0.8145,  0.6705,  0.4023,  2.0257,  0.1763],
        [-0.5415, -1.1249,  0.3457,  0.5409,  2.3534,  0.1793],
        [-0.3092, -2.2508,  1.4274, -0.2847,  1.8703,  1.4145],
        [-1.0122, -1.8888, -0.4865, -0.7387,  2.2695,  0.1774],
        [ 0.2517, -1.2487,  1.1942, -0.7473,  2.6944,  0.4717],
        [ 0.7518, -1.3187,  1.4822, -1.1384,  3.3323,  0.5758],
        [ 0.7366, -2.0114,  2.5724, -0.2671,  1.7591,  1.9606],
        [ 0.6973, -1.4736,  2.0757, -0.4062,  2.3358,  0.8226],
        [ 0.0898, -1.0489,  0.5864, -0.8987,  2.8573,  0.1802],
        [ 0.1556, -1.2961,  1.1905, -0.4876,  1.9181,  0.6016],
        [-0.8279, -1.3586,  0.0811,  0.7068,  2.3050,  0.1850],
        [ 0.3398, -0.2950,  0.7538, -0.5656,  1.7217,  0.1387],
        [-0.0591, -1.1

Processing...
Done!


### 1.3 Normalization Statistics

In [11]:
def get_mean_std(dataset):
    """
    Compute dataset-wide mean and std for node and edge features.

    Args:
        dataset: PyTorch Geometric Dataset

    Returns:
        tuple: (mean_dict, std_dict) with keys 'node_features' and 'edge_features'
    """
    all_node_features = []
    all_edge_features = []

    for data in dataset:
        all_node_features.append(data.x)
        if data.edge_attr is not None and len(data.edge_attr) > 0:
            all_edge_features.append(data.edge_attr)

    # Concatenate all features
    all_node_features = torch.cat(all_node_features, dim=0)
    all_edge_features = torch.cat(all_edge_features, dim=0) if all_edge_features else None

    # Compute mean and std
    mean_dict = {
        'node_features': all_node_features.mean(dim=0),
        'edge_features': all_edge_features.mean(dim=0) if all_edge_features is not None else None
    }

    std_dict = {
        'node_features': all_node_features.std(dim=0),
        'edge_features': all_edge_features.std(dim=0) if all_edge_features is not None else None
    }

    return mean_dict, std_dict

### 1.4 Data Transformations

In [18]:
from torch_geometric.data import Data
from torch_geometric.transforms import BaseTransform, AddRandomWalkPE, Compose
from torch_geometric.utils import dropout_edge, mask_feature

class StandardizeGraphFeatures(BaseTransform):
    """Standardize node and edge features using pre-computed statistics."""

    def __init__(self, mean_dict, std_dict, eps=1e-8):
        self.mean_dict = mean_dict
        self.std_dict = std_dict
        self.eps = eps

    def forward(self, data):
        # Standardize node features
        if data.x is not None and self.mean_dict['node_features'] is not None:
            data.x = (data.x - self.mean_dict['node_features']) / (self.std_dict['node_features'] + self.eps)

        # Standardize edge features
        if data.edge_attr is not None and self.mean_dict['edge_features'] is not None:
            data.edge_attr = (data.edge_attr - self.mean_dict['edge_features']) / (self.std_dict['edge_features'] + self.eps)

        return data

class MaskFeatureAugmentation(BaseTransform):
    """Randomly mask features for data augmentation."""

    def __init__(self, p=0.1, mode='all', fill_value=0):
        """
        Args:
            p: probability of masking
            mode: 'all' (individual values), 'col' (entire features), 'row' (entire nodes)
            fill_value: value to fill masked positions
        """
        self.p = p
        self.mode = mode
        self.fill_value = fill_value

    def forward(self, data):
        if data.x is None or self.p == 0:
            return data

        x = data.x.clone()

        if self.mode == 'all':
            # Mask individual values
            mask = torch.rand_like(x) < self.p
            x[mask] = self.fill_value
        elif self.mode == 'col':
            # Mask entire features (columns)
            n_features = x.shape[1]
            mask = torch.rand(n_features) < self.p
            x[:, mask] = self.fill_value
        elif self.mode == 'row':
            # Mask entire nodes (rows)
            n_nodes = x.shape[0]
            mask = torch.rand(n_nodes) < self.p
            x[mask, :] = self.fill_value

        data.x = x
        return data

class DropoutEdgeAugmentation(BaseTransform):
    """Randomly drop edges for data augmentation."""

    def __init__(self, p=0.1, force_undirected=True):
        self.p = p
        self.force_undirected = force_undirected

    def forward(self, data):
        if data.edge_index is None or self.p == 0:
            return data

        edge_index, edge_attr = dropout_edge(
            data.edge_index,
            data.edge_attr,
            p=self.p,
            force_undirected=self.force_undirected
        )

        data.edge_index = edge_index
        data.edge_attr = edge_attr
        return data

class RandomWalkPEAugmentation(BaseTransform):
    """Add Random Walk Positional Encoding to node features."""

    def __init__(self, walk_length=4, attr_name=None):
        self.walk_length = walk_length
        self.attr_name = attr_name

    def forward(self, data):
        from torch_geometric.transforms import AddRandomWalkPE

        # Apply random walk PE
        rw_transform = AddRandomWalkPE(walk_length=self.walk_length, attr_name=self.attr_name)
        data = rw_transform(data)

        # Concatenate PE to node features
        if self.attr_name:
            pe = getattr(data, self.attr_name)
        else:
            pe = data.random_walk_pe

        data.x = torch.cat([data.x, pe], dim=-1)

        return data

In [14]:
def get_transformations(mean, std, augment=False,
                       edge_dropout_p=0.0, feature_mask_p=0.0, feature_mask_mode='all',
                       use_positional_encoding=False, pe_walk_length=4):
    """
    Create transformation pipeline for graphs.

    Args:
        mean: mean dictionary from get_mean_std
        std: std dictionary from get_mean_std
        augment: if True, apply data augmentation
        edge_dropout_p: probability of edge dropout (default: 0.0)
        feature_mask_p: probability of feature masking (default: 0.0)
        feature_mask_mode: mode for feature masking ('all', 'col', 'row')
        use_positional_encoding: if True, add Random Walk PE
        pe_walk_length: walk length for positional encoding

    Returns:
        Compose object with transformations
    """
    transforms = []

    # 1. Standardization (always first)
    transforms.append(StandardizeGraphFeatures(mean, std))

    # 2. Positional Encoding (if enabled, before augmentation)
    if use_positional_encoding:
        transforms.append(RandomWalkPEAugmentation(walk_length=pe_walk_length))

    # 3. Data Augmentation (only for training)
    if augment:
        if edge_dropout_p > 0:
            transforms.append(DropoutEdgeAugmentation(p=edge_dropout_p))
        if feature_mask_p > 0:
            transforms.append(MaskFeatureAugmentation(p=feature_mask_p, mode=feature_mask_mode))

    return Compose(transforms)

### 1.5 Load Dataset and Create K-Fold Loaders

In [None]:
# Configuration
DATA_ROOT = "../data/processed_data_HOMER3/GNG"
DATA_TYPE = "hbo"
MAX_TRIALS = 2
DIRECTED = True
CORR_THRESHOLD = 0.1
SELF_LOOPS = True
BATCH_SIZE = 8
N_SPLITS = 5
RANDOM_STATE = 42

# Augmentation parameters (recommended values)
EDGE_DROPOUT_P = 0.0
FEATURE_MASK_P = 0.1
FEATURE_MASK_MODE = 'all'
USE_POSITIONAL_ENCODING = False  # Set to True to enable RWPE
PE_WALK_LENGTH = 4

print("Loading dataset...")
dataset = fNIRSGraphDatasetNonRecurrent(
    root=DATA_ROOT,
    data_type=DATA_TYPE,
    max_trials=MAX_TRIALS,
    directed=DIRECTED,
    corr_threshold=CORR_THRESHOLD,
    self_loops=SELF_LOOPS
)

print(f"Dataset size: {len(dataset)}")
print(f"Number of node features: {dataset[0].x.shape[1]}")
print(f"Number of edge features: {dataset[0].edge_attr.shape[1] if dataset[0].edge_attr is not None else 0}")
print(f"Number of classes: {len(set([data.y.item() for data in dataset]))}")

Loading dataset...
Dataset size: 106
Number of node features: 6
Number of edge features: 2
Number of classes: 2


Processing...
Done!


In [16]:
# Compute normalization statistics
print("Computing normalization statistics...")
mean_dict, std_dict = get_mean_std(dataset)

print(f"Node features mean: {mean_dict['node_features']}")
print(f"Node features std: {std_dict['node_features']}")
print(f"Edge features mean: {mean_dict['edge_features']}")
print(f"Edge features std: {std_dict['edge_features']}")

Computing normalization statistics...
Node features mean: tensor([ 0.0180, -1.0680,  0.9679, -0.1545,  2.4034,  0.5546])
Node features std: tensor([0.7270, 1.1150, 0.9750, 0.6541, 0.9562, 0.8711])
Edge features mean: tensor([0.7332, 0.5796])
Edge features std: tensor([0.2519, 0.3105])


In [20]:
# Create transformations for train and validation
train_transform = get_transformations(
    mean_dict, std_dict,
    augment=False,  # No augmentation for validation
)

val_transform = get_transformations(
    mean_dict, std_dict,
    augment=False,  # No augmentation for validation
)

print("Transformations created:")
print(f"Train: {train_transform}")
print(f"Validation: {val_transform}")

Transformations created:
Train: Compose([
  StandardizeGraphFeatures()
])
Validation: Compose([
  StandardizeGraphFeatures()
])


In [21]:
# Create K-fold data loaders
fold_loaders_v2 = helper_utils.get_kfold_subject_loaders_v2(
    dataset,
    n_splits=5,
    batch_size=8,
    shuffle_train=True,
    num_workers=0,
    pin_memory=False,
    random_state=42,
    show_subjects=True,
    train_transform=train_transform,
    val_transform=val_transform
)

print(f"\nTotal folds: {len(fold_loaders_v2)}")

# Test first fold
train_loader_fold1, val_loader_fold1 = fold_loaders_v2[0]
print(f"\nFold 1:")
print(f"  Train loader batches: {len(train_loader_fold1)}")
print(f"  Val loader batches: {len(val_loader_fold1)}")

# Test a batch from first fold
batch = next(iter(train_loader_fold1))
print(f"\n  Batch from fold 1 train_loader:")
print(f"    Number of graphs: {batch.num_graphs}")
print(f"    Node features shape: {batch.x.shape}")
print(f"    Edge features shape: {batch.edge_attr.shape}")
print(f"    Features are standardized (means near 0): {batch.x.mean(dim=0)[:3]}")

=== Dataset Overview ===
Total graphs           : 106
Unique subjects        : 53
Node feature x shape   : (23, 6) | dtype=torch.float32
Edge_attr shape        : (475, 2) | dtype=torch.float32
Per-graph label counts : {0: 66, 1: 40}
Per-subject label cnts : {0: 33, 1: 20}
------------------------
=== K-Fold 1/5 (by subject) with Transform ===
Train
  graphs: 84 | subjects: 42
  label counts (graphs): {0: 52, 1: 32}
  label counts (subjects): {1: 16, 0: 26}
  subjects: ['AA013', 'AA041', 'AA064', 'AH014', 'AH015', 'AH017', 'AH018', 'AH019', 'AH021', 'AH022', 'AH024', 'AH025', 'AH026', 'AH027', 'AH029', 'AH033', 'AH034', 'AH035', 'AH036', 'AH037', 'AH038', 'AH039', 'AH040', 'AH043', 'AH044', 'AH045', 'AH047', 'AH048', 'AH049', 'EA012', 'EA016', 'EA060', 'EA061', 'EA062', 'LA051', 'LA052', 'LA053', 'LA054', 'LA057', 'LA058', 'LA059', 'LA063']
Val  
  graphs: 22 | subjects: 11
  label counts (graphs): {0: 14, 1: 8}
  label counts (subjects): {1: 4, 0: 7}
  subjects: ['AA011', 'AA056', 'AH0

## Section 2: Graph Model Initialization

This section contains:
- Flexible GAT model architecture
- Class weight calculation for handling class imbalance
- Loss function options (Focal Loss and weighted Cross Entropy)

### 2.1 Flexible GAT Model

In [22]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Linear, Sequential, BatchNorm1d, LayerNorm, ReLU
from torch_geometric.nn import GATv2Conv, global_mean_pool, GINEConv
from typing import Optional, Dict, Any, Union

class FlexibleGATNet(nn.Module):
    """
    Flexible GAT Network with configurable architecture.
    Supports both pure GAT and GINE-GAT hybrid architectures.
    """

    def __init__(self, in_channels, n_layers, n_filters, heads, fc_size,
                 dropout=0.6, edge_dim=None, n_classes=2,
                 use_residual=True, use_norm=False, norm_type='batch',
                 use_gine_first_layer=False, gine_train_eps=True):
        """
        Args:
            in_channels: input node feature dimension
            n_layers: number of graph convolution layers
            n_filters: hidden dimensions per layer (int or list)
            heads: attention heads per layer (int or list)
            fc_size: output dimension before classification
            dropout: dropout probability
            edge_dim: edge feature dimension
            n_classes: number of output classes
            use_residual: enable residual connections
            use_norm: enable normalization between layers
            norm_type: 'batch' or 'layer' normalization
            use_gine_first_layer: use GINEConv for first layer
            gine_train_eps: learnable epsilon for GINE
        """
        super().__init__()

        self.in_channels = in_channels
        self.n_layers = n_layers
        self.dropout = dropout
        self.use_residual = use_residual
        self.use_norm = use_norm
        self.norm_type = norm_type
        self.use_gine_first_layer = use_gine_first_layer

        # Handle scalar or list configurations
        if isinstance(n_filters, int):
            n_filters = [n_filters] * n_layers
        if isinstance(heads, int):
            heads = [heads] * n_layers

        self.n_filters = n_filters
        self.heads = heads

        # Build graph convolution layers
        self.convs, self.residual_projections, self.norms = self._build_gat_layers(
            in_channels, n_filters, heads, edge_dim, gine_train_eps
        )

        # Final classifier
        final_dim = n_filters[-1] * heads[-1]
        self.fc1 = nn.Linear(final_dim, fc_size)
        self.fc2 = nn.Linear(fc_size, n_classes)

    def _build_gine_mlp(self, in_dim, out_dim):
        """Build MLP for GINEConv layer."""
        return nn.Sequential(
            nn.Linear(in_dim, out_dim),
            nn.BatchNorm1d(out_dim),
            nn.ReLU(),
            nn.Linear(out_dim, out_dim),
            nn.ReLU()
        )

    def _build_gat_layers(self, in_channels, n_filters, heads, edge_dim, gine_train_eps):
        """Build GAT/GINE convolution layers with optional residual and normalization."""
        convs = nn.ModuleList()
        residual_projections = nn.ModuleList()
        norms = nn.ModuleList()

        for i in range(self.n_layers):
            in_dim = in_channels if i == 0 else n_filters[i-1] * heads[i-1]
            out_dim = n_filters[i]
            n_heads = heads[i]

            # First layer: GINE or GAT based on configuration
            if i == 0 and self.use_gine_first_layer:
                mlp = self._build_gine_mlp(in_dim, out_dim * n_heads)
                conv = GINEConv(mlp, edge_dim=edge_dim, train_eps=gine_train_eps)
            else:
                conv = GATv2Conv(
                    in_dim, out_dim, heads=n_heads,
                    dropout=self.dropout, edge_dim=edge_dim, concat=True
                )

            convs.append(conv)

            # Residual projection (if input and output dimensions differ)
            if self.use_residual and in_dim != out_dim * n_heads:
                residual_projections.append(nn.Linear(in_dim, out_dim * n_heads))
            else:
                residual_projections.append(None)

            # Normalization
            if self.use_norm:
                if self.norm_type == 'batch':
                    norms.append(nn.BatchNorm1d(out_dim * n_heads))
                elif self.norm_type == 'layer':
                    norms.append(nn.LayerNorm(out_dim * n_heads))
            else:
                norms.append(None)

        return convs, residual_projections, norms

    def forward(self, x, edge_index, edge_attr, batch):
        """
        Forward pass.

        Args:
            x: node features [num_nodes, in_channels]
            edge_index: edge indices [2, num_edges]
            edge_attr: edge features [num_edges, edge_dim]
            batch: batch assignment [num_nodes]

        Returns:
            logits: [batch_size, n_classes]
        """
        # Graph convolution layers
        for i, conv in enumerate(self.convs):
            identity = x

            # Apply convolution
            if isinstance(conv, GINEConv):
                x = conv(x, edge_index, edge_attr)
            else:  # GATv2Conv
                x = conv(x, edge_index, edge_attr)

            # Apply normalization
            if self.norms[i] is not None:
                x = self.norms[i](x)

            # Apply residual connection
            if self.use_residual:
                if self.residual_projections[i] is not None:
                    identity = self.residual_projections[i](identity)
                x = x + identity

            # Apply activation and dropout (except for last layer)
            if i < self.n_layers - 1:
                x = F.elu(x)
                x = F.dropout(x, p=self.dropout, training=self.training)

        # Global pooling
        x = global_mean_pool(x, batch)

        # Classifier
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = F.elu(self.fc1(x))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.fc2(x)

        return x

### 2.2 Loss Functions

In [None]:
class FocalLoss(nn.Module):
    """
    Focal Loss for handling class imbalance.

    FL(p_t) = -alpha_t * (1 - p_t)^gamma * log(p_t)
    """

    def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
        """
        Args:
            alpha: class weights tensor or None
            gamma: focusing parameter (default: 2.0)
            reduction: 'mean', 'sum', or 'none'
        """
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        """
        Args:
            inputs: logits [batch_size, n_classes]
            targets: labels [batch_size]

        Returns:
            loss: scalar or tensor depending on reduction
        """
        ce_loss = F.cross_entropy(inputs, targets, reduction='none', weight=self.alpha)
        pt = torch.exp(-ce_loss)
        focal_loss = (1 - pt) ** self.gamma * ce_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

def calculate_class_weights(data_source, device, fold_idx=None, use_sqrt=False):
    """
    Calculate balanced class weights for handling class imbalance.

    Args:
        data_source: DataLoader or list of fold loaders
        device: torch device
        fold_idx: fold index if using K-fold loaders
        use_sqrt: if True, apply square root to reduce aggressive weighting

    Returns:
        torch.Tensor: class weights on specified device
    """
    # Extract labels
    if fold_idx is not None:
        # K-fold loaders
        labels = []
        for batch in data_source[fold_idx]['train']:
            labels.extend(batch.y.cpu().numpy())
        labels = np.array(labels)
    else:
        # Single DataLoader
        labels = []
        for batch in data_source:
            labels.extend(batch.y.cpu().numpy())
        labels = np.array(labels)

    # Compute class weights
    classes = np.unique(labels)
    weights = compute_class_weight('balanced', classes=classes, y=labels)

    # Optional: reduce aggressive weighting with square root
    if use_sqrt:
        weights = np.sqrt(weights)

    return torch.tensor(weights, dtype=torch.float).to(device)

### 2.3 Model Configuration and Initialization

In [None]:
# Model hyperparameters
MODEL_CONFIG = {
    'in_channels': 6,  # 6 statistical features (or 6 + PE_WALK_LENGTH if using RWPE)
    'n_layers': 2,
    'n_filters': [112, 32],
    'heads': [6, 4],
    'fc_size': 96,
    'dropout': 0.4,
    'edge_dim': 2,  # correlation + coherence
    'n_classes': 2,
    'use_residual': True,
    'use_norm': True,
    'norm_type': 'batch',
    'use_gine_first_layer': True,
    'gine_train_eps': True
}

# Adjust input channels if using positional encoding
if USE_POSITIONAL_ENCODING:
    MODEL_CONFIG['in_channels'] = 6 + PE_WALK_LENGTH

# Loss function configuration
USE_FOCAL_LOSS = False  # Set to True to use Focal Loss
FOCAL_GAMMA = 2.0
USE_CLASS_WEIGHTS = False
USE_SQRT_WEIGHTS = False

# Learning rate and optimization
LEARNING_RATE = 0.001
WEIGHT_DECAY = 1e-4

print("Model Configuration:")
for key, value in MODEL_CONFIG.items():
    print(f"  {key}: {value}")
print(f"\nLoss Configuration:")
print(f"  Use Focal Loss: {USE_FOCAL_LOSS}")
print(f"  Use Class Weights: {USE_CLASS_WEIGHTS}")
print(f"  Learning Rate: {LEARNING_RATE}")
print(f"  Weight Decay: {WEIGHT_DECAY}")

Model Configuration:
  in_channels: 6
  n_layers: 2
  n_filters: [112, 32]
  heads: [6, 4]
  fc_size: 96
  dropout: 0.4
  edge_dim: 2
  n_classes: 2
  use_residual: True
  use_norm: True
  norm_type: batch
  use_gine_first_layer: True
  gine_train_eps: True

Loss Configuration:
  Use Focal Loss: False
  Use Class Weights: False
  Learning Rate: 0.001
  Weight Decay: 0.0001


## Section 3: 5-Fold Cross-Validation Training Pipeline

This section contains:
- Early stopping mechanism
- Training and validation functions
- K-fold cross-validation training loop
- Metrics tracking and visualization

### 3.1 Early Stopping

In [35]:
class EarlyStopping:
    """
    Early stopping to stop training when validation metric stops improving.

    Args:
        patience: Number of epochs to wait for improvement
        min_delta: Minimum change to qualify as improvement
        mode: 'min' for loss, 'max' for accuracy/F1
        verbose: Print messages when stopping
    """
    def __init__(self, patience=10, min_delta=0.0001, mode='min', verbose=True):
        self.patience = patience
        self.min_delta = min_delta
        self.mode = mode
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.best_epoch = 0

    def __call__(self, current_score, epoch):
        """
        Args:
            current_score: Current validation metric value
            epoch: Current epoch number

        Returns:
            bool: True if training should stop
        """
        if self.best_score is None:
            self.best_score = current_score
            self.best_epoch = epoch
            return False

        improved = False
        if self.mode == 'min':
            improved = current_score < (self.best_score - self.min_delta)
        else:  # mode == 'max'
            improved = current_score > (self.best_score + self.min_delta)

        if improved:
            self.best_score = current_score
            self.best_epoch = epoch
            self.counter = 0
        else:
            self.counter += 1
            if self.verbose:
                print(f"  EarlyStopping: {self.counter}/{self.patience} epochs without improvement")

        if self.counter >= self.patience:
            if self.verbose:
                print(f"  EarlyStopping: Stopping training at epoch {epoch}")
                print(f"  Best score: {self.best_score:.6f} at epoch {self.best_epoch}")
            self.early_stop = True

        return self.early_stop

### 3.2 Training and Validation Functions

In [36]:
import os
import torch
import torch.nn as nn
import numpy as np
import pickle
import shutil
from datetime import datetime
from sklearn import metrics
from sklearn.metrics import confusion_matrix
import torchmetrics

import numpy as np
import torch

def get_experiment_dir(experiment_name: str, base_dir: str = "experiments", overwrite: bool = False) -> str:
    date_str = datetime.now().strftime("%Y%m%d")
    exp_dir = os.path.join(base_dir, date_str, experiment_name)
    if os.path.exists(exp_dir) and overwrite:
        print(f"Removing existing directory: {exp_dir}")
        shutil.rmtree(exp_dir)
    os.makedirs(exp_dir, exist_ok=True)
    return exp_dir

def save_metrics(metrics_dict, save_dir, filename):
    os.makedirs(save_dir, exist_ok=True)
    file_path = os.path.join(save_dir, f"{filename}.pkl")
    with open(file_path, "wb") as f:
        pickle.dump(metrics_dict, f)

def save_best_model(model, save_dir, model_name):
    os.makedirs(save_dir, exist_ok=True)
    save_path = os.path.join(save_dir, f"{model_name}_best.pt")
    torch.save(model.state_dict(), save_path)
    print(f"• Saved best model weights to {save_path}")
    return save_path

def visualize_training_results(results, save_dir, experiment_name, training_type="holdout", best_epoch=None):
    """
    Visualizes training results and saves figures to the experiment directory.

    Compatible with both holdout and k-fold training results, including early stopping.

    Args:
        results: Results from perform_holdout_training or perform_kfold_training
                 - Holdout (new): tuple of (holdout_metrics, final_summary)
                 - Holdout (old): dict with "history" and "final" keys
                 - K-fold: dict with "overall" and "folds" keys
        save_dir (str): Directory to save the figures
        experiment_name (str): Name of the experiment (for file naming)
        training_type (str): Either "holdout" or "kfold" (default: "holdout")
        best_epoch (int, optional): Best epoch index (0-indexed) to mark on plots.
                                   If None, will try to extract from results.

    Returns:
        dict: Paths to all saved figures
            {
                "loss_plot": str,
                "accuracy_plot": str,
                "f1_plot": str,
                "confusion_matrix_plot": str
            }

    Examples:
        # Holdout training (new structure)
        holdout_metrics, final_summary = perform_holdout_training(...)
        exp_dir = get_experiment_dir("my_experiment")
        paths = visualize_training_results(
            (holdout_metrics, final_summary),
            exp_dir,
            "my_experiment",
            "holdout"
        )

        # K-fold training
        results = perform_kfold_training(...)
        exp_dir = get_experiment_dir("my_kfold")
        paths = visualize_training_results(
            results,
            exp_dir,
            "my_kfold",
            "kfold"
        )
    """
    import matplotlib.pyplot as plt
    import numpy as np
    import seaborn as sns

    # Ensure save directory exists
    os.makedirs(save_dir, exist_ok=True)

    # Extract metrics based on training type and result structure
    if training_type == "holdout":
        # Check if results is a tuple (new structure) or dict (old structure)
        if isinstance(results, tuple):
            # New structure from updated perform_holdout_training
            holdout_metrics, final_summary = results
            metrics = holdout_metrics
            final_cm = np.array(final_summary["confusion_matrix"])
            # Extract best_epoch if not provided
            if best_epoch is None:
                best_epoch = holdout_metrics.get("best_epoch", None)
        else:
            # Old structure (backward compatibility)
            metrics = results.get("history", results)
            final_cm = np.array(results["final"]["confusion_matrix"])

    elif training_type == "kfold":
        # K-fold structure remains compatible
        metrics = results["overall"]
        final_cm = np.array(results["overall"]["confusion_matrix_overall"])
        # For k-fold aggregate plot, best_epoch doesn't apply directly
        # (each fold has its own best epoch)
        best_epoch = None
    else:
        raise ValueError(f"Invalid training_type: {training_type}. Must be 'holdout' or 'kfold'")

    # Extract metric arrays
    train_loss = np.array(metrics["train_loss"])
    val_loss = np.array(metrics["val_loss"])
    train_acc = np.array(metrics["train_accuracy"])
    val_acc = np.array(metrics["val_accuracy"])
    train_f1 = np.array(metrics["train_f1"])
    val_f1 = np.array(metrics["val_f1"])

    # Handle NaN values in k-fold (from variable-length folds due to early stopping)
    if training_type == "kfold":
        # Find first occurrence of NaN to truncate arrays
        valid_mask = ~np.isnan(train_loss)
        if not np.all(valid_mask):
            max_valid_idx = np.where(valid_mask)[0][-1] + 1
            train_loss = train_loss[:max_valid_idx]
            val_loss = val_loss[:max_valid_idx]
            train_acc = train_acc[:max_valid_idx]
            val_acc = val_acc[:max_valid_idx]
            train_f1 = train_f1[:max_valid_idx]
            val_f1 = val_f1[:max_valid_idx]

    # Create epoch array (1-indexed for display)
    epochs = np.arange(1, len(train_loss) + 1)

    # Get early stopping info for titles
    stopped_at = metrics.get("stopped_at_epoch", len(train_loss) - 1)
    early_stop_info = ""
    if best_epoch is not None:
        early_stop_info = f"\n(Best: Epoch {best_epoch+1}, Stopped: Epoch {stopped_at+1})"

    # Dictionary to store saved paths
    saved_paths = {}

    # -------------------------------------------------------------------------
    # Figure 1: Training vs Validation Loss
    # -------------------------------------------------------------------------
    fig, ax = plt.subplots(figsize=(10, 6))
    ax.plot(epochs, train_loss, 'b-', label='Training Loss', linewidth=2)
    ax.plot(epochs, val_loss, 'r-', label='Validation Loss', linewidth=2)

    # Add best epoch marker
    if best_epoch is not None and best_epoch < len(val_loss):
        ax.axvline(x=best_epoch+1, color='green', linestyle='--',
                   linewidth=2, alpha=0.7, label=f'Best Epoch ({best_epoch+1})')
        ax.plot(best_epoch+1, val_loss[best_epoch],
               'g*', markersize=15, label='Best Val Loss', zorder=5)

    ax.set_xlabel('Epoch', fontsize=12)
    ax.set_ylabel('Loss', fontsize=12)
    ax.set_title(f'Training vs Validation Loss - {experiment_name}{early_stop_info}',
                 fontsize=14, fontweight='bold')
    ax.legend(fontsize=11)
    ax.grid(True, alpha=0.3)

    loss_path = os.path.join(save_dir, f"{experiment_name}_loss_curves.png")
    plt.savefig(loss_path, dpi=300, bbox_inches='tight')
    plt.close(fig)
    saved_paths["loss_plot"] = loss_path

    # -------------------------------------------------------------------------
    # Figure 2: Training vs Validation Accuracy
    # -------------------------------------------------------------------------
    fig, ax = plt.subplots(figsize=(10, 6))
    ax.plot(epochs, train_acc, 'b-', label='Training Accuracy', linewidth=2)
    ax.plot(epochs, val_acc, 'r-', label='Validation Accuracy', linewidth=2)

    # Add best epoch marker
    if best_epoch is not None and best_epoch < len(val_acc):
        ax.axvline(x=best_epoch+1, color='green', linestyle='--',
                   linewidth=2, alpha=0.7, label=f'Best Epoch ({best_epoch+1})')
        ax.plot(best_epoch+1, val_acc[best_epoch],
               'g*', markersize=15, label='Best Val Acc', zorder=5)

    ax.set_xlabel('Epoch', fontsize=12)
    ax.set_ylabel('Accuracy', fontsize=12)
    ax.set_title(f'Training vs Validation Accuracy - {experiment_name}{early_stop_info}',
                 fontsize=14, fontweight='bold')
    ax.legend(fontsize=11)
    ax.grid(True, alpha=0.3)
    ax.set_ylim([0, 1])

    acc_path = os.path.join(save_dir, f"{experiment_name}_accuracy_curves.png")
    plt.savefig(acc_path, dpi=300, bbox_inches='tight')
    plt.close(fig)
    saved_paths["accuracy_plot"] = acc_path

    # -------------------------------------------------------------------------
    # Figure 3: Training vs Validation F1-Score
    # -------------------------------------------------------------------------
    fig, ax = plt.subplots(figsize=(10, 6))
    ax.plot(epochs, train_f1, 'b-', label='Training F1-Score', linewidth=2)
    ax.plot(epochs, val_f1, 'r-', label='Validation F1-Score', linewidth=2)

    # Add best epoch marker
    if best_epoch is not None and best_epoch < len(val_f1):
        ax.axvline(x=best_epoch+1, color='green', linestyle='--',
                   linewidth=2, alpha=0.7, label=f'Best Epoch ({best_epoch+1})')
        ax.plot(best_epoch+1, val_f1[best_epoch],
               'g*', markersize=15, label='Best Val F1', zorder=5)

    ax.set_xlabel('Epoch', fontsize=12)
    ax.set_ylabel('F1-Score', fontsize=12)
    ax.set_title(f'Training vs Validation F1-Score - {experiment_name}{early_stop_info}',
                 fontsize=14, fontweight='bold')
    ax.legend(fontsize=11)
    ax.grid(True, alpha=0.3)
    ax.set_ylim([0, 1])

    f1_path = os.path.join(save_dir, f"{experiment_name}_f1_curves.png")
    plt.savefig(f1_path, dpi=300, bbox_inches='tight')
    plt.close(fig)
    saved_paths["f1_plot"] = f1_path

    # -------------------------------------------------------------------------
    # Figure 4: Confusion Matrix
    # -------------------------------------------------------------------------
    fig, ax = plt.subplots(figsize=(8, 6))

    # Create heatmap
    sns.heatmap(final_cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=['Healthy (0)', 'Anxiety (1)'],
                yticklabels=['Healthy (0)', 'Anxiety (1)'],
                cbar_kws={'label': 'Count'},
                ax=ax)

    cm_title = f'Final Confusion Matrix - {experiment_name}'
    if best_epoch is not None:
        cm_title += f'\n(From Best Epoch: {best_epoch+1})'
    ax.set_xlabel('Predicted Label', fontsize=12)
    ax.set_ylabel('True Label', fontsize=12)
    ax.set_title(cm_title, fontsize=14, fontweight='bold')

    cm_path = os.path.join(save_dir, f"{experiment_name}_confusion_matrix.png")
    plt.savefig(cm_path, dpi=300, bbox_inches='tight')
    plt.close(fig)
    saved_paths["confusion_matrix_plot"] = cm_path

    print(f"\nVisualization saved to: {save_dir}")
    print(f"  - Loss curves: {os.path.basename(loss_path)}")
    print(f"  - Accuracy curves: {os.path.basename(acc_path)}")
    print(f"  - F1-Score curves: {os.path.basename(f1_path)}")
    print(f"  - Confusion Matrix: {os.path.basename(cm_path)}")

    if best_epoch is not None:
        print(f"\nBest model from epoch {best_epoch+1} (marked with green line and star)")

    return saved_paths

def train(model, train_loader, optimizer, loss_fn, device,
          epoch=None, n_epochs=None, verbose=True, log_freq=10):
    """
    Performs one full training epoch over the train_loader.

    Args:
        model: The GNN model to train
        train_loader: DataLoader for the training set
        optimizer: Optimizer for updating model parameters
        loss_fn: Loss function (e.g., CrossEntropyLoss)
        device: Device to run training on ('cpu' or 'cuda')
        epoch: Current epoch number (for logging, optional)
        n_epochs: Total number of epochs (for logging, optional)
        verbose: If True, print progress (default: True)
        log_freq: Log every N batches (default: 10)

    Returns:
        tuple: (avg_loss, train_acc, train_f1)
            - avg_loss: Average training loss over all batches
            - train_acc: Training accuracy (computed using torchmetrics)
            - train_f1: Training F1 score (computed using torchmetrics)
    """
    model.train()
    total_loss = 0.0

    # Introspect model signature once
    sig = inspect.signature(model.forward)
    param_names = set(sig.parameters.keys())
    use_x          = ('x'          in param_names)
    use_edge_index = ('edge_index' in param_names)
    use_edge_attr  = ('edge_attr'  in param_names)
    use_batch      = ('batch'      in param_names)

    # Create metrics for this epoch
    train_accuracy_metric = torchmetrics.Accuracy(task='binary').to(device)
    train_f1_metric = torchmetrics.F1Score(task='binary').to(device)

    # Determine total batches for logging
    total_batches = len(train_loader)
    epoch_str = f"{epoch}/{n_epochs}" if (epoch is not None and n_epochs is not None) else "1/1"

    for batch_idx, batch in enumerate(train_loader):
        batch = batch.to(device)

        # Forward pass - build kwargs based on model signature
        optimizer.zero_grad()
        kwargs = {}
        if use_x:          kwargs['x'] = batch.x
        if use_edge_index: kwargs['edge_index'] = batch.edge_index
        if use_edge_attr and hasattr(batch, 'edge_attr'):
            kwargs['edge_attr'] = batch.edge_attr
        if use_batch and hasattr(batch, 'batch'):
            kwargs['batch'] = batch.batch
        logits = model(**kwargs)

        y_gt = batch.y.view(-1).long()

        # Compute loss
        loss = loss_fn(logits, y_gt)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        # Update metrics with class predictions
        batch_preds = torch.argmax(logits, dim=-1)
        train_accuracy_metric.update(batch_preds, y_gt)
        train_f1_metric.update(batch_preds, y_gt)

        # Log progress at specified frequency
        current_batch = batch_idx + 1
        if verbose and (current_batch % log_freq == 0 or current_batch == total_batches):
            avg_loss = total_loss / current_batch
            current_acc = train_accuracy_metric.compute().item()
            current_f1 = train_f1_metric.compute().item()
            print(f"Epoch [{epoch_str}], Step [{current_batch}/{total_batches}], "
                  f"Loss: {avg_loss:.4f}, Acc: {current_acc:.4f}, F1: {current_f1:.4f}")

    # Compute final metrics for the epoch
    avg_loss = total_loss / len(train_loader)
    train_acc = train_accuracy_metric.compute().item()
    train_f1 = train_f1_metric.compute().item()

    return avg_loss, train_acc, train_f1

def val(model, val_loader, loss_fn, device,
        epoch=None, n_epochs=None, verbose=True, log_final_only=True):
    """
    Performs one full validation epoch over the val_loader.

    Computes all validation metrics internally including precision, recall,
    and confusion matrix, eliminating the need for a separate logger class.

    Args:
        model: The GNN model to evaluate
        val_loader: DataLoader for the validation set
        loss_fn: Loss function (e.g., CrossEntropyLoss)
        device: Device to run evaluation on ('cpu' or 'cuda')
        epoch: Current epoch number (for logging, optional)
        n_epochs: Total number of epochs (for logging, optional)
        verbose: If True, print results (default: True)
        log_final_only: If True, only log final results after all batches (default: True)

    Returns:
        tuple: (avg_loss, val_acc, val_f1, precision, recall, cm, preds, targets)
            - avg_loss: Average validation loss over all batches
            - val_acc: Validation accuracy (computed using torchmetrics)
            - val_f1: Validation F1 score (computed using torchmetrics)
            - precision: Validation precision (computed using torchmetrics)
            - recall: Validation recall (computed using torchmetrics)
            - cm: Confusion matrix as numpy array, shape (2, 2)
            - preds: Numpy array of class predictions {0, 1}
            - targets: Numpy array of ground truth labels {0, 1}
    """
    model.eval()
    total_loss = 0.0
    logit_list, label_list = [], []

    # Introspect model signature once
    sig = inspect.signature(model.forward)
    param_names = set(sig.parameters.keys())
    use_x          = ('x'          in param_names)
    use_edge_index = ('edge_index' in param_names)
    use_edge_attr  = ('edge_attr'  in param_names)
    use_batch      = ('batch'      in param_names)

    # Create metrics for validation
    val_accuracy_metric = torchmetrics.Accuracy(task='binary').to(device)
    val_f1_metric = torchmetrics.F1Score(task='binary').to(device)

    # Determine total batches for logging
    total_batches = len(val_loader)
    epoch_str = f"{epoch}/{n_epochs}" if (epoch is not None and n_epochs is not None) else "1/1"

    with torch.no_grad():
        for batch_idx, batch in enumerate(val_loader):
            batch = batch.to(device)

            # Forward pass - build kwargs based on model signature
            kwargs = {}
            if use_x:          kwargs['x'] = batch.x
            if use_edge_index: kwargs['edge_index'] = batch.edge_index
            if use_edge_attr and hasattr(batch, 'edge_attr'):
                kwargs['edge_attr'] = batch.edge_attr
            if use_batch and hasattr(batch, 'batch'):
                kwargs['batch'] = batch.batch
            logits = model(**kwargs)

            y_gt = batch.y.view(-1).long()

            # Compute loss
            batch_loss = loss_fn(logits, y_gt).item()
            total_loss += batch_loss

            # Get predictions and update metrics
            batch_preds = torch.argmax(logits, dim=-1)
            val_accuracy_metric.update(batch_preds, y_gt)
            val_f1_metric.update(batch_preds, y_gt)

            logit_list.append(batch_preds.detach().cpu().numpy().reshape(-1))
            label_list.append(y_gt.detach().cpu().numpy().reshape(-1))

            # Log per-batch progress if log_final_only=False
            if verbose and not log_final_only:
                current_batch = batch_idx + 1
                current_loss = total_loss / current_batch
                current_acc = val_accuracy_metric.compute().item()
                current_f1 = val_f1_metric.compute().item()
                print(f"Epoch [{epoch_str}], Step [{current_batch}/{total_batches}], "
                      f"Loss: {current_loss:.4f}, Acc: {current_acc:.4f}, F1: {current_f1:.4f}")

    # Compute final metrics
    avg_loss = total_loss / len(val_loader)
    val_acc = val_accuracy_metric.compute().item()
    val_f1 = val_f1_metric.compute().item()
    preds = np.hstack(logit_list).ravel()
    targets = np.hstack(label_list).ravel()

    # Compute precision, recall, and confusion matrix
    preds_tensor = torch.from_numpy(preds).long().to(device)
    targets_tensor = torch.from_numpy(targets).long().to(device)

    precision_metric = torchmetrics.Precision(task='binary').to(device)
    recall_metric = torchmetrics.Recall(task='binary').to(device)

    precision = precision_metric(preds_tensor, targets_tensor).item()
    recall = recall_metric(preds_tensor, targets_tensor).item()
    cm = confusion_matrix(targets, preds, labels=[0, 1])

    # Log final results only once (or only time if log_final_only=True)
    if verbose and log_final_only:
        print(f"Epoch [{epoch_str}], Step [{total_batches}/{total_batches}], "
              f"Loss: {avg_loss:.4f}, Acc: {val_acc:.4f}, F1: {val_f1:.4f}")

    return avg_loss, val_acc, val_f1, precision, recall, cm, preds, targets

### 3.3 K-Fold Cross-Validation Training

In [37]:
def perform_kfold_training(
    *,
    fold_loaders,
    model_fn,
    optimizer_fn,
    loss_fn,
    device,
    epochs: int,
    experiment_name: str,
    scheduler_fn=None,
    patience=10,
    min_delta=0.0001,
    early_stop_metric='loss'
):
    """
    Performs K-fold cross-validation training without using MetricsLogger.

    All metric tracking is done explicitly with lists, making the code more
    readable and eliminating ambiguity.

    Args:
        fold_loaders: List of (train_loader, val_loader) pairs from get_kfold_subject_loaders_v2
        model_fn: Callable that returns a new model instance
        optimizer_fn: Callable that takes model and returns optimizer
        loss_fn: Loss function (e.g., nn.CrossEntropyLoss)
        device: Device to train on ('cpu' or 'cuda')
        epochs: Number of epochs per fold
        experiment_name: Name for saving results
        scheduler_fn: Optional callable that takes optimizer and returns scheduler
        patience: Number of epochs to wait for improvement before stopping
        min_delta: Minimum change to qualify as improvement
        early_stop_metric: Metric to monitor ('loss', 'f1', 'accuracy')
    Returns:
        dict: K-fold training results and metrics

    """
    exp_dir = get_experiment_dir(experiment_name)
    all_fold_metrics = []
    final_cms = []
    best_epochs_per_fold = []

    for fold_idx, (train_loader, val_loader) in enumerate(fold_loaders, start=1):
        print(f"\n{'='*70}")
        print(f"Fold {fold_idx}/{len(fold_loaders)}")
        print(f"{'='*70}")

        # Initialize model, optimizer, scheduler
        model = model_fn().to(device)
        optimizer = optimizer_fn(model)
        scheduler = scheduler_fn(optimizer) if scheduler_fn else None

        # Initialize tracking lists
        train_losses = []
        train_accs = []
        train_f1s = []
        val_losses = []
        val_accs = []
        val_f1s = []
        precisions = []
        recalls = []
        confusion_matrices = []

        # Track best model for this fold based on early_stop_metric
        if early_stop_metric == 'loss':
            best_metric_value = float('inf')
        elif early_stop_metric in ['f1', 'accuracy']:
            best_metric_value = 0.0
        else:
            raise ValueError(f"Invalid early_stop_metric: {early_stop_metric}. Must be 'loss', 'f1', or 'accuracy'")

        best_model_state = None
        best_epoch = 0

        # NEW: Initialize early stopping for this fold
        mode = 'min' if early_stop_metric == 'loss' else 'max'
        early_stopper = EarlyStopping(patience=patience, min_delta=min_delta,
                                      mode=mode, verbose=True)

        for epoch in range(epochs):
            # Training phase
            tr_loss, tr_acc, tr_f1 = train(
                model, train_loader, optimizer, loss_fn, device,
                epoch=epoch + 1, n_epochs=epochs, verbose=True, log_freq=10
            )

            # Validation phase
            vl_loss, vl_acc, vl_f1, precision, recall, cm, vl_preds, vl_targets = val(
                model, val_loader, loss_fn, device,
                epoch=epoch + 1, n_epochs=epochs, verbose=False, log_final_only=True
            )

            # Step scheduler
            if scheduler:
                scheduler.step(vl_f1)

            # Append metrics
            train_losses.append(tr_loss)
            train_accs.append(tr_acc)
            train_f1s.append(tr_f1)
            val_losses.append(vl_loss)
            val_accs.append(vl_acc)
            val_f1s.append(vl_f1)
            precisions.append(precision)
            recalls.append(recall)
            confusion_matrices.append(cm.tolist())

            # Track best model by early_stop_metric (consistent with early stopping)
            current_metric = {'loss': vl_loss, 'accuracy': vl_acc, 'f1': vl_f1}[early_stop_metric]

            is_better = False
            if early_stop_metric == 'loss':
                is_better = current_metric < best_metric_value
            else:  # f1 or accuracy
                is_better = current_metric > best_metric_value

            if is_better:
                best_metric_value = current_metric
                best_model_state = model.state_dict()
                best_epoch = epoch
                metric_name = {'loss': 'Loss', 'accuracy': 'Accuracy', 'f1': 'F1'}[early_stop_metric]
                print(f"  ★ New best model! Val {metric_name}: {current_metric:.6f}")

            # Print summary
            print(
                f"[Fold {fold_idx}] Epoch {epoch+1}/{epochs}  "
                f"TR L:{tr_loss:.4f} Acc:{tr_acc:.4f} F1:{tr_f1:.4f} | "
                f"VL L:{vl_loss:.4f} Acc:{vl_acc:.4f} F1:{vl_f1:.4f}"
            )

            # Check early stopping
            stop_metric = {'loss': vl_loss, 'accuracy': vl_acc, 'f1': vl_f1}[early_stop_metric]
            if early_stopper(stop_metric, epoch):
                print(f"\n[Fold {fold_idx}] Early stopping at epoch {epoch+1}")
                print(f"[Fold {fold_idx}] Best model was at epoch {best_epoch+1}")
                break

        # Save per-fold metrics
        fold_metrics = {
            "train_loss": train_losses,
            "train_accuracy": train_accs,
            "train_f1": train_f1s,
            "val_loss": val_losses,
            "val_accuracy": val_accs,
            "val_f1": val_f1s,
            "precision": precisions,
            "recall": recalls,
            "confusion_matrix": confusion_matrices,
            "best_epoch": best_epoch,
            "stopped_at_epoch": len(train_losses) - 1,
        }
        save_metrics(fold_metrics, exp_dir, f"{experiment_name}_fold_{fold_idx}")
        all_fold_metrics.append(fold_metrics)
        best_epochs_per_fold.append(best_epoch)

        # Load and save best model
        if best_model_state is not None:
            model.load_state_dict(best_model_state)
            save_best_model(model, exp_dir, f"{experiment_name}_fold_{fold_idx}")

        # FIXED: Use confusion matrix from BEST epoch, not last
        final_cms.append(np.array(confusion_matrices[best_epoch]))

        print(f"\n[Fold {fold_idx}] Summary:")
        print(f"  Best epoch: {best_epoch+1}/{len(train_losses)}")
        print(f"  Best val loss: {val_losses[best_epoch]:.4f}")
        print(f"  Best val accuracy: {val_accs[best_epoch]:.4f}")
        print(f"  Best val F1: {val_f1s[best_epoch]:.4f}")

    # Aggregate overall metrics
    # For curves, pad shorter folds with NaN to handle different stopping points
    max_epochs = max(len(m["train_loss"]) for m in all_fold_metrics)

    def pad_to_max(arr, max_len):
        """Pad array with NaN to max length"""
        padded = np.full(max_len, np.nan)
        padded[:len(arr)] = arr
        return padded

    tl = np.array([pad_to_max(m["train_loss"], max_epochs) for m in all_fold_metrics])
    ta = np.array([pad_to_max(m["train_accuracy"], max_epochs) for m in all_fold_metrics])
    tf = np.array([pad_to_max(m["train_f1"], max_epochs) for m in all_fold_metrics])
    vl = np.array([pad_to_max(m["val_loss"], max_epochs) for m in all_fold_metrics])
    va = np.array([pad_to_max(m["val_accuracy"], max_epochs) for m in all_fold_metrics])
    vf = np.array([pad_to_max(m["val_f1"], max_epochs) for m in all_fold_metrics])

    # Use nanmean to ignore NaN from early stopped folds
    overall_metrics = {
        "train_loss": np.nanmean(tl, axis=0).tolist(),
        "train_accuracy": np.nanmean(ta, axis=0).tolist(),
        "train_f1": np.nanmean(tf, axis=0).tolist(),
        "val_loss": np.nanmean(vl, axis=0).tolist(),
        "val_accuracy": np.nanmean(va, axis=0).tolist(),
        "val_f1": np.nanmean(vf, axis=0).tolist(),
        "best_epochs_per_fold": [e+1 for e in best_epochs_per_fold],
    }

    best_fold_precisions = [all_fold_metrics[i]["precision"][best_epochs_per_fold[i]]
                            for i in range(len(all_fold_metrics))]
    best_fold_recalls = [all_fold_metrics[i]["recall"][best_epochs_per_fold[i]]
                        for i in range(len(all_fold_metrics))]

    best_fold_accuracies = [all_fold_metrics[i]["val_accuracy"][best_epochs_per_fold[i]]
                            for i in range(len(all_fold_metrics))]
    best_fold_f1s = [all_fold_metrics[i]["val_f1"][best_epochs_per_fold[i]]
                    for i in range(len(all_fold_metrics))]

    # Aggregate confusion matrix (sum across folds)
    agg_cm = np.sum(final_cms, axis=0) if final_cms else np.zeros((2, 2), int)

    # Calculate overall metrics from aggregated CM
    if agg_cm.shape == (2, 2) and agg_cm.sum() > 0:
        tn, fp, fn, tp = agg_cm.ravel()
        overall_accuracy = (tp + tn) / (tp + tn + fp + fn) if (tp + tn + fp + fn) > 0 else 0.0
        overall_precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        overall_recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        overall_f1 = 2 * (overall_precision * overall_recall) / (overall_precision + overall_recall) \
                     if (overall_precision + overall_recall) > 0 else 0.0
    else:
        overall_precision = overall_recall = overall_f1 = 0.0

    overall_metrics.update({
        "accuracy_mean": np.mean(best_fold_accuracies),
        "f1_mean": np.mean(best_fold_f1s),
        "precision_mean": np.mean(best_fold_precisions),
        "recall_mean": np.mean(best_fold_recalls),
        "accuracy_overall": overall_accuracy,
        "precision_overall": overall_precision,
        "recall_overall": overall_recall,
        "f1_overall": overall_f1,
        "confusion_matrix_overall": agg_cm.tolist(),
    })
    save_metrics(overall_metrics, exp_dir, f"{experiment_name}_kfold_overall")

    # Print summary
    print(f"\n{'='*70}")
    print(f"K-Fold Training Complete: {experiment_name}")
    print(f"{'='*70}")
    print(f"Best epochs per fold: {overall_metrics['best_epochs_per_fold']}")
    print(f"Overall Metrics (from best epochs across {len(fold_loaders)} folds):")
    print(f"  Precision: {overall_precision:.4f}")
    print(f"  Recall:    {overall_recall:.4f}")
    print(f"  F1 Score:  {overall_f1:.4f}")
    print(f"  Confusion Matrix:\n{agg_cm}")

    return {"folds": all_fold_metrics, "overall": overall_metrics}

### 3.4 Execute 5-Fold Cross-Validation

In [38]:
# Training configuration
EPOCHS = 50
PATIENCE = 10
MIN_DELTA = 0.0001
EARLY_STOP_METRIC = 'loss'  # 'loss', 'f1', or 'accuracy'
EXPERIMENT_NAME = 'flexible_gat_5fold'
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define model factory function
def create_model():
    return FlexibleGATNet(**MODEL_CONFIG)

# Define optimizer factory function
def create_optimizer(model):
    return torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# Optional: Define scheduler factory function
def create_scheduler(optimizer):
    return torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.5, patience=3, verbose=False)

# Calculate class weights if enabled
if USE_CLASS_WEIGHTS:
    class_weights = calculate_class_weights(fold_loaders_v2, device=DEVICE, fold_idx=0, use_sqrt=USE_SQRT_WEIGHTS)
    print(f"Class weights: {class_weights}")
else:
    class_weights = None

# Define loss function
if USE_FOCAL_LOSS:
    loss_fn = FocalLoss(alpha=class_weights, gamma=FOCAL_GAMMA)
    print(f"Using Focal Loss (gamma={FOCAL_GAMMA})")
else:
    loss_fn = nn.CrossEntropyLoss()
    print(f"Using Cross Entropy Loss")

print(f"\nTraining Configuration:")
print(f"  Epochs: {EPOCHS}")
print(f"  Patience: {PATIENCE}")
print(f"  Early Stop Metric: {EARLY_STOP_METRIC}")
print(f"  Experiment Name: {EXPERIMENT_NAME}")

Using Cross Entropy Loss

Training Configuration:
  Epochs: 50
  Patience: 10
  Early Stop Metric: loss
  Experiment Name: flexible_gat_5fold


In [39]:
# Run 5-fold cross-validation training
results = perform_kfold_training(
    fold_loaders=fold_loaders_v2,
    model_fn=create_model,
    optimizer_fn=create_optimizer,
    loss_fn=loss_fn,
    device=DEVICE,
    epochs=EPOCHS,
    experiment_name=EXPERIMENT_NAME,
    scheduler_fn=create_scheduler,
    patience=PATIENCE,
    min_delta=MIN_DELTA,
    early_stop_metric=EARLY_STOP_METRIC
)


Fold 1/5
Epoch [1/50], Step [10/11], Loss: 0.7273, Acc: 0.6125, F1: 0.5373
Epoch [1/50], Step [11/11], Loss: 0.7573, Acc: 0.5952, F1: 0.5278
  ★ New best model! Val Loss: 0.472359
[Fold 1] Epoch 1/50  TR L:0.7573 Acc:0.5952 F1:0.5278 | VL L:0.4724 Acc:0.8182 F1:0.7143
Epoch [2/50], Step [10/11], Loss: 0.5213, Acc: 0.7875, F1: 0.7119
Epoch [2/50], Step [11/11], Loss: 0.5076, Acc: 0.7976, F1: 0.7213
  ★ New best model! Val Loss: 0.428159
[Fold 1] Epoch 2/50  TR L:0.5076 Acc:0.7976 F1:0.7213 | VL L:0.4282 Acc:0.7727 F1:0.6667
Epoch [3/50], Step [10/11], Loss: 0.6228, Acc: 0.6500, F1: 0.4615
Epoch [3/50], Step [11/11], Loss: 0.6985, Acc: 0.6429, F1: 0.4828
  ★ New best model! Val Loss: 0.402076
[Fold 1] Epoch 3/50  TR L:0.6985 Acc:0.6429 F1:0.4828 | VL L:0.4021 Acc:0.8182 F1:0.7778
Epoch [4/50], Step [10/11], Loss: 0.7794, Acc: 0.5875, F1: 0.4762
Epoch [4/50], Step [11/11], Loss: 0.7602, Acc: 0.5833, F1: 0.4615
[Fold 1] Epoch 4/50  TR L:0.7602 Acc:0.5833 F1:0.4615 | VL L:0.5040 Acc:0.7273

### 3.5 Visualize and Report Results

In [41]:
DATE = '20251219'

result_path = f'./experiments/{DATE}/flexible_gat_5fold'
helper_utils.report_training_results(result_path)


K-FOLD CROSS-VALIDATION RESULTS
Experiment: flexible_gat_5fold
Folder: ./experiments/20251219/flexible_gat_5fold
Number of Folds: 5

PER-FOLD BEST EPOCH PERFORMANCE
+--------+------------+----------+--------+-----------+--------+
| Fold   | Best Epoch | Accuracy |     F1 | Precision | Recall |
+--------+------------+----------+--------+-----------+--------+
| Fold 1 |         10 |   0.9091 | 0.8889 |    0.8000 | 1.0000 |
| Fold 2 |          4 |   0.5000 | 0.2667 |    0.2857 | 0.2500 |
| Fold 3 |          7 |   0.5000 | 0.5217 |    0.4000 | 0.7500 |
| Fold 4 |         12 |   0.7500 | 0.6154 |    0.8000 | 0.5000 |
| Fold 5 |         10 |   0.5500 | 0.4706 |    0.4444 | 0.5000 |
+--------+------------+----------+--------+-----------+--------+

PER-FOLD FINAL EPOCH PERFORMANCE
+--------+----------+--------+-----------+--------+
| Fold   | Accuracy |     F1 | Precision | Recall |
+--------+----------+--------+-----------+--------+
| Fold 1 |   0.7273 | 0.7273 |    0.5714 | 1.0000 |
| Fold 