In [1]:
!pip install torch-geometric

Collecting torch-geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m13.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch-geometric
Successfully installed torch-geometric-2.6.1


Numerical Example with Python Code
Suppose we have a simple graph with 3 nodes connected as follows: node 1 is connected to nodes 2 and 3. Each node has a single feature initially set to 1. We’ll perform one message passing step using sum aggregation and a simple linear update function.

Graph Representation
*   Node features: X = [1,1,1]
*   Edges: (1, 2), (1, 3)
We'll implement a single GNN layer using PyTorch Geometric:



In [None]:
import torch
from torch_geometric.data import Data

# Define the edges (source and target)
edge_index = torch.tensor([[0, 0], [1, 2]], dtype=torch.long)  # 0-based indexing

# Define node features
x = torch.tensor([[1], [1], [1]], dtype=torch.float)

# Create a graph
graph = Data(x=x, edge_index=edge_index.t().contiguous())

print("Graph node features:")
print(graph.x)
print("Graph edges:")
print(graph.edge_index)

Graph node features:
tensor([[1.],
        [1.],
        [1.]])
Graph edges:
tensor([[0, 1],
        [0, 2]])


In [None]:
import torch
import torch.nn as nn
from torch_geometric.nn import MessagePassing

# Set a random seed for reproducibility
torch.manual_seed(42)

# Define the SimpleGNNLayer using PyTorch Geometric
class SimpleGNNLayer(MessagePassing):
    def __init__(self):
        super(SimpleGNNLayer, self).__init__(aggr='add')  # "add" means sum aggregation
        self.lin = nn.Linear(1, 1)
        # Manually initialize weights and biases
        self.lin.weight.data.fill_(0.5)
        self.lin.bias.data.fill_(0.1)

    def forward(self, x, edge_index):
        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)

    def message(self, x_j):
        return x_j

    def update(self, aggr_out):
        return self.lin(aggr_out)

# Initialize the GNN layer
pg_gnn_layer = SimpleGNNLayer()

# Define the graph structure and node features
edge_index = torch.tensor([[0, 0], [1, 2]], dtype=torch.long)
x = torch.tensor([[2.], [1.], [1.]], dtype=torch.float)

# Perform forward pass
pg_updated_features = pg_gnn_layer(x, edge_index)

print("Updated node features from PyTorch Geometric:")
print(pg_updated_features)


Updated node features from PyTorch Geometric:
tensor([[0.1000],
        [1.1000],
        [1.1000]], grad_fn=<AddmmBackward0>)


In [None]:
import torch
import torch.nn as nn

# Set a random seed for reproducibility
torch.manual_seed(42)

# Define the ManualGNNLayer
class ManualGNNLayer(nn.Module):
    def __init__(self, in_features, out_features):
        super(ManualGNNLayer, self).__init__()
        self.lin = nn.Linear(in_features, out_features)
        # Manually initialize weights and biases
        self.lin.weight.data.fill_(0.5)
        self.lin.bias.data.fill_(0.1)

    def forward(self, x, edge_index):
        num_nodes = x.size(0)
        messages = x[edge_index[0]]
        #print(messages)
        aggr = torch.zeros(num_nodes, x.size(1), device=x.device)
        aggr.index_add_(0, edge_index[1], messages)
        #print(aggr)
        return self.lin(aggr)

# Initialize the GNN layer
manual_gnn_layer = ManualGNNLayer(1, 1)

# Define the graph structure and node features
edge_index = torch.tensor([[0, 0], [1, 2]], dtype=torch.long)
x = torch.tensor([[2.], [1.], [1.]], dtype=torch.float)

# Perform forward pass
manual_updated_features = manual_gnn_layer(x, edge_index)

print("Updated node features from Manual GNN Layer:")
print(manual_updated_features)


Updated node features from Manual GNN Layer:
tensor([[0.1000],
        [1.1000],
        [1.1000]], grad_fn=<AddmmBackward0>)
