# Graph neural network (GNN) basics

## Table of contents

1. [Understanding graph neural networks (GNNs)](#understanding-graph-neural-networks-gnns)
2. [Setting up the environment](#setting-up-the-environment)
3. [Defining graph data](#defining-graph-data)
4. [Building a basic message-passing mechanism](#building-a-basic-message-passing-mechanism)
5. [Implementing a simple graph convolution layer](#implementing-a-simple-graph-convolution-layer)
6. [Building a basic GNN model](#building-a-basic-gnn-model)
7. [Training the GNN on a node classification task](#training-the-gnn-on-a-node-classification-task)
8. [Evaluating the GNN model](#evaluating-the-gnn-model)
9. [Experimenting with different configurations](#experimenting-with-different-configurations)

## Understanding graph neural networks (GNNs)

### **Key concepts**
Graph Neural Networks (GNNs) are a class of neural networks designed to process data represented as graphs, where nodes represent entities and edges represent relationships. Unlike traditional neural networks, which operate on structured data like grids or sequences, GNNs can handle non-Euclidean data, making them suitable for tasks involving irregular, interconnected data structures.

Key elements of GNNs include:
- **Node Features**: Represent attributes or characteristics of individual nodes.
- **Edge Features**: Capture the relationships or interactions between nodes.
- **Message Passing**: Nodes exchange information with their neighbors to update their representations iteratively.
- **Graph Representation**: The model learns embeddings for nodes, edges, or the entire graph, depending on the task.

Popular GNN architectures include Graph Convolutional Networks (GCNs), Graph Attention Networks (GATs), and GraphSAGE, each specializing in different aspects of graph learning.

### **Applications**
GNNs have a wide range of applications across various domains:
- **Social networks**: Analyzing user interactions for recommendations, influence detection, or community detection.
- **Molecular biology**: Predicting molecular properties or drug interactions based on chemical structure graphs.
- **Knowledge graphs**: Enhancing link prediction, node classification, and graph completion tasks.
- **Recommendation systems**: Personalizing content or product suggestions by modeling user-item interaction graphs.
- **Traffic networks**: Analyzing and predicting traffic flow in road or transportation networks.

### **Advantages**
- **Flexible data handling**: Processes graph-structured data of varying sizes and connectivity.
- **Relational reasoning**: Captures relationships and dependencies between entities effectively.
- **Generalization**: Learns embeddings that generalize well to unseen nodes or subgraphs.
- **Scalability**: Supports various graph sizes through efficient message-passing mechanisms.

### **Challenges**
- **Scalability for large graphs**: Training on massive graphs requires substantial computational resources and optimization techniques.
- **Over-smoothing**: Node representations may become indistinguishable with excessive message-passing layers.
- **Irregular data**: Processing graphs with heterogeneous structures or dynamic topologies can be complex.
- **Data dependency**: Performance depends heavily on the quality and completeness of the graph structure and features.

## Setting up the environment


##### **Q1: How do you install the necessary libraries for building a GNN in PyTorch?**


In [28]:
# !pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
# !pip install torch-geometric
# !pip install numpy matplotlib seaborn

##### **Q2: How do you import the required modules for constructing a GNN and handling graph data in PyTorch?**


In [29]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data

##### **Q3: How do you configure the environment to use GPU for training the GNN model in PyTorch?**

In [30]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


## Defining graph data


##### **Q4: How do you represent a graph using an adjacency matrix in PyTorch?**


In [31]:
adj_matrix = torch.tensor([
    [0, 1, 1, 0],
    [1, 0, 1, 1],
    [1, 1, 0, 1],
    [0, 1, 1, 0]
], dtype=torch.float32)

##### **Q5: How do you define node features for each node in the graph as input to the GNN?**


In [32]:
x = torch.tensor([
    [1, 0],
    [0, 1],
    [1, 1],
    [0, 0]
], dtype=torch.float32)

##### **Q6: How do you convert graph edges into an edge list to represent the connections between nodes?**

In [33]:
edge_index = torch.tensor([
    [0, 0, 1, 1, 1, 2, 2, 3, 3],
    [1, 2, 0, 2, 3, 0, 3, 1, 2]
], dtype=torch.long)

## Building a basic message-passing mechanism


##### **Q7: How do you implement a basic message-passing mechanism between neighboring nodes in a graph?**


In [34]:
def message_passing(x, edge_index):
    source_nodes = edge_index[0]  # sender nodes
    messages = x[source_nodes]  # fetch features of source nodes
    return messages

##### **Q8: How do you aggregate messages from neighboring nodes using operations in PyTorch?**


In [35]:
def aggregate_messages(messages, edge_index, num_nodes):
    target_nodes = edge_index[1]  # receiver nodes
    aggregated = torch.zeros(num_nodes, messages.size(1)).to(messages.device)
    aggregated.index_add_(0, target_nodes, messages)  # sum messages for each target
    return aggregated

##### **Q9: How do you implement node updates by combining aggregated messages with the node's own features?**

In [36]:
def update_nodes(x, aggregated):
    return x + aggregated  # residual connection

## Implementing a simple graph convolution layer


##### **Q10: How do you define a simple graph convolution layer using `torch.nn.Module` in PyTorch?**


In [37]:
class GraphConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.linear = nn.Linear(in_channels, out_channels)

##### **Q11: How do you implement the forward pass of the graph convolution layer to compute new node embeddings?**


In [38]:
class GraphConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.linear = nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        messages = message_passing(x, edge_index)  # fetch messages from neighbors
        aggregated = aggregate_messages(messages, edge_index, x.size(0))  # sum by target
        updated = update_nodes(x, aggregated)  # residual update
        return self.linear(updated)  # linear transformation

##### **Q12: How do you apply a non-linearity, such as ReLU, after computing the graph convolution to update node features?**

In [39]:
class GraphConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.linear = nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        messages = message_passing(x, edge_index)  # fetch messages from neighbors
        aggregated = aggregate_messages(messages, edge_index, x.size(0))  # sum by target
        updated = update_nodes(x, aggregated)  # residual update
        out = self.linear(updated)  # linear transformation
        return F.relu(out)  # apply relu

## Building a basic GNN model


##### **Q13: How do you stack multiple graph convolution layers to build a simple GNN model in PyTorch?**


In [40]:
class SimpleGNN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GraphConv(in_channels, hidden_channels)
        self.conv2 = GraphConv(hidden_channels, out_channels)

##### **Q14: How do you define the forward pass of the GNN model to process node features through multiple graph convolution layers?**


In [41]:
class SimpleGNN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GraphConv(in_channels, hidden_channels)
        self.conv2 = GraphConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)  # first graph convolution
        x = self.conv2(x, edge_index)  # second graph convolution
        return x

##### **Q15: How do you implement dropout and batch normalization in the GNN model to improve generalization?**

In [42]:
class SimpleGNN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, dropout_p=0.5):
        super().__init__()
        self.conv1 = GraphConv(in_channels, hidden_channels)
        self.bn1 = nn.BatchNorm1d(hidden_channels)
        self.conv2 = GraphConv(hidden_channels, out_channels)
        self.dropout_p = dropout_p

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)  # first graph convolution
        x = self.bn1(x)  # batch norm after first layer
        x = F.dropout(x, p=self.dropout_p, training=self.training)  # dropout
        x = self.conv2(x, edge_index)  # second graph convolution
        return x

## Training the GNN on a node classification task


##### **Q16: How do you define the loss function for training the GNN model on a node classification task?**


In [43]:
criterion = nn.CrossEntropyLoss()

##### **Q17: How do you set up the optimizer to update the GNN model parameters during training?**


In [44]:
model = SimpleGNN(in_channels=2, hidden_channels=8, out_channels=2).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

##### **Q18: How do you implement the training loop for the GNN, including the forward pass, loss computation, and backpropagation?**


In [45]:
class GraphData:
    def __init__(self, x, edge_index):
        self.x = x
        self.edge_index = edge_index

x_input = torch.tensor([[1, 0], [0, 1], [1, 1], [0, 0]], dtype=torch.float32).to(device)
edge_index_input = torch.tensor([[0, 0, 1, 1, 1, 2, 2, 3, 3],
                                 [1, 2, 0, 2, 3, 0, 3, 1, 2]], dtype=torch.long).to(device)
labels = torch.tensor([0, 1, 0, 1], dtype=torch.long).to(device)

data = GraphData(x_input, edge_index_input)

In [46]:
def train(model, data, labels, optimizer, criterion):
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)  # forward pass
    loss = criterion(out, labels)  # compute loss
    loss.backward()  # backprop
    optimizer.step()
    return loss.item()

##### **Q19: How do you track and log the accuracy and loss over training epochs to monitor the GNN model’s performance?**

In [47]:
def evaluate(model, data, labels):
    model.eval()
    with torch.no_grad():
        out = model(data.x, data.edge_index)
        preds = out.argmax(dim=1)
        correct = preds.eq(labels).sum().item()
        acc = correct / labels.size(0)
    return acc

In [48]:
losses, accs = [], []
for epoch in range(301):
    loss = train(model, data, labels, optimizer, criterion)
    acc = evaluate(model, data, labels)
    losses.append(loss)
    accs.append(acc)
    if epoch % 100 == 0:
        print(f"Epoch {epoch:02d}  Loss: {loss:.4f}  Accuracy: {acc:.4f}")

Epoch 00  Loss: 0.8202  Accuracy: 0.5000
Epoch 100  Loss: 0.6931  Accuracy: 1.0000
Epoch 200  Loss: 0.6443  Accuracy: 1.0000
Epoch 300  Loss: 0.5450  Accuracy: 1.0000


## Evaluating the GNN model


##### **Q20: How do you evaluate the GNN model on a validation or test dataset and calculate its accuracy for node classification?**


In [49]:
test_x = torch.tensor([[1, 0], [0, 1], [1, 1], [0, 0]], dtype=torch.float32).to(device)
test_edge_index = torch.tensor([[0, 1, 2, 3, 0, 1],
                                [1, 2, 3, 0, 2, 3]], dtype=torch.long).to(device)
test_labels = torch.tensor([0, 1, 0, 1], dtype=torch.long).to(device)
test_data = GraphData(test_x, test_edge_index)

In [50]:
test_acc = evaluate(model, test_data, test_labels)
print(f"Test Accuracy: {test_acc:.4f}")

Test Accuracy: 0.7500


##### **Q21: How do you implement a function to perform inference using the trained GNN model on new graph data?**

In [51]:
def infer(model, data):
    model.eval()
    with torch.no_grad():
        out = model(data.x.to(device), data.edge_index.to(device))
        preds = out.argmax(dim=1)
    return preds

In [52]:
predictions = infer(model, test_data)
print(f"Predicted labels: {predictions.tolist()}")

Predicted labels: [0, 0, 0, 1]


## Experimenting with different configurations


##### **Q22: How do you experiment with different numbers of graph convolution layers and observe the effect on model performance?**


In [53]:
class DeeperGNN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GraphConv(in_channels, hidden_channels)
        self.conv2 = GraphConv(hidden_channels, hidden_channels)
        self.conv3 = GraphConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = self.conv2(x, edge_index)
        x = self.conv3(x, edge_index)
        return x

In [54]:
model = DeeperGNN(2, 8, 2).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [55]:
for epoch in range(201):
    loss = train(model, data, labels, optimizer, criterion)
acc = evaluate(model, data, labels)
print(f"Accuracy with 3-layer GNN: {acc:.4f}")

Accuracy with 3-layer GNN: 0.5000


##### **Q23: How do you adjust the hidden dimension size in the GNN layers to analyze its impact on training time and accuracy?**


In [56]:
hidden_sizes = [4, 8, 16]
for h in hidden_sizes:
    model = SimpleGNN(2, h, 2).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    for epoch in range(201):
        train(model, data, labels, optimizer, criterion)
    acc = evaluate(model, data, labels)
    print(f"Hidden size: {h}  Accuracy: {acc:.4f}")

Hidden size: 4  Accuracy: 0.5000
Hidden size: 8  Accuracy: 0.5000
Hidden size: 16  Accuracy: 1.0000


##### **Q24: How do you experiment with different aggregation functions in the message-passing mechanism?**


In [60]:
def aggregate_mean(messages, edge_index, num_nodes):
    target_nodes = edge_index[1]
    aggregated = torch.zeros(num_nodes, messages.size(1)).to(messages.device)
    counts = torch.bincount(target_nodes, minlength=num_nodes).clamp(min=1).unsqueeze(1).float()
    aggregated.index_add_(0, target_nodes, messages)
    return aggregated / counts

def aggregate_max(messages, edge_index, num_nodes):
    target_nodes = edge_index[1]
    aggregated = torch.full((num_nodes, messages.size(1)), float('-inf')).to(messages.device)
    for i in range(messages.size(0)):
        idx = target_nodes[i].item()
        aggregated[idx] = torch.max(aggregated[idx].clone(), messages[i])  # avoid in-place ops on graph
    return aggregated

In [61]:
class GraphConvMean(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.linear = nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        messages = message_passing(x, edge_index)
        aggregated = aggregate_mean(messages, edge_index, x.size(0))
        updated = update_nodes(x, aggregated)
        out = self.linear(updated)
        return F.relu(out)

class GraphConvMax(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.linear = nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        messages = message_passing(x, edge_index)
        aggregated = aggregate_max(messages, edge_index, x.size(0))
        updated = update_nodes(x, aggregated)
        out = self.linear(updated)
        return F.relu(out)

class GNNWithAggregation(nn.Module):
    def __init__(self, conv_layer_class, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = conv_layer_class(in_channels, hidden_channels)
        self.conv2 = conv_layer_class(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = self.conv2(x, edge_index)
        return x

In [62]:
print("Testing sum aggregation (baseline)")
model = SimpleGNN(2, 8, 2).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
for epoch in range(20):
    train(model, data, labels, optimizer, criterion)
acc = evaluate(model, data, labels)
print(f"Sum agg accuracy: {acc:.4f}")

print("Testing mean aggregation")
model = GNNWithAggregation(GraphConvMean, 2, 8, 2).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
for epoch in range(20):
    train(model, data, labels, optimizer, criterion)
acc = evaluate(model, data, labels)
print(f"Mean agg accuracy: {acc:.4f}")

print("Testing max aggregation")
model = GNNWithAggregation(GraphConvMax, 2, 8, 2).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
for epoch in range(20):
    train(model, data, labels, optimizer, criterion)
acc = evaluate(model, data, labels)
print(f"Max agg accuracy: {acc:.4f}")

Testing sum aggregation (baseline)
Sum agg accuracy: 0.5000
Testing mean aggregation
Mean agg accuracy: 1.0000
Testing max aggregation
Max agg accuracy: 0.5000


##### **Q25: How do you tune learning rates and dropout rates to improve the generalization of the GNN model?**

In [63]:
class TuningGNN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, dropout_p):
        super().__init__()
        self.conv1 = GraphConv(in_channels, hidden_channels)
        self.bn1 = nn.BatchNorm1d(hidden_channels)
        self.conv2 = GraphConv(hidden_channels, out_channels)
        self.dropout_p = dropout_p

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = self.bn1(x)
        x = F.dropout(x, p=self.dropout_p, training=self.training)
        x = self.conv2(x, edge_index)
        return x

In [64]:
dropouts = [0.2, 0.5, 0.8]
learning_rates = [0.001, 0.01, 0.1]

for dr in dropouts:
    for lr in learning_rates:
        model = TuningGNN(2, 8, 2, dropout_p=dr).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        for epoch in range(20):
            train(model, data, labels, optimizer, criterion)
        acc = evaluate(model, data, labels)
        print(f"Dropout: {dr:.1f}  LR: {lr:.3f}  Accuracy: {acc:.4f}")

Dropout: 0.2  LR: 0.001  Accuracy: 0.5000
Dropout: 0.2  LR: 0.010  Accuracy: 1.0000
Dropout: 0.2  LR: 0.100  Accuracy: 0.5000
Dropout: 0.5  LR: 0.001  Accuracy: 0.5000
Dropout: 0.5  LR: 0.010  Accuracy: 0.5000
Dropout: 0.5  LR: 0.100  Accuracy: 1.0000
Dropout: 0.8  LR: 0.001  Accuracy: 0.5000
Dropout: 0.8  LR: 0.010  Accuracy: 0.5000
Dropout: 0.8  LR: 0.100  Accuracy: 0.5000
