# Making GNN Networks

In this example, you will learn how to easily build a full GNN network using any kind of GNN layer.

In [1]:
%load_ext autoreload
%autoreload 2

import torch
import dgl
from copy import deepcopy

from goli.dgl.dgl_layers import PNAMessagePassingLayer
from goli.dgl.architectures import FullDGLNetwork

_ = torch.manual_seed(42)

Using backend: pytorch


We will first create some simple batched graphs that will be used accross the examples.

In [2]:
in_dim = 5          # Input node-feature dimensions
out_dim = 11        # Desired output node-feature dimensions
in_dim_edges = 13   # Input edge-feature dimensions

# Let's create 2 simple graphs. Here the tensors represent the connectivity between nodes
g1 = dgl.graph((torch.tensor([0, 1, 2]), torch.tensor([1, 2, 3])))
g2 = dgl.graph((torch.tensor([0, 0, 0, 1]), torch.tensor([0, 1, 2, 0])))

# We add some node features to the graphs
g1.ndata["h"] = torch.rand(g1.num_nodes(), in_dim, dtype=float)
g2.ndata["h"] = torch.rand(g2.num_nodes(), in_dim, dtype=float)

# We also add some edge features to the graphs
g1.edata["e"] = torch.rand(g1.num_edges(), in_dim_edges, dtype=float)
g2.edata["e"] = torch.rand(g2.num_edges(), in_dim_edges, dtype=float)

# Finally we batch the graphs in a way compatible with the DGL library
bg = dgl.batch([g1, g2])
bg = dgl.add_self_loop(bg)

# The batched graph will show as a single graph with 7 nodes
print(bg)


Graph(num_nodes=7, num_edges=14,
      ndata_schemes={'h': Scheme(shape=(5,), dtype=torch.float64)}
      edata_schemes={'e': Scheme(shape=(13,), dtype=torch.float64)})


## Building the network

To build the network, we must define the arguments to pass at the different steps:

- `pre_nn_kwargs`: The parameters used by a feed-forward neural network on the input node-features, before passing to the convolutional layers. See class `FeedForwardNN` for details on the required parameters. Will be ignored if set to `None`.

- `gnn_kwargs`: The parameters used by a feed-forward **graph** neural network on the features after it has passed through the pre-processing neural network. See class `FeedForwardDGL` for details on the required parameters.

- `post_nn_kwargs`: The parameters used by a feed-forward neural network on the features after the GNN layers. See class `FeedForwardNN` for details on the required parameters. Will be ignored if set to `None`.

In [3]:
temp_dim_1 = 23
temp_dim_2 = 17

pre_nn_kwargs = {
        "in_dim": in_dim,
        "out_dim": temp_dim_1,
        "hidden_dims": [4, 4, 4],
        "activation": "relu",
        "last_activation": "none",
        "batch_norm": True,
        "dropout": 0.2,    }

post_nn_kwargs = {
        "in_dim": temp_dim_2,
        "out_dim": out_dim,
        "hidden_dims": [6, 6],
        "activation": "relu",
        "last_activation": "sigmoid",
        "batch_norm": False,
        "dropout": 0.,    }

layer_kwargs = {
    "aggregators": ["mean", "max", "sum"], 
    "scalers": ["identity", "amplification"],
    "in_dim_edges": in_dim_edges}

gnn_kwargs = {
    "in_dim": temp_dim_1,
    "out_dim": temp_dim_2,
    "hidden_dims": [5, 5, 5, 5, 5, 5],
    "residual_type": "densenet",
    "residual_skip_steps": 2,
    "layer_type": PNAMessagePassingLayer,
    "pooling": ["sum"],
    "activation": "relu",
    "last_activation": "none",
    "batch_norm": False,
    "dropout": 0.2,
    **layer_kwargs,
}

gnn_net = FullDGLNetwork(
    pre_nn_kwargs=pre_nn_kwargs, 
    gnn_kwargs=gnn_kwargs, 
    post_nn_kwargs=post_nn_kwargs).to(float)

## Applying the network

Once the network is defined, we only need to run the forward pass on the input graphs to get a prediction.

The network will handle the node and edge features depending on it's parameters and layer type.

In [4]:
graph = deepcopy(bg)
h_in = graph.ndata["h"]

h_out = gnn_net(graph)

print(h_in.shape)
print(h_out.shape)
print("\n")
print(gnn_net)


torch.Size([7, 5])
torch.Size([1, 11])


DGL_GNN
---------
    pre-trans-NN(depth=4, ResidualConnectionNone(skip_steps=1))
        [FCLayer[5 -> 4 -> 4 -> 4 -> 23] -> Linear({self.out_dim})
    
    main-GNN(depth=7, ResidualConnectionDenseNet(skip_steps=2))
        PNAMessagePassingLayer[23 -> 5 -> 5 -> 5 -> 5 -> 5 -> 5 -> 17]
        -> Pooling(['sum']) -> FCLayer(17 -> 17, activation=None)
    
    post-trans-NN(depth=3, ResidualConnectionNone(skip_steps=1))
        [FCLayer[17 -> 6 -> 6 -> 11] -> Linear({self.out_dim})
