In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, Dataset, DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool

from utilities/utils.py import get_n_ROI, vec_to_symmetric_matrix

In [None]:
# Load configuration file for filepath
with open("config.yaml", "r") as f:
    config = yaml.safe_load(f)
my_filepath = config["simulation_filepath"]

# Loading dataset for 1 task paradigm assessed by 1 method for all subjects (1 run)
dFC = np.load(my_filepath, allow_pickle=True)
dFC_dict = dFC.item() # extract the dictionary from np array

X = dFC_dict["X"]
y = dFC_dict["y"]
subj_label = dFC_dict["subj_label"]
method = dFC_dict["measure_name"]

In [None]:
def dfc_to_graph(dfc_matrix, label):
    """
    Convert a dFC matrix into a fully connected weighted graph.
    """
    num_nodes = dfc_matrix.shape[0]

    # Node features: all nodes get identical 1D feature (e.g., [[1.0], [1.0], ..., [1.0]])
    x = torch.ones((num_nodes, 1), dtype=torch.float)

    # Create edge_index for all pairs (excluding self-loops)
    edge_index = torch.tensor([[i, j] for i in range(num_nodes) for j in range(num_nodes) if i != j], dtype=torch.long).t()

    # Get corresponding edge weights from the dFC matrix
    edge_weights = dfc_matrix[edge_index[0], edge_index[1]]

    return Data(x=x, edge_index=edge_index, edge_attr=edge_weights, y=torch.tensor([label], dtype=torch.float))


In [None]:
class dFCGraphDataset(Dataset):
    def __init__(self, X, y):
        """
        Parameters:
            X: 2D numpy array of shape (n_samples, num_features)
            y: 1D array-like of binary labels
        """
        self.X = X
        self.y = y

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        vec = self.X[idx]
        label = self.y[idx]

        # Convert vector to symmetric matrix
        dfc_matrix = vec_to_symmetric_matrix(vec)

        # Convert symmetric matrix to graph data object
        data = dfc_to_graph(dfc_matrix, label)

        return data