# Predicting Arousal from Brain Activity Using Spatio-Temporal GNNs
## Example model training

In this notebook, we show how to load the brain activity (resting fMRI) data and train an example model. The data is composed of 72, ten-minute brain scans. The brain data has been parcellated to include 116 or 630 general brain regions. 40 of the scans are from the fasted person while 32 of the scans are from the caffeinated person. We expect to see changes in brain connectivity between these two states due to caffeine influencing arousal.  

The model ran here is a graph classification GNN. It is composed of 3 GCNConv layers, each followed by a ReLU non-linearity. Then, we sum a mean and max pooling of all the nodes in one graph, and pass this output through a linear layer and softmax to get the probability of the graph belonging to the caffeinated or the non-caffeinated class.


### Install PyG and mount google drive

In [2]:
!pip install torch==2.4.0



In [3]:
! pip install torch-geometric-temporal



In [4]:
# Install torch geometric
import os
import torch

In [5]:
torch_version = str(torch.__version__)
scatter_src = f"https://pytorch-geometric.com/whl/torch-{torch_version}.html"
sparse_src = f"https://pytorch-geometric.com/whl/torch-{torch_version}.html"
!pip install torch-scatter -f $scatter_src
!pip install torch-sparse -f $sparse_src
!pip install torch-geometric
!pip install -q git+https://github.com/snap-stanford/deepsnap.git

Looking in links: https://pytorch-geometric.com/whl/torch-2.4.0+cu121.html
Looking in links: https://pytorch-geometric.com/whl/torch-2.4.0+cu121.html
  Preparing metadata (setup.py) ... [?25l[?25hdone


In [6]:
import torch_geometric
torch_geometric.__version__

'2.6.1'

In [7]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## Create graph dataset

In this case, we are creating a custom dataset with 72 graphs. We are including connectivity between the nodes and an attribute to hold the time series.

In [95]:
import os
import torch
import pandas as pd
import numpy as np
import networkx as nx
from sklearn.model_selection import train_test_split
from torch_geometric.data import Data
import torch_geometric.utils as pyg_utils
import datetime
import matplotlib.pyplot as plt
from torch_geometric.data import InMemoryDataset, Data
import torch.nn.functional as F

In [96]:
def load_labels(path, sep = ","):
  labels= pd.read_csv(path)
  return labels

In [97]:
def create_connectivity_graphs(connectivity_path, timeseries_path, labels, sep = ",", normalize = True, one_hot = False, embedding_layer = None):
  """
  Takes in the path where all the connectivity matrices and timeseries data are and a DataFrame Labels with the subject codes
  Outputs a list of PyG Data graphs (undirected, weighted normalized)
  """

  graphs = []
  for sub in labels["subcode"]:
    connectivity_file_pth = os.path.join(connectivity_path, f"{sub}.txt")
    timeseries_file_pth = os.path.join(timeseries_path, f"{sub}.txt")

    if not os.path.exists(connectivity_file_pth) or not os.path.exists(timeseries_file_pth):
      continue

    # Load graph into networkx based on connectivity
    matrix = pd.read_csv(connectivity_file_pth, sep=sep, header=None).to_numpy()[1:, 1:]
    matrix -= np.identity(matrix.shape[0])

    timeseries = pd.read_csv(timeseries_file_pth, sep=sep, header=None).to_numpy()[1:, 1:].T

    if normalize:
      # Take absolute value of correlation and make sure each connection sums to 1.
      matrix = np.abs(matrix)
      row_sums = matrix.sum(axis=1, keepdims=True)
      matrix = matrix / (row_sums)
    G = nx.from_numpy_array(matrix)


    # Convert the graph to PyTorch Geometric format
    data = pyg_utils.from_networkx(G)

    # Initialize node features (random)
    num_nodes = data.num_nodes
    if embedding_layer:
        roi_ids = torch.arange(num_nodes)
        with torch.no_grad():
          data.x = F.normalize(embedding_layer(roi_ids), p=2, dim=1)
    else:
      one_hot = True
    if one_hot:
      # One-hot encoding for the brain regions
      data.x = torch.eye(num_nodes)

    # Assign label
    caffeinated = labels.query(f"subcode == '{sub}'")["caffeinated"].iloc[0]
    data.y = torch.tensor([caffeinated], dtype=torch.long)  # Binary labels: 0 or 1
    data.timeseries = timeseries
    graphs.append(data)

  return graphs

In [98]:
class MyConnectomeDataset(InMemoryDataset):
    def __init__(self, graphs, transform=None, pre_transform=None):
        super().__init__(None, transform, pre_transform)
        self.data, self.slices = self.collate(graphs)

In [99]:
embedding_dim = 32  # Dimension of each area embedding
num_brain_rois = 116
embedding_layer = torch.nn.Embedding(num_brain_rois, embedding_dim)

In [100]:
# Create graphs
connecivity_dir = "/content/drive/MyDrive/CS224W Project/data/connectivity_aa116"
timeseries_dir = "/content/drive/MyDrive/CS224W Project/data/timeseries_aa116"
labels_dir = "/content/drive/MyDrive/CS224W Project/data/labels.csv"

# Load labels
labels= load_labels(labels_dir)
# Load connectivity matrices
graphs = create_connectivity_graphs(connecivity_dir, timeseries_dir, labels, embedding_layer = embedding_layer)

In [101]:
dataset = MyConnectomeDataset(graphs)

In [102]:
# Print information about the dataset
print(f'Dataset: {dataset}')
print('-------------------')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of nodes: {dataset[0].x.shape[0]}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')
print(f'Number of time points: {dataset[0].timeseries.shape[1]}')

Dataset: MyConnectomeDataset(72)
-------------------
Number of graphs: 72
Number of nodes: 116
Number of features: 32
Number of classes: 2
Number of time points: 518


### Divide into training and testing

In [103]:
torch.manual_seed(12345)
dataset = dataset.shuffle()

train_dataset = dataset[:50]
test_dataset = dataset[50:]

print(f'Number of training graphs: {len(train_dataset)}')
print(f'Number of test graphs: {len(test_dataset)}')

Number of training graphs: 50
Number of test graphs: 22


In [104]:
from torch_geometric.loader import DataLoader

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)

### Create simple GNN graph classifier
First, we implement a GNN Graph classifier model that only uses the connectivity matrices.

Ref: https://pytorch-geometric.readthedocs.io/en/2.4.0/get_started/introduction.html

In [108]:
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch.nn import Linear, ReLU
from torch_geometric.nn import global_max_pool, global_mean_pool, BatchNorm

class GCN(torch.nn.Module):
    def __init__(self, hidden_dim = 16):
        super().__init__()
        self.conv1 = GCNConv(dataset.num_node_features, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, hidden_dim)
        self.bn1 = BatchNorm(hidden_dim)
        self.bn2 = BatchNorm(hidden_dim)
        self.bn3 = BatchNorm(hidden_dim)
        self.lin = Linear(hidden_dim, dataset.num_classes)
        self.relu = ReLU()

    def forward(self, data):
        x, edge_index, edge_weight = data.x, data.edge_index, data.weight

        x = self.conv1(x, edge_index, edge_weight = edge_weight)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x, edge_index, edge_weight = edge_weight)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.conv3(x, edge_index, edge_weight = edge_weight)
        x = self.bn3(x)
        x = self.relu(x)
        x = global_max_pool(x, data.batch) + global_mean_pool(x, data.batch)
        x = self.lin(x)
        # x = F.softmax(x, dim = 1)

        return x

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device', device)

model = GCN(hidden_dim=16).to(device)
model = GCN()
print(model)

Device cpu
GCN(
  (conv1): GCNConv(32, 16)
  (conv2): GCNConv(16, 16)
  (conv3): GCNConv(16, 16)
  (bn1): BatchNorm(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn2): BatchNorm(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn3): BatchNorm(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (lin): Linear(in_features=16, out_features=2, bias=True)
  (relu): ReLU()
)


### Train and test model
Notice that the performance is very poor (basically predicting the same category for all). This implies we could leverage using temporal data.

In [109]:
def evaluate_model(model, data_loader):
    model.eval()
    correct_per_class = {0: 0, 1: 0}
    incorrect_per_class = {0: 0, 1: 0}

    with torch.no_grad():  # Disable gradient calculations for evaluation
        for data in data_loader:
            data = data.to(device)
            out = model(data)
            pred = out.argmax(dim=1)  # Get predicted class

            # Compare predictions to actual labels
            for i in range(len(pred)):
                true_label = data.y[i].item()
                predicted_label = pred[i].item()
                if predicted_label == true_label:
                    correct_per_class[true_label] += 1
                else:
                    incorrect_per_class[true_label] += 1

    return correct_per_class, incorrect_per_class

In [110]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

model.train()
for epoch in range(100):
    loss_all = 0
    predicted, real = [], []
    for batch_data in train_loader:
        batch_data = batch_data.to(device)
        optimizer.zero_grad()  # Reset gradients from the previous iteration
        out = model(batch_data)
        pred = out.argmax(dim=1)
        loss = F.nll_loss(out, batch_data.y)
        loss.backward()  # Calculate gradients
        loss_all += batch_data.num_graphs * loss.item()
        optimizer.step()  # Update model parameters based on calculated gradients

    if epoch % 10 == 0:
        print(f'epoch: {epoch}, loss: {loss_all}')
        print()

epoch: 0, loss: -441.4660530090332

epoch: 10, loss: -42916.19177246094

epoch: 20, loss: -162427.21630859375

epoch: 30, loss: -365231.0126953125

epoch: 40, loss: -647005.12890625

epoch: 50, loss: -1001379.03125

epoch: 60, loss: -1428241.46875

epoch: 70, loss: -1922525.5078125

epoch: 80, loss: -2483750.015625

epoch: 90, loss: -3110322.28125



In [111]:
# Evaluation on the training set
train_correct, train_incorrect = evaluate_model(model, train_loader)
print("Training Set Evaluation:")
print(f"Correct predictions per class: {train_correct}")
print(f"Incorrect predictions per class: {train_incorrect}")

# Evaluation on the test set
test_correct, test_incorrect = evaluate_model(model, test_loader)
print("Test Set Evaluation:")
print(f"Correct predictions per class: {test_correct}")
print(f"Incorrect predictions per class: {test_incorrect}")

Training Set Evaluation:
Correct predictions per class: {0: 27, 1: 0}
Incorrect predictions per class: {0: 0, 1: 23}
Test Set Evaluation:
Correct predictions per class: {0: 13, 1: 0}
Incorrect predictions per class: {0: 0, 1: 9}
