In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch_geometric.nn as pyg_nn
import torch_geometric.utils as pyg_utils

import time
from datetime import datetime

import networkx as nx
import numpy as np
import torch
import torch.optim as optim

from torch_geometric.datasets import TUDataset, FakeDataset, GNNBenchmarkDataset
from torch_geometric.loader import DataLoader

import torch_geometric.transforms as T

from tensorboardX import SummaryWriter
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

In [2]:
# Set verbosity
VERBOSE = False

Debug tools

In [3]:
def print_edge_index(edge_index):
    print('Edges info:')
    print('Num edges=', edge_index.size(1))
    print(''.join(['edge #' + str(i) + ':\t' + str(edge_index[0,i].item()) + '->' + str(edge_index[1,i].item()) + ('\n' if (i+1) % 4 == 0 else ';\t\t') for i in range(edge_index.size(1))]))

# Defining the custom graph convolution layer

Here we define a custom instance of MessagePassing. It defines a single layer of graph convolution, which can be decomposed into:
* Message computation
* Aggregation
* Update
* Pooling

Here we give an example of how to subclass the pytorch geometric MessagePassing class to derive a new model (rather than using existing GCNConv and GINConv).

We make use of `MessagePassing`'s key building blocks:
- `aggr`: The aggregation method to use ("add", "mean" or "max").
- `propagate()`: The initial call to start propagating messages. Takes in the edge indices and any other data to pass along (e.g. to update node embeddings).
- `message()`: Constructs messages to node i. Takes any argument which was initially passed to propagate().
- `update()`: Updates node embeddings. Takes in the output of aggregation as first argument and any argument which was initially passed to propagate().


In [4]:
class CustomConv(pyg_nn.MessagePassing):
    def __init__(self, in_channels, out_channels, max_neighbors_to_consider=4):
        """
        Parameters
        ----------
        max_neighbors_to_consider : int
            The number of neighbor nodes (from the epsilon-environment) to consider during aggregation.
        """
        super(CustomConv, self).__init__(aggr='mean')
        self.out_channels = out_channels
        self.max_neighbors_to_consider = max_neighbors_to_consider

        # Linear layer applied on the features of the center node.
        self.lin_self = nn.Linear(in_channels, out_channels)

        # For each neighbor node (from the epsilon-environment) we consider during aggregation, we apply linear layer on its features.
        self.lin = nn.ModuleList()
        for i in range(max_neighbors_to_consider):
            self.lin.append(nn.Linear(in_channels, out_channels))
            nn.init.normal_(self.lin[i].weight, mean=0, std=1.0)

        # The epsilon that defines the "eplison environment" we consider for each node.
        # This is a trainable parameter.
        self.eps = torch.nn.Parameter(torch.rand(1)) 

    def forward(self, x, edge_index):
        """
        Forward computation.

        Parameters
        ----------
        x : torch.Tensor
            Input node embeddings. Has shape [NUM_NODES, in_channels]
        edge_index : torch.Tensor
            Edge index. Has shape [2, NUM_EDGES]

        Returns
        -------
        h : torch.Tensor
            Updated node embeddings.
        """        

        # Remove self loops
        edge_index, _ = pyg_utils.remove_self_loops(edge_index)

        if VERBOSE:
            print('self.eps= ', self.eps)
            print('num nodes=', x.size(0))
            print_edge_index(edge_index)

        x = self.lin_self(x) + self.propagate(edge_index, x=x, num_nodes=x.size(0))

        return self.averaging(x)

    def message(self, x_j, edge_index, num_nodes):
        """
        Compute messages.

        Parameters
        ----------
        ...
        x_j : torch.Tensor
            Source node embeddings of each edge. Meaning, x_j[17] will contain the features of the source node of edge #17.
            To see the edge's nodes indices, see 'edge_index'
            Has shape [NUM_EDGES, in_channels]
        edge_index : torch.Tensor
            Edge index. Has shape [2, NUM_EDGES]. The first entry in the first dimension is j 
            and the second entry is i. For example, if edge_index[0,17]=30 and edge_index[1,17]=21, 
            then there exists an edge from node 30 to node 21. The features data of node 30 will be
            found in x_j[17], while the feature data of node 21 will be found in x_i[17].

        Returns
        -------
        res : torch.Tensor
            The message from each source node to the target node, for each edge.
            res[17] will contain the message from the source node of edge #17 to the target node
            of edge #17. To see the edge's nodes indices, see 'edge_index'.
            Has shape [NUM_EDGES, out_channels]
        """             

        # This variable contains the first feature of the source node of each edge.
        x_j_first_feature = torch.squeeze(torch.narrow(x_j, 1, 0, 1))
        if VERBOSE:
            print('x_j_first_feature=', x_j_first_feature)

        # This variable will contain the order of the in-edges (of each node), 
        # sorted by the values of the first feature of the source nodes.
        # For example, if neighborhood_ranking[17]=3, it means that edge #17 is 
        # in the 4th place, out of all of the edges with the same target node
        # (when ordering these edges by the value of the first feature of 
        # their repective source nodes).
        neighborhood_ranking = torch.zeros(x_j_first_feature.shape, dtype=torch.long)

        # We iterate over each target node
        for i in range(num_nodes):
            # We first collect the indices of the in-edges of the target node
            in_edges = torch.where(torch.narrow(edge_index, 0, 1, 1) == i)[1]
            neighborhood_ranking[in_edges] = x_j_first_feature[in_edges].argsort(dim=0).argsort(dim=0)

            if VERBOSE:
                print('in-edges of node ' + str(i) + ':', ''.join([str(elem.item()) + ', ' for elem in in_edges]))

        if VERBOSE:
            print('neighborhood_ranking=', neighborhood_ranking)

        # At this point we have the 'ranking' of every adjacent node. 
        # We apply the corresponding linear layer on all nodes with similar rankings. 
        res = torch.zeros((x_j.shape[0], self.out_channels))
        for i in range(self.max_neighbors_to_consider):
            indices = torch.where(neighborhood_ranking == i)
            res[indices] = self.lin[i](x_j[indices])

        return res

    def averaging(self, x):
        """
        Averages points along the first dimension. The purpose of this operation is to prevent "jumping" of the points
        when we sort them along the first dimension. For each node, we take its epsilon-neighborhood, and we update the
        feature of the node to be the average of the features of the nodes in its eplison-neighborhood.

        Parameters
        ----------
        ...
        x : torch.Tensor
            The embedding of all nodes in the graph.
            Has shape [NUM_NODES, channels]
        """   
        
        if VERBOSE:
            print('x before averaging', x)

        res = torch.zeros(x.shape)

        first_dim_values, indices = torch.sort(torch.squeeze(torch.narrow(x, 1, 0, 1)))
        for i, val in enumerate(first_dim_values):
            eps_env = torch.where(abs(first_dim_values-val) < self.eps)
            res[indices[i]] = torch.mean(x[indices[eps_env]], dim=0)

        if VERBOSE:
            print('x after averaging', res)

        return res

# Define the neural network

In [5]:
class GNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, task='node', use_custom_conv=False):
        """
        Parameters
        ----------
        task : str
            Whether the task is to predict a label for each node (in the case of 'node'), or a label to the entire graph (in the case of 'graph').
        use_custom_conv : boolean
            Whether to use our custom convolution or the standard GCNConv of Torch Geometric.
        """
        super(GNN, self).__init__()
        self.task = task
        self.use_custom_conv = use_custom_conv

        # Note: this architechture is from some notebook I found https://colab.research.google.com/drive/1DIQm9rOx2mT1bZETEeVUThxcrP1RKqAn
        # from some Stanford GNN course lecture https://www.youtube.com/watch?v=-UjytpbqX4A
        self.convs = nn.ModuleList()
        self.convs.append(self.build_conv_model(input_dim, hidden_dim))
        self.lns = nn.ModuleList()
        self.lns.append(nn.LayerNorm(hidden_dim))
        self.lns.append(nn.LayerNorm(hidden_dim))
        for l in range(2):
            self.convs.append(self.build_conv_model(hidden_dim, hidden_dim))

        # post-message-passing
        self.post_mp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim), nn.Dropout(0.25),
            nn.Linear(hidden_dim, output_dim))
        if not (self.task == 'node' or self.task == 'graph'):
            raise RuntimeError('Unknown task.')

        self.dropout = 0.25
        self.num_layers = 3

    def build_conv_model(self, input_dim, hidden_dim):
        if self.use_custom_conv:
            return CustomConv(input_dim, hidden_dim)

        # refer to pytorch geometric nn module for different implementation of GNNs.
        if self.task == 'node':
            return pyg_nn.GCNConv(input_dim, hidden_dim)
        else:
            return pyg_nn.GINConv(nn.Sequential(nn.Linear(input_dim, hidden_dim),
                                  nn.ReLU(), nn.Linear(hidden_dim, hidden_dim)))

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        if data.num_node_features == 0:
          x = torch.ones(data.num_nodes, 1)

        for i in range(self.num_layers):
            # start = time.time()

            x = self.convs[i](x, edge_index)
            emb = x
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
            if not i == self.num_layers - 1:
                x = self.lns[i](x)

            # end = time.time()
            # print(f'Finished layer {i}. Running time: {end - start:.2f} seconds')

        if self.task == 'graph':
            x = pyg_nn.global_mean_pool(x, batch)

        x = self.post_mp(x)

        return emb, F.log_softmax(x, dim=1)

    def loss(self, pred, label):
        return F.nll_loss(pred, label)

# Training setup

We train the model in a standard way here, running it forwards to compute its predicted label distribution and backpropagating the error. Note the task setup in our graph setting: for node classification, we define a subset of nodes to be training nodes and the rest of the nodes to be test nodes, and mask out the test nodes during training via `batch.train_mask`. For graph classification, we use 80% of the graphs for training and the remainder for testing, as in other classification settings.

In [6]:
def train(dataset, task, use_custom = False):
    if task == 'graph':
        data_size = len(dataset)
        loader = DataLoader(dataset[:int(data_size * 0.8)], batch_size=64, shuffle=True)
        test_loader = DataLoader(dataset[int(data_size * 0.8):], batch_size=64, shuffle=True)

    else:
        test_loader = loader = DataLoader(dataset, batch_size=64, shuffle=True)
        

    # build model
    model = GNN(max(dataset.num_node_features, 1), 32, dataset.num_classes, task=task, use_custom_conv=use_custom)
    opt = optim.Adam(model.parameters(), lr=0.01)

    # train
    for epoch in range(101):
        if VERBOSE:
            print('epoch #' + str(epoch))
        total_loss = 0
        model.train()
        for batch in loader:
            opt.zero_grad()
            embedding, pred = model(batch)
            label = batch.y
            if task == 'node':
                pred = pred[batch.train_mask]
                label = label[batch.train_mask]
            loss = model.loss(pred, label)
            loss.backward()
            opt.step()
            total_loss += loss.item() * batch.num_graphs
        total_loss /= len(loader.dataset)

        if epoch % 10 == 0:
            test_acc = test(test_loader, model, True)
            print("Epoch {}. Loss: {:.4f}. Test accuracy: {:.4f}".format(
                epoch, total_loss, test_acc))

    return model


Test time, for the CiteSeer/Cora node classification task, there is only 1 graph. So we use masking to determine validation and test set.

For graph classification tasks, a subset of graphs is considered validation / test graph.

In [7]:
def test(loader, model, is_train=False):
    model.eval()

    correct = 0
    for data in loader:
        with torch.no_grad():
            emb, pred = model(data)
            pred = pred.argmax(dim=1)
            label = data.y

        if model.task == 'node':
            mask = data.train_mask if is_train else data.test_mask
            # node classification: only evaluate on nodes in test set
            pred = pred[mask]
            label = data.y[mask]

        correct += pred.eq(label).sum().item()

    if model.task == 'graph':
        total = len(loader.dataset)
    else:
        total = 0
        for data in loader.dataset:
            mask = data.train_mask if is_train else data.test_mask
            total += torch.sum(mask).item()
    return correct / total

# Training the model

In [8]:
dataset = GNNBenchmarkDataset(root='/tmp/MNIST', name='MNIST')
dataset = dataset.shuffle()
task = 'graph'

model = train(dataset, task)

Downloading https://data.pyg.org/datasets/benchmarking-gnns/MNIST_v2.zip
Extracting /tmp/MNIST/MNIST/raw/MNIST_v2.zip
Processing...
Done!


Epoch 0. Loss: 1.9695. Test accuracy: 0.2705
Epoch 10. Loss: 1.8795. Test accuracy: 0.3547
Epoch 20. Loss: 1.8044. Test accuracy: 0.2196
Epoch 30. Loss: 1.8354. Test accuracy: 0.3343
Epoch 40. Loss: 1.8297. Test accuracy: 0.3317
Epoch 50. Loss: 1.8102. Test accuracy: 0.3738
Epoch 60. Loss: 1.7895. Test accuracy: 0.3854
Epoch 70. Loss: 1.8148. Test accuracy: 0.3450


And now with our custom model

In [None]:
model = train(dataset, task, use_custom=True)

Here we try a node classification task on the Citeseer citation network:

In [None]:
# dataset = FakeDataset(num_graphs=200, avg_num_nodes=500, num_channels=16, avg_degree=4)
# task = 'graph'

# model = train(dataset, task)

And now with our custom model

In [None]:
# model = train(dataset, task, use_custom=True)