# **Machine Learning for Neuroimaging: Connectome Analysis with Python**

## **Introduction**
### **Welcome and Overview**
Welcome everyone to this tutorial session on Machine Learning for Neuroimaging.

**Objective:** Today, we will explore applying machine learning to analyzing connectomes using Python libraries.

**Connectomes:** A connectome is a comprehensive map of neural connections in the brain. Studying connectomes can help us understand the intricate organization of neural (brain) networks.



## **Python Libraries for Connectome Analysis (10 minutes)**
### **Essential Libraries**
*  [NumPy](https://numpy.org/):https For numerical computing and handling multi-dimensional data.
*  [Pandas](https://pandas.pydata.org/): For structured data operations and manipulations.
*  [Matplotlib](https://matplotlib.org/): For creating static, interactive, and animated visualizations in Python.
*  [scikit-learn](https://scikit-learn.org/stable/): For implementing machine learning algorithms.
*  [nibabel](https://nipy.org/nibabel/): For reading and writing neuroimaging data formats.
*  [nilearn](https://nilearn.github.io/stable/index.html): For advanced neuroimaging data manipulation and visualization.
*  [PyTorch Geometric](https://pytorch-geometric.readthedocs.io/en/latest/): For writing and training Graph Neural Networks (GNNs).
*  [DGL (Deep Graph Library)](https://www.dgl.ai/): For deep learning on GNNs.

In [None]:
%%capture
# @title Run to install needed packages.
!pip install nilearn torch torchvision torchaudio torch-geometric

In [None]:
# Importing required libraries
import matplotlib.pyplot as plt
import numpy as np
import nibabel as nib
import nilearn
from nilearn import plotting
from nilearn import datasets


# Load a sample dataset provided by nilearn
# Here we will use the MNI152 template which is a standard MRI brain template
mni152_template = datasets.load_mni152_template()

# Visualize the template
plotting.plot_img(mni152_template, title="MNI152 Template")

# Fetch a dataset of functional connectomes
# We will use the nilearn datasets.fetch_atlas_msdl function to fetch atlas with group brain parcellations
msdl_atlas_dataset = datasets.fetch_atlas_msdl()
msdl_atlas = msdl_atlas_dataset.maps

# Visualize the atlas
plotting.plot_prob_atlas(msdl_atlas, title='Multiple Sclerosis Detection Atlas')

# Fetch resting-state functional connectivity datasets
# We will use the nilearn datasets.fetch_adhd function to fetch some resting-state functional connectivity data
adhd_dataset = datasets.fetch_adhd(n_subjects=1)

# Print basic information on the dataset
print('First subject functional nifti images (4D) are at: %s' %
      adhd_dataset.func[0])  # 4D data

# Load the functional connectivity data of the first subject
func_filename = adhd_dataset.func[0]
func_data = nib.load(func_filename)

# Visualize the functional connectivity data
# We will use the mean image as the background to overlay on
mean_func_img = nilearn.image.mean_img(func_filename)
plotting.plot_epi(mean_func_img, title='Resting state functional connectivity')

# Show plots (if not automatically displayed)
plotting.show()


# **Introduction to Graph Neural Networks (GNNs)**

Graph Neural Networks (GNNs) are a category of neural networks designed to perform inference on data that can be structured as graphs. They are particularly powerful for tasks where the data is inherently graph-structured, such as social networks, molecular structures, and, notably, neuroimaging data where the brain's functional and structural connectivity patterns can be represented as graphs.

## **Key Concepts**
### Node
A node in a graph represents an entity. For instance, in neuroimaging, a node could represent a brain region. In mathematical terms, a node is denoted as $v$ in a graph $G$.

### Edge
An edge represents a relationship or connection between two nodes. In neuroimaging, an edge could represent a functional or structural connectivity *(e.g. Pearson's correlation cofficient)* between two brain regions. Mathematically, an edge is denoted as $(u, v)$, connecting node $u$ to node $v$.

### Graph
A graph is defined by a set of nodes and a set of edges. Each edge connects a pair of nodes. A graph $G$ is typically defined as $G = (V, E)$, where $V$ is the set of nodes and $E$ is the set of edges.





## **Examples of Graph Neural Networks**

### **Graph Convolutional Networks (GCNs)**
GCNs generalize convolutional neural networks (CNNs) to work on graph data. They operate by aggregating information from a node's neighbors to learn a representation of the node.

**GCN Operation**

The basic operation of a GCN on a node $v$ can be described by the following equation:

$$
h_v^{(l+1)} = \sigma \left( W^{(l)} \sum_{u \in \mathcal{N}(v)} \frac{1}{c_{uv}} h_u^{(l)} + b^{(l)} \right)
$$

where:
- $h_v^{(l+1)}$ is the feature representation of node $v$ at layer $l+1$.
- $\sigma$ is a non-linear activation function, such as ReLU.
- $W^{(l)}$ and $b^{(l)}$ are the trainable weight and bias at layer $l$.
- $\mathcal{N}(v)$ is the set of neighbors of $v$.
- $c_{uv}$ is a normalization constant (often the degree of the node).



In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv


def get_y(dataset):
    """
    Get the y values from a list of Data objects.
    """
    y = []
    for d in dataset:
        y.append(d.y.numpy())
    return np.array(y)

# Load a dataset
dataset = Planetoid(root='/tmp/Cora', name='Cora')
print(f'Number of graphs: {len(dataset)}')
print("Number and count of classes: ", np.unique(get_y(dataset).flatten(), return_counts=True))
print("Number of node features: ", dataset.num_node_features)
print(f'Number of edge features: {dataset.num_edge_features}')


# Define a GCN model
class GCN(torch.nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(dataset.num_node_features, 16)
        self.conv2 = GCNConv(16, dataset.num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)

        return F.log_softmax(x, dim=1)

In [None]:
# Initialize the model
GCN_model = GCN()
print(GCN_model)

# Move to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
GCN_model = GCN_model.to(device)
data = dataset[0].to(device)

# Define loss and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(GCN_model.parameters(), lr=0.01)


In [None]:
# Train the model
GCN_model.train()
loss_history =[]

for epoch in range(200):
    # Initialize variables to track the loss and accuracy for each epoch
    epoch_loss = 0.0
    epoch_correct_predictions = 0
    epoch_total_predictions = 0

    # Loop over each batch from the data loader
    for batch in dataset:
        # Move batch to device
        batch = batch.to(device)
        # Zero the parameter gradients
        optimizer.zero_grad()
        # Forward pass
        out = GCN_model(batch)
        # Calculate loss
        loss = criterion(out, data.y)
        epoch_loss += loss.item()
        # Backward pass: compute gradient of the loss with respect to model parameters
        loss.backward()
        # Update the weights
        optimizer.step()
        # Calculate the number of correct predictions
        predictions = out.argmax(dim=1)
        epoch_correct_predictions += (predictions == data.y).sum().item()
        epoch_total_predictions += len(data.train_mask)

    # Calculate the average loss and accuracy for the epoch
    epoch_loss /= len(dataset)
    loss_history.append(epoch_loss)
    epoch_accuracy = epoch_correct_predictions / epoch_total_predictions

    if epoch % 10 == 0: # Print every 10 epochs
        print(f'Epoch {epoch+1}, Loss: {epoch_loss:.3f}, Accuracy: {epoch_accuracy:.3f}')

# Plot training loss history over epochs
plt.plot(loss_history)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss over Epochs')
plt.show()

### **Graph Attention Networks (GATs)**
GATs introduce the mechanism of attention to GNNs by each node to weigh its neighboring nodes' contributions differently.

**GAT Operation**

The key operation of a GAT can be described by the following equation:

$$
h_v^{(l+1)} = \sigma \left( \sum_{u \in \mathcal{N}(v)} \alpha_{uv} W^{(l)} h_u^{(l)} \right)
$$

where:
- $\alpha_{uv}$ is the attention coefficient that indicates the importance of node $u$'s features to node $v$.
- $W^{(l)}$ is the trainable weight matrix at layer $l$.
- $h_u^{(l)}$ is the feature representation of node $u$ at layer $l$.
- The attention coefficients $\alpha_{uv}$ are typically computed using a parametric function of the node features, which is learned during training.


In [None]:
from torch_geometric.nn import GATConv

# Define a GAT model
class GAT(torch.nn.Module):
    def __init__(self):
        super(GAT, self).__init__()
        self.conv1 = GATConv(dataset.num_node_features, 8, heads=8, dropout=0.6)
        self.conv2 = GATConv(8 * 8, dataset.num_classes, heads=1, concat=True, dropout=0.6)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.dropout(x, p=0.6, training=self.training)
        x = F.elu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

# Initialize the GAT model
gat_model = GAT()
print(gat_model)

In [None]:
# Move to GPU if available
gat_model = gat_model.to(device)

# Define loss and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(gat_model.parameters(), lr=0.01)

# Train the GAT model
gat_model.train()
loss_history =[]

for epoch in range(200):
    # Initialize variables to track the loss and accuracy for each epoch
    epoch_loss = 0.0
    epoch_correct_predictions = 0
    epoch_total_predictions = 0

    # Loop over each batch from the data loader
    for batch in dataset:
        # Move batch to device
        batch = batch.to(device)
        # Zero the parameter gradients
        optimizer.zero_grad()
        # Forward pass
        out = gat_model(batch)
        # Calculate loss
        loss = criterion(out, data.y)
        epoch_loss += loss.item()
        # Backward pass: compute gradient of the loss with respect to model parameters
        loss.backward()
        # Update the weights
        optimizer.step()
        # Calculate the number of correct predictions
        predictions = out.argmax(dim=1)
        epoch_correct_predictions += (predictions == data.y).sum().item()
        epoch_total_predictions += len(data.train_mask)

    # Calculate the average loss and accuracy for the epoch
    epoch_loss /= len(dataset)
    loss_history.append(epoch_loss)
    epoch_accuracy = epoch_correct_predictions / epoch_total_predictions

    if epoch % 10 == 0: # Print every 10 epochs
        print(f'Epoch {epoch+1}, Loss: {epoch_loss:.3f}, Accuracy: {epoch_accuracy:.3f}')

# Plot training loss history over epochs
plt.plot(loss_history)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss over Epochs')
plt.show()

In [None]:
from sklearn.manifold import TSNE

# Test GAT model on the first sample in the dataset
gat_model.eval()
_, pred = gat_model(data).max(dim=1)

# Get the embeddings
embeddings = gat_model.conv1(data.x, data.edge_index)

# Reduce the embeddings to 2 dimensions using t-SNE
tsne = TSNE(n_components=2)
embeddings_2d = tsne.fit_transform(embeddings.detach().cpu().numpy())

# Plot
plt.figure(figsize=(10, 8))
scatter = plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], c=pred.cpu().numpy(), cmap='Set1')
plt.legend(*scatter.legend_elements(), title="Classes")
plt.show()


*We observe good separation of all 7 classes due to a high classification performance of the GAT model.*

## **Applications to Neuroimaging**: ABIDE and ADHD Case Studies
Preface: In neuroimaging, GNNs can be used to analyze brain connectivity patterns. Each brain region can be represented as a node, and the connections (either structural or functional) between regions are the edges. GNNs can help in tasks such as:

*  Classification of cognitive states or disorders.
*  Regression tasks to predict behavioral or genetic traits.
*  Clustering or community detection within brain networks.

Case:



# ABIDE Classification & Regression Tasks

The Autism Brain Imaging Data Exchange (ABIDE) dataset provides previously collected resting state functional magnetic resonance imaging datasets from 539 individuals with ASD and 573 typical controls for the purpose of data sharing in the broader scientific community. This grass-root initiative involved 16 international sites, sharing 20 samples yielding 1112 datasets composed of both MRI data and an extensive array of phenotypic information common across nearly all sites (see below).

Note that this is the preprocessed version of ABIDE provided by the preprocess connectome projects (PCP).

For more information about this dataset's structure: http://preprocessed-connectomes-project.github.io http://www.childmind.org/en/healthy-brain-network/abide/

Nielsen, Jared A., et al. "Multisite functional connectivity MRI classification of autism: ABIDE results." Frontiers in human neuroscience 7 (2013).



**Note:** The ABIDE dataaset consumes a lot of storage space so we will provide pre-parcellated connectivity matrices below. Let's run through the steps we used to do so on fewer subjects.

In [None]:
# Fetch the ABIDE dataset
abide = datasets.fetch_abide_pcp(n_subjects=5, pipeline="cpac",
                                 derivatives=['func_preproc'],
                                 quality_checked=True,
                                 legacy_format=False)


In [None]:
# Store the filenames of the functional scans
fmri_filenames = abide.func_preproc

# Check the number of subject functional scans fetched
print(f"Number of subjects: {len(fmri_filenames)}")

We need to decide which parcellation to use for rs-fMRI data. We are going to use the AAL atlas (...).

In [None]:
from nilearn import plotting

# Retrieve brain atlas for parcellation
parcellations = datasets.fetch_atlas_aal()
atlas_filename = parcellations.maps
labels = parcellations.labels
print(f"Number of ROIs: {len(labels)}")

# Plot atlas
plotting.plot_roi(atlas_filename, draw_cross=False)

Select a single subject's scan to check information.

In [None]:
sample_subject = abide.func_preproc[0]
sample_subject

Using NiftiLabelsMasker we will create a mask on our functional images with the labels of the chosen atlas and extract the time series in each ROI. Because the data is already preprocessed, we do not need to regress out any confounds.

In [None]:
from nilearn.input_data import NiftiLabelsMasker

masker = NiftiLabelsMasker(labels_img=atlas_filename,
                           standardize='zscore_sample',  #z-score each sample to zero mean scaled to unit variance w.r.t. sample std
                           memory='nilearn_cache',
                           verbose=0)

time_series = masker.fit_transform(sample_subject)

Now we are going to extract the connectivity matrix for each sample using the pre-defined masker.

In [None]:
from nilearn.connectome import ConnectivityMeasure

correlation_measure = ConnectivityMeasure(kind='correlation')
connectivity_matrix = correlation_measure.fit_transform([time_series])[0]
print(f"Connectivity matrix shape: {connectivity_matrix.shape}")

...and plot it.

In [None]:
import numpy as np

np.fill_diagonal(connectivity_matrix, 0)

plotting.plot_matrix(connectivity_matrix, figure=(10, 8),
                     labels=range(time_series.shape[-1]),
                     vmax=0.8, vmin=-0.8, reorder=False)

In [None]:
# @title Load parcellated, connectivity matrices as our features, X
import numpy as np
import pickle as pkl

# Load the connectivity matrices
conn_matrices = './ABIDE_conn_matrices.npz'
X = np.load(conn_matrices)['a']

with open('./abide_preproc_300.pkl', 'rb') as f:
    abide = pkl.load(f)

# Store the filenames of the functional scans
fmri_filenames = abide.func_preproc

# Check the number of subject functional scans fetched
print(f"Number of subjects: {len(fmri_filenames)}")

Now we have the connectivity matrices for all subjects. Let's see the shape of our connectivity matrices.

In [None]:
X.shape

Accompanying the data set is a csv file containing the phenotypic data. According to the Phenotypic Data Legend which can be downloaded [here](http://fcon_1000.projects.nitrc.org/indi/abide/abide_I.html), the column DX_GROUP has the information about the diagnostic group each participant is in. It is coded as:

*   1 = Autism
*   2 = Control

Let's import the csv.

In [None]:
import pandas as pd
phenotypic = pd.read_csv("./Phenotypic_V1_0b_preprocessed1.csv")

Let's use the file names to get the right values from the DX_GROUP column.

In [None]:
file_ids = []

# Get the file IDs from the file names
for f in fmri_filenames:
    file_ids.append(f[-27:-20])

# Get labels for autism diagnosis => classification task
y_asd = []
for i in range(len(phenotypic)):
    for j in range(len(file_ids)):
        if file_ids[j] in phenotypic.FILE_ID[i]:
            y_asd.append(phenotypic.DX_GROUP[i])

# Get labels for full-scale IQ (FIQ) => regression task
y_fiq = []
for i in range(len(phenotypic)):
    for j in range(len(file_ids)):
        if file_ids[j] in phenotypic.FILE_ID[i]:
            y_fiq.append(phenotypic.FIQ[i])

Now, we're ready to prepare out data for graph learning. We will first create a custom PyTorch Geometric dataset to convert connectivity matrices, X to edge index and edge attributes for our graphs.

In [None]:
# @title Create custom PyG ConnectomeDataset
import torch
from torch_geometric.data import Data, Dataset
from torch_geometric.utils import dense_to_sparse
class ConnectomeDataset(Dataset):
    def __init__(self, connectivity_matrices, labels, task="classification", transform=None, pre_transform=None):
        super(ConnectomeDataset, self).__init__(None, transform, pre_transform)
        self.connectivity_matrices = connectivity_matrices
        self.labels = labels
        self.task = task

    def len(self):
        return len(self.connectivity_matrices)

    def get(self, idx):
        # Convert the connectivity matrix to edge index and edge attributes
        connectivity_matrix = torch.tensor(self.connectivity_matrices[idx])
        edge_index, edge_attr = dense_to_sparse(connectivity_matrix)

        # Create a data object
        data = Data(edge_index=edge_index, edge_attr=edge_attr)
        data.x = connectivity_matrix.to(torch.float)
        if self.task == "classifiction":
          data.y = torch.tensor([self.labels[idx]-1], dtype=torch.long) # make labels start at 0
        else:
          data.y = torch.tensor([self.labels[idx]], dtype=torch.float)
        return data

The labels contain diagnostic group each participant is in. It is coded as:

*   1 = Autism Spectrum Disorder (ASD)
*   2 = Control

**Let's re-index this to 0=ASD and 1=Control due to zero-indexed systems required for ML and other data processing software.**

In [None]:
from collections import Counter

# Adjust labels to start from 0
y_asd = np.array(y_asd)
y_asd = y_asd - 1

# Print label classes and counts
print(Counter(y_asd))

### Classification Task

In [None]:
import torch
import torch_geometric.transforms as T
import torch.nn.functional as F

from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.loader import DataLoader

# Instantiate the dataset
abide_dataset_asd = ConnectomeDataset(X, y_asd)

loader = DataLoader(abide_dataset_asd, batch_size=32, shuffle=True)

print(f'Number of graphs: {len(abide_dataset_asd)}')
print("Number and count of classes: ", np.unique(get_y(abide_dataset_asd).flatten(), return_counts=True))
print("Number of node features: ", abide_dataset_asd.num_node_features)
print(f'Number of edge features: {abide_dataset_asd.num_edge_features}')

num_classes = 2

# Define a GCN model
class GCN(torch.nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(abide_dataset_asd.num_node_features, 16)
        self.conv2 = GCNConv(16, num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        # Apply global mean pooling to get a single vector for the whole graph
        x = global_mean_pool(x, batch=data.batch)
        x = F.log_softmax(x, dim=1)

        return x

In [None]:
# Initialize the model
GCN_model = GCN()
print(GCN_model)

# Move to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
GCN_model = GCN_model.to(device)

# Define loss and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, GCN_model.parameters()),
            lr=0.0001,
            weight_decay=0.0001
        )

In [None]:
# Train the model
GCN_model.train()
loss_history =[]

for epoch in range(20):
    # Initialize variables to track the loss and accuracy for each epoch
    epoch_loss = 0.0
    epoch_correct_predictions = 0
    epoch_total_predictions = 0

    # Loop over each batch from the data loader
    for batch in loader:
        # Move batch to device
        batch = batch.to(device)
        # Zero the parameter gradients
        optimizer.zero_grad()
        # Forward pass
        out = GCN_model(batch)
        # Calculate loss
        loss = criterion(out, batch.y)
        epoch_loss += loss.item()
        # Backward pass: compute gradient of the loss with respect to model parameters
        loss.backward()
        # Update the weights
        optimizer.step()
        # Calculate the number of correct predictions
        predictions = out.argmax(dim=1)
        epoch_correct_predictions += (predictions == batch.y).sum().item()
        epoch_total_predictions += len(batch)

    # Calculate the average loss and accuracy for the epoch
    epoch_loss /= len(loader)
    loss_history.append(epoch_loss)
    epoch_accuracy = epoch_correct_predictions / epoch_total_predictions

    if epoch % 10 == 0: # Print every 10 epochs
        print(f'Epoch {epoch+1}, Loss: {epoch_loss:.3f}, Accuracy: {epoch_accuracy:.3f}')

# Plot training loss history over epochs
plt.plot(loss_history)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss over Epochs')
plt.show()

### Comparison to Baseline

Let's compare our model's performance to baseline traditional ML methods.

In [None]:
from sklearn.model_selection import train_test_split

X_train, X_val, y_train, y_val = train_test_split(X, # x
                                                  y_asd, # y
                                                  test_size = 0.4, # 60%/40% split
                                                  shuffle = True, # shuffle dataset before splitting
                                                  stratify = y_asd,  # keep distribution of ASD consistent between sets
                                                  random_state = 123 # same split each time
                                                 )

Linear SVC
The classifier that is going to be used here is going to be a Linear Support Vector Classification (SVC). We will use cross-validation to estimate our accuracy.

In [None]:
from sklearn.svm import LinearSVC
from sklearn.model_selection import cross_val_predict, cross_val_score

l_svc = LinearSVC(max_iter=100000) # more iterations than the default
l_svc.fit(X_train, y_train)

# predict
y_pred_svc = cross_val_predict(l_svc, X_train, y_train, cv=10)
# scores
acc_svc = cross_val_score(l_svc, X_train, y_train, cv=10)

print("Accuracy:", acc_svc)
print("Mean accuracy:", acc_svc.mean())

### Regression Task

Now, we'll do a regression task for predicting the full-scale intelligence (FIQ) of the subjects.

In [None]:
# Use the dataset with a DataLoader for batching
abide_dataset_fiq = ConnectomeDataset(X, y_fiq)
train_loader = DataLoader(abide_dataset_fiq[:255], batch_size=32, shuffle=True)
test_loader = DataLoader(abide_dataset_fiq[255:], batch_size=32)

In [None]:
import torch

num_classes = 1
# Define a RegGNN model
class RegGNN(nn.Module):
    '''Regression using a DenseGCNConv layer from pytorch geometric.

       Layers in this model are identical to GCNConv.
       Adapted from: https://github.com/basiralab/RegGNN/blob/main/proposed_method/RegGNN.py
    '''

    def __init__(self, hidden_dim=64):
        super(RegGNN, self).__init__()

        self.gc1 = GCNConv(abide_dataset_fiq.num_features, hidden_dim)
        self.gc2 = GCNConv(hidden_dim, hidden_dim)
        self.LinearLayer = torch.nn.Linear(hidden_dim, 1)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = F.relu(self.gc1(x, edge_index))
        x = F.dropout(x, p=0.5, training=self.training)
        x = F.relu(self.gc2(x, edge_index))
        # Apply global mean pooling to get a single vector for the whole graph
        x = global_mean_pool(x, data.batch)
        x = self.LinearLayer(x)

        return x

In [None]:
# Initialize the model
RegGNN_model = RegGNN()
print(RegGNN_model)

# Move to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
RegGNN_model = RegGNN_model.to(device)

# Define loss and optimizer
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(GCN_model.parameters(), lr=0.01)
# optimizer = torch.optim.Adam(
#             filter(lambda p: p.requires_grad, RegGNN_model.parameters()),
#             lr=0.0001,
#             weight_decay=0.0001
#         )

In [None]:
# Train the model
RegGNN_model.train()
loss_history =[]

for epoch in range(50):
    # Initialize variables to track the loss and accuracy for each epoch
    epoch_loss = 0.0
    epoch_correct_predictions = 0
    epoch_total_predictions = 0

    # Loop over each batch from the data loader
    for batch in train_loader:
        # Move batch to device
        batch = batch.to(device)
        # Zero the parameter gradients
        optimizer.zero_grad()
        # Forward pass
        out = RegGNN_model(batch)
        # Calculate loss
        loss = criterion(out, batch.y)
        epoch_loss += float(loss)
        # Backward pass: compute gradient of the loss with respect to model parameters
        loss.backward()
        # Update the weights
        optimizer.step()
        # Calculate the number of correct predictions
        predictions = out.argmax(dim=1)
        epoch_correct_predictions += (predictions == batch.y).sum().item()
        epoch_total_predictions += len(batch)

    # Calculate the average loss and accuracy for the epoch
    epoch_loss /= len(train_loader)
    loss_history.append(epoch_loss)
    epoch_accuracy = epoch_correct_predictions / epoch_total_predictions

    if epoch % 10 == 0: # Print every 10 epochs
        print(f'Epoch {epoch+1}, Loss: {epoch_loss:.3f}, Accuracy: {epoch_accuracy:.3f}')

# Plot training loss history over epochs
plt.plot(loss_history)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss over Epochs')
plt.show()