In [2]:
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git

2.1.0+cu118
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.2/10.2 MB[0m [31m95.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.9/4.9 MB[0m [31m66.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for torch_geometric (pyproject.toml) ... [?25l[?25hdone


In [3]:
import os

import torch
import torch_scatter
import torch.nn as nn
import torch.nn.functional as F

import torch_geometric
import torch_geometric.nn as pyg_nn
import torch_geometric.utils as pyg_utils

from torch import Tensor
from typing import Union, Tuple, Optional
from torch_geometric.typing import (OptPairTensor, Adj, Size, NoneType,
                                    OptTensor)

from torch.nn import Parameter, Linear
from torch_sparse import SparseTensor, set_diag
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops, softmax

import numpy as np

## Custom Graph Layer Implementation

References:
- http://web.stanford.edu/class/cs224w/

The job of a message passing layer is to update the current feature representation or embedding of each node in a graph by propagating and transforming information within the graph. Overall, the general paradigm of a message passing layers is: 1) pre-processing -> 2) **message passing** / propagation -> 3) post-processing.

The `forward` fuction that we will implement for our message passing layer captures this execution logic. Namely, the `forward` function handles the pre and post-processing of node features / embeddings, as well as initiates message passing by calling the `propagate` function.

The `propagate` function encapsulates the message passing process! It does so by calling three important functions: 1) `message`, 2) `aggregate`, and 3) `update`. Our implementation will vary slightly from this, as we will not explicitly implement `update`, but instead place the logic for updating node embeddings after message passing and within the `forward` function. To be more specific, after information is propagated (message passing), we can further transform the node embeddings outputed by `propagate`. Therefore, the output of `forward` is exactly the node embeddings after one GNN layer.

Lastly, before starting to implement our own layer, let us dig a bit deeper into each of the functions described above:

1.

```
def propagate(edge_index, x=(x_i, x_j), extra=(extra_i, extra_j), size=size):
```
Calling `propagate` initiates the message passing process. Looking at the function parameters, we highlight a couple of key parameters.

  - `edge_index` is passed to the forward function and captures the edge structure of the graph.
  - `x=(x_i, x_j)` represents the node features that will be used in message passing. In order to explain why we pass the tuple `(x_i, x_j)`, we first look at how our edges are represented. For every edge $(i, j) \in {E}$, we can differentiate $i$ as the source or central node ($x_{central}$) and j as the neighboring node ($x_{neighbor}$).
  
    Taking the example of message passing above, for a central node $u$ we will aggregate and transform all of the messages associated with the nodes $v$ s.t. $(u, v) \in {E}$ (i.e. $v \in \mathscr{N}_{u}$). Thus we see, the subscripts `_i` and `_j` allow us to specifcally differenciate features associated with central nodes (i.e. nodes  recieving message information) and neighboring nodes (i.e. nodes passing messages).

    This is definitely a somewhat confusing concept; however, one key thing to remember / wrap your head around is that depending on the perspective, a node $x$ acts as a central node or a neighboring node. In fact, in undirected graphs we store both edge directions (i.e. $(i, j)$ and $(j, i)$). From the central node perspective, `x_i`, x is collecting neighboring information to update its embedding. From a neighboring node perspective, `x_j`, x is passing its message information along the edge connecting it to a different central node.

  - `extra=(extra_i, extra_j)` represents additional information that we can associate with each node beyond its current feature embedding. In fact, we can include as many additional parameters of the form `param=(param_i, param_j)` as we would like. Again, we highlight that indexing with `_i` and `_j` allows us to differentiate central and neighboring nodes.

  The output of the `propagate` function is a matrix of node embeddings after the message passing process and has shape $[N, d]$.

2.
```
def message(x_j, ...):
```
The `message` function is called by propagate and constructs the messages from
neighboring nodes $j$ to central nodes $i$ for each edge $(i, j)$ in *edge_index*. This function can take any argument that was initially passed to `propagate`. Furthermore, we can again differentiate central nodes and neighboring nodes by appending `_i` or `_j` to the variable name, .e.g. `x_i` and `x_j`. Looking more specifically at the variables, we have:

  - `x_j` represents a matrix of feature embeddings for all neighboring nodes passing their messages along their respective edge (i.e. all nodes $j$ for edges $(i, j) \in {E}$). Thus, its shape is $[|{E}|, d]$!

  Critically, we see that the output of the `message` function is a matrix of neighboring node embeddings ready to be aggregated, having shape $[|{E}|, d]$.

3.
```
def aggregate(self, inputs, index, dim_size = None):
```
Lastly, the `aggregate` function is used to aggregate the messages from neighboring nodes. Looking at the parameters we highlight:

  - `inputs` represents a matrix of the messages passed from neighboring nodes (i.e. the output of the `message` function).
  - `index` has the same shape as `inputs` and tells us the central node that corresponding to each of the rows / messages $j$ in the `inputs` matrix. Thus, `index` tells us which rows / messages to aggregate for each central node.

  The output of `aggregate` is of shape $[N, d]$.


For additional resources refer to the PyG documentation for implementing custom message passing layers: https://pytorch-geometric.readthedocs.io/en/latest/notes/create_gnn.html

In [4]:
class CustomGraphSage(MessagePassing):

    def __init__(self, in_channels, out_channels, normalize = True,
                 bias = False, **kwargs):
        super(CustomGraphSage, self).__init__(**kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.normalize = normalize

        self.lin_l = None
        self.lin_r = None

        ############################################################################
        # TODO: Your code here!
        # Define the layers needed for the message and update functions below.
        # self.lin_l is the linear transformation that you apply to embedding
        #            for central node.
        # self.lin_r is the linear transformation that you apply to aggregated
        #            message from neighbors.
        # Don't forget the bias!
        # Our implementation is ~2 lines, but don't worry if you deviate from this.
        self.lin_l = nn.Linear(in_channels, out_channels)
        self.lin_r = nn.Linear(in_channels, out_channels)
        ############################################################################

        self.reset_parameters()

    def reset_parameters(self):
        # TODO: initialize your parameters
        self.lin_l.reset_parameters()
        self.lin_r.reset_parameters()

    def forward(self, x, edge_index, size = None):

        out = None

        ############################################################################
        # TODO: Your code here!
        # Implement message passing, as well as any post-processing (our update rule).
        # 1. Call the propagate function to conduct the message passing.
        #    1.1 See the description of propagate above or the following link for more information:
        #        https://pytorch-geometric.readthedocs.io/en/latest/notes/create_gnn.html
        #    1.2 We will only use the representation for neighbor nodes (x_j), so by default
        #        we pass the same representation for central and neighbor nodes as x=(x, x).
        # 2. Update our node embedding with skip connection from the previous layer.
        # 3. If normalize is set, do L-2 normalization (defined in
        #    torch.nn.functional)
        aggr = self.propagate(edge_index, size=size, x=(x,x))
        out = self.lin_l(x) + self.lin_r(aggr)
        out = F.normalize(out, p=2) if self.normalize else out
        ############################################################################

        return out

    def message(self, x_j):

        out = None

        ############################################################################
        # TODO: Your code here!
        # Implement your message function here.
        # Hint: Look at the formulation of the mean aggregation function, focusing on
        # what message each neighboring node passes.
        out = x_j
        ############################################################################

        return out

    def aggregate(self, inputs, index, dim_size = None):

        out = None

        # The axis along which to index number of nodes.
        node_dim = self.node_dim

        ############################################################################
        # TODO: Your code here!
        # Implement your aggregate function here.
        # See here as how to use torch_scatter.scatter:
        # https://pytorch-scatter.readthedocs.io/en/latest/functions/scatter.html#torch_scatter.scatter

        out = torch_scatter.scatter(inputs, index, dim=node_dim, dim_size=dim_size, reduce="mean")
        ############################################################################

        return out

In [5]:
node_embedings = torch.ones(4, 8) # A graph with 4 nodes and 8 dimensional node features
edge_index = torch.tensor([[0, 1, 2, 0, 3],
                           [1, 0, 1, 3, 2]], dtype=torch.long) # Example edge index

In [6]:
custom_layer = CustomGraphSage(in_channels = 8, out_channels = 2, normalize = True)

In [7]:
transformed_embeddings = custom_layer(node_embedings, edge_index)
assert transformed_embeddings.shape == torch.Size([4, 2]), "Incorrect shape returned."

## Train and Evaluate  

In [8]:
from torch_geometric.data import DataLoader
from torch_geometric.datasets import TUDataset

def shuffle(dataset, seed):
    torch.manual_seed(seed)
    return dataset.shuffle()

def train_test_val_split(num_test, batch_size, dataset):
    test_dataset = dataset[:num_test]
    val_dataset = dataset[num_test:2*num_test]
    train_dataset = dataset[2*num_test:]
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    train_loader = DataLoader(train_dataset, batch_size=batch_size)
    return test_loader, val_loader, train_loader

def calculate_accuracy(model, loader):
    model.eval()
    correct = 0
    total = 0
    for i, data_batch in enumerate(loader):
        x, edge_index, batch, y = data_batch.x, data_batch.edge_index, data_batch.batch, data_batch.y
        pred = model(x, edge_index, batch)
        correct += (pred.max(dim=1)[1]).eq(y).sum().item()
        total += len(data_batch)
    return correct / total

seed = [42, 1999, 12581]
batch_size = 32
num_test = 100

1. Build a training loop and evaluate the custom graph layer implementation.

In [9]:
from torch_geometric.nn import global_add_pool, global_mean_pool

class CustomGNN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers=1):
        super(CustomGNN, self).__init__()
        self.gnn_layers = nn.ModuleList()
        self.gnn_layers.append(CustomGraphSage(input_dim, hidden_dim))
        for _ in range(num_layers-1):
            self.gnn_layers.append(CustomGraphSage(hidden_dim, hidden_dim))
        self.num_layers = num_layers
        self.pool = global_mean_pool
        self.post = torch.nn.Linear(hidden_dim, output_dim)

    def forward(self, x, edge_index, batch):
        for i, layer in enumerate(self.gnn_layers):
            x = layer(x, edge_index)
            x = F.relu(x)

        x = self.pool(x, batch)
        x = self.post(x)
        x = F.log_softmax(x, dim=1)
        return x

In [10]:
def use_custom(run):
    # dataset
    dataset = shuffle(TUDataset(name='ENZYMES', root='data/TUDataset'), seed[run])
    test_loader, valid_loader, train_loader = train_test_val_split(num_test, batch_size, dataset)

    # model and optimizer
    model = CustomGNN(dataset.num_features, 32, dataset.num_classes, 3)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    # train
    for epoch in range(100):
        model.train()
        for i, data_batch in enumerate(train_loader):
            feature, edge_index, batch, label = data_batch.x, data_batch.edge_index, data_batch.batch, data_batch.y
            pred = model(feature, edge_index, batch)
            loss = F.nll_loss(pred, label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # valid
        if (epoch + 1) % 20 == 0:
            accuracy = calculate_accuracy(model, valid_loader)
            print("Run: {}\tEpoch: {}\tValidation accuracy: {}".format(run + 1, epoch + 1, accuracy))

    # test
    accuracy = calculate_accuracy(model, test_loader)
    accuracies.append(accuracy)

num_runs = 3
accuracies = []
for run in range(num_runs):
    use_custom(run)
print("Model: GNN with custom GraphSAGE")
print("Accuracies: {}".format(accuracies))
print("Mean: {}".format(np.mean(accuracies)))
print("Standard deviation: {}".format(np.std(accuracies)))

Downloading https://www.chrsmrrs.com/graphkerneldatasets/ENZYMES.zip
Extracting data/TUDataset/ENZYMES/ENZYMES.zip
Processing...
Done!


Run: 1	Epoch: 20	Validation accuracy: 0.28
Run: 1	Epoch: 40	Validation accuracy: 0.26
Run: 1	Epoch: 60	Validation accuracy: 0.27
Run: 1	Epoch: 80	Validation accuracy: 0.28
Run: 1	Epoch: 100	Validation accuracy: 0.3
Run: 2	Epoch: 20	Validation accuracy: 0.27
Run: 2	Epoch: 40	Validation accuracy: 0.36
Run: 2	Epoch: 60	Validation accuracy: 0.39
Run: 2	Epoch: 80	Validation accuracy: 0.43
Run: 2	Epoch: 100	Validation accuracy: 0.37
Run: 3	Epoch: 20	Validation accuracy: 0.29
Run: 3	Epoch: 40	Validation accuracy: 0.3
Run: 3	Epoch: 60	Validation accuracy: 0.32
Run: 3	Epoch: 80	Validation accuracy: 0.32
Run: 3	Epoch: 100	Validation accuracy: 0.34
Model: GNN with custom GraphSAGE
Accuracies: [0.36, 0.29, 0.27]
Mean: 0.30666666666666664
Standard deviation: 0.03858612300930075


2. Use SAGEConv layer instead of the custom layer.

In [11]:
from torch_geometric.nn import global_add_pool, global_mean_pool

class SAGEConvGNN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers=1):
        super(SAGEConvGNN, self).__init__()
        self.gnn_layers = nn.ModuleList()
        self.gnn_layers.append(pyg_nn.SAGEConv(input_dim, hidden_dim))
        for _ in range(num_layers-1):
            self.gnn_layers.append(pyg_nn.SAGEConv(hidden_dim, hidden_dim))
        self.num_layers = num_layers
        self.pool = global_mean_pool
        self.post = torch.nn.Linear(hidden_dim, output_dim)

    def forward(self, x, edge_index, batch):
        for i, layer in enumerate(self.gnn_layers):
            x = layer(x, edge_index)
            x = F.relu(x)

        x = self.pool(x, batch)
        x = self.post(x)
        x = F.log_softmax(x, dim=1)
        return x

In [13]:
def use_SAGEConv(run):
    # dataset
    dataset = shuffle(TUDataset(name='ENZYMES', root='data/TUDataset'), seed[run])
    test_loader, valid_loader, train_loader = train_test_val_split(num_test, batch_size, dataset)

    # model and optimizer
    model = SAGEConvGNN(dataset.num_features, 32, dataset.num_classes, 3)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    # train
    for epoch in range(100):
        model.train()
        for i, data_batch in enumerate(train_loader):
            feature, edge_index, batch, label = data_batch.x, data_batch.edge_index, data_batch.batch, data_batch.y
            pred = model(feature, edge_index, batch)
            loss = F.nll_loss(pred, label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # valid
        if (epoch + 1) % 20 == 0:
            accuracy = calculate_accuracy(model, valid_loader)
            print("Run: {}\tEpoch: {}\tValidation accuracy: {}".format(run + 1, epoch + 1, accuracy))

    # test
    accuracy = calculate_accuracy(model, test_loader)
    accuracies.append(accuracy)

num_runs = 3
accuracies = []
for run in range(num_runs):
    use_SAGEConv(run)
print("Model: GNN with SAGEConv layer")
print("Accuracies: {}".format(accuracies))
print("Mean: {}".format(np.mean(accuracies)))
print("Standard deviation: {}".format(np.std(accuracies)))

Run: 1	Epoch: 20	Validation accuracy: 0.2
Run: 1	Epoch: 40	Validation accuracy: 0.28
Run: 1	Epoch: 60	Validation accuracy: 0.3
Run: 1	Epoch: 80	Validation accuracy: 0.28
Run: 1	Epoch: 100	Validation accuracy: 0.27
Run: 2	Epoch: 20	Validation accuracy: 0.37
Run: 2	Epoch: 40	Validation accuracy: 0.41
Run: 2	Epoch: 60	Validation accuracy: 0.37
Run: 2	Epoch: 80	Validation accuracy: 0.4
Run: 2	Epoch: 100	Validation accuracy: 0.4
Run: 3	Epoch: 20	Validation accuracy: 0.38
Run: 3	Epoch: 40	Validation accuracy: 0.32
Run: 3	Epoch: 60	Validation accuracy: 0.34
Run: 3	Epoch: 80	Validation accuracy: 0.39
Run: 3	Epoch: 100	Validation accuracy: 0.37
Model: GNN with SAGEConv layer
Accuracies: [0.4, 0.29, 0.31]
Mean: 0.3333333333333333
Standard deviation: 0.04784233364802443


3. Use GRU aggregation instead of Mean aggregation.

In [14]:
from torch_geometric.nn import aggr, global_add_pool, global_mean_pool
from torch_geometric.utils import sort_edge_index

class CustomGraphSageGRU(MessagePassing):

    def __init__(self, in_channels, out_channels, normalize = True,
                 bias = False, **kwargs):
        super(CustomGraphSageGRU, self).__init__(**kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.normalize = normalize
        self.lin_l = nn.Linear(in_channels, out_channels)
        self.lin_r = nn.Linear(in_channels, out_channels)
        self.aggr = aggr.GRUAggregation(in_channels, out_channels)
        self.reset_parameters()

    def reset_parameters(self):
        self.lin_l.reset_parameters()
        self.lin_r.reset_parameters()

    def forward(self, x, edge_index, size = None):
        aggr = self.propagate(edge_index, size=size, x=(x,x))
        out = self.lin_l(x) + self.lin_r(aggr)
        out = F.normalize(out, p=2) if self.normalize else out
        return out

    def message(self, x_j):
        out = x_j
        return out

class CustomGNNGRU(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers=1):
        super(CustomGNNGRU, self).__init__()
        self.gnn_layers = nn.ModuleList()
        self.gnn_layers.append(CustomGraphSageGRU(input_dim, hidden_dim))
        for _ in range(num_layers-1):
            self.gnn_layers.append(CustomGraphSageGRU(hidden_dim, hidden_dim))
        self.num_layers = num_layers
        self.pool = global_mean_pool
        self.post = torch.nn.Linear(hidden_dim, output_dim)

    def forward(self, x, edge_index, batch):
        edge_index = sort_edge_index(edge_index)
        for i, layer in enumerate(self.gnn_layers):
            x = layer(x, edge_index)
            x = F.relu(x)

        x = self.pool(x, batch)
        x = self.post(x)
        x = F.log_softmax(x, dim=1)
        return x

In [15]:
def use_GRU(run):
    # dataset
    dataset = shuffle(TUDataset(name='ENZYMES', root='data/TUDataset'), seed[run])
    test_loader, valid_loader, train_loader = train_test_val_split(num_test, batch_size, dataset)

    # model and optimizer
    model = CustomGNNGRU(dataset.num_features, 32, dataset.num_classes, 3)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    # train
    for epoch in range(100):
        model.train()
        for i, data_batch in enumerate(train_loader):
            feature, edge_index, batch, label = data_batch.x, data_batch.edge_index, data_batch.batch, data_batch.y
            pred = model(feature, edge_index, batch)
            loss = F.nll_loss(pred, label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # valid
        if (epoch + 1) % 20 == 0:
            accuracy = calculate_accuracy(model, valid_loader)
            print("Run: {}\tEpoch: {}\tValidation accuracy: {}".format(run + 1, epoch + 1, accuracy))

    # test
    accuracy = calculate_accuracy(model, test_loader)
    accuracies.append(accuracy)

num_runs = 3
accuracies = []
for run in range(num_runs):
    use_GRU(run)
print("Model: GNN with GRU layer")
print("Accuracies: {}".format(accuracies))
print("Mean: {}".format(np.mean(accuracies)))
print("Standard deviation: {}".format(np.std(accuracies)))

Run: 1	Epoch: 20	Validation accuracy: 0.22
Run: 1	Epoch: 40	Validation accuracy: 0.3
Run: 1	Epoch: 60	Validation accuracy: 0.28
Run: 1	Epoch: 80	Validation accuracy: 0.33
Run: 1	Epoch: 100	Validation accuracy: 0.37
Run: 2	Epoch: 20	Validation accuracy: 0.4
Run: 2	Epoch: 40	Validation accuracy: 0.42
Run: 2	Epoch: 60	Validation accuracy: 0.47
Run: 2	Epoch: 80	Validation accuracy: 0.52
Run: 2	Epoch: 100	Validation accuracy: 0.46
Run: 3	Epoch: 20	Validation accuracy: 0.27
Run: 3	Epoch: 40	Validation accuracy: 0.35
Run: 3	Epoch: 60	Validation accuracy: 0.36
Run: 3	Epoch: 80	Validation accuracy: 0.42
Run: 3	Epoch: 100	Validation accuracy: 0.44
Model: GNN with GRU layer
Accuracies: [0.47, 0.35, 0.44]
Mean: 0.42
Standard deviation: 0.050990195135927854
