# Creating GNN layers

One of the primary advantage of the current library is the fact that GNN layers are independant from model architecture, thus allowing more flexibility with the code by easily swapping different layer types as hyper-parameters. However, this requires that the layers be implemented using the DGL library, and must be inherited from the class `BaseDGLLayer`, which standardizes the inputs, outputs and properties of the layers. Thus, the architecture can be handled independantly using the class `FeedForwardDGL`, or any similar custom class.

We will first start by a simple layer that does not use edges, to a more complex layer that uses edges.

Since these examples are built on top of DGL, we recommend looking at their [library](https://docs.dgl.ai/en/0.5.x/index.html) for more info. 

In [1]:
%load_ext autoreload
%autoreload 2

import torch
import dgl
from copy import deepcopy

from goli.nn.dgl_layers import BaseDGLLayer
from goli.nn.base_layers import FCLayer
from goli.utils.decorators import classproperty


_ = torch.manual_seed(42)

Using backend: pytorch


## Pre-defining test variables

We define below a small batched graph on which we can test the created layers

In [2]:
in_dim = 5          # Input node-feature dimensions
out_dim = 11        # Desired output node-feature dimensions
in_dim_edges = 13   # Input edge-feature dimensions

# Let's create 2 simple graphs. Here the tensors represent the connectivity between nodes
g1 = dgl.graph((torch.tensor([0, 1, 2]), torch.tensor([1, 2, 3])))
g2 = dgl.graph((torch.tensor([0, 0, 0, 1]), torch.tensor([0, 1, 2, 0])))

# We add some node features to the graphs
g1.ndata["h"] = torch.rand(g1.num_nodes(), in_dim, dtype=float)
g2.ndata["h"] = torch.rand(g2.num_nodes(), in_dim, dtype=float)

# We also add some edge features to the graphs
g1.edata["e"] = torch.rand(g1.num_edges(), in_dim_edges, dtype=float)
g2.edata["e"] = torch.rand(g2.num_edges(), in_dim_edges, dtype=float)

# Finally we batch the graphs in a way compatible with the DGL library
bg = dgl.batch([g1, g2])
bg = dgl.add_self_loop(bg)

# The batched graph will show as a single graph with 7 nodes
print(bg)

Graph(num_nodes=7, num_edges=14,
      ndata_schemes={'h': Scheme(shape=(5,), dtype=torch.float64)}
      edata_schemes={'e': Scheme(shape=(13,), dtype=torch.float64)})


## Creating a simple layer

Here, we will show how to create a GNN layer that does a mean aggregation on the neighbouring features.

First, for the layer to be fully compatible with the flexible architecture provided by `FeedForwardDGL`, it needs to inherit from the class `BaseDGLLayer`. This base-layer has multiple virtual methods that must be implemented in any class that inherits from it.

The virtual methods are below

- `layer_supports_edges`: We want to return `False` since our layer doesn't support edges
- `layer_inputs_edges`: We want to return `False` since our layer doesn't input edges
- `layer_outputs_edges`: We want to return `False` since our layer doesn't output edges
- `layer_outdim_factor`: We want to return `1` since the output dimension does not depend on internal parameters.

The example is given below

In [3]:
class SimpleMeanLayer(BaseDGLLayer):
    def __init__(self, in_dim, out_dim, activation, dropout, normalization):
        # Initialize the parent class
        super().__init__(   in_dim=in_dim, out_dim=out_dim, activation=activation,
                            dropout=dropout, normalization=normalization)

        # Create the layer with learned parameters
        self.layer = FCLayer(in_dim=in_dim, out_dim=out_dim)

    def forward(self, g, h):
        # We first apply the mean aggregation
        g.ndata["h"] = h
        g.update_all(message_func=dgl.function.copy_u("h", "m"), 
                    reduce_func=dgl.function.mean("m", "h"))

        # Then we apply the FCLayer, and the non-linearities
        h = g.ndata["h"]
        h = self.layer(h)
        h = self.apply_norm_activation_dropout(h)
        return h

    # Finally, we define all the virtual properties according to how
    # the class works
    @classproperty
    def layer_supports_edges(cls):
        return False

    @property
    def layer_inputs_edges(self):
        return False

    @property
    def layer_outputs_edges(self):
        return False

    @property
    def out_dim_factor(self):
        return 1   

Now, we are ready to test the `SimpleMeanLayer` on some DGL graphs. Note that in this example, we **ignore** the edge features since they are not supported.

In [4]:
graph = deepcopy(bg)
h_in = graph.ndata["h"]
layer = SimpleMeanLayer(
            in_dim=in_dim, out_dim=out_dim, 
            activation="relu", dropout=.3, normalization="batch_norm").to(float)
h_out = layer(graph, h_in)

print(h_in.shape)
print(h_out.shape)

torch.Size([7, 5])
torch.Size([7, 11])


## Creating a complex layer with edges

Here, we will show how to create a GNN layer that does a mean aggregation on the neighbouring features, concatenated to the edge features with their neighbours. In that case, only the node features will change, and the network will not update the edge features.

The virtual methods will have different outputs

- `layer_supports_edges`: We want to return `True` since our layer does support edges
- `layer_inputs_edges`: We want to return `True` since our layer does input edges
- `layer_outputs_edges`: We want to return `False` since our layer will not output new edges
- `layer_outdim_factor`: We want to return `1` since the output dimension does not depend on internal parameters.

The example is given below

In [5]:
class ComplexMeanLayer(BaseDGLLayer):
    def __init__(self, in_dim, out_dim, in_dim_edges, activation, dropout, normalization):
        # Initialize the parent class
        super().__init__(   in_dim=in_dim, out_dim=out_dim, activation=activation,
                            dropout=dropout, normalization=normalization)

        # Create the layer with learned parameters. Note the addition
        self.layer = FCLayer(in_dim=in_dim + in_dim_edges, out_dim=out_dim)

    def cat_nodes_edges(self, edges):
        # Create a message "m" by concatenating "h" and "e" for each pair of nodes
        nodes_edges = torch.cat([edges.src["h"], edges.data["e"]], dim=-1)
        return {"m": nodes_edges}

    def get_edges_messages(self, edges): # Simply return the messages on the edges
        return {"m": edges.data["m"]}

    def forward(self, g, h, e):

        # We first concatenate both the node and edge features on the edges
        g.ndata["h"] = h
        g.edata["e"] = e
        g.apply_edges(self.cat_nodes_edges)

        # Then we apply the mean aggregation to generate a message "m"
        g.update_all(message_func=self.get_edges_messages, 
                    reduce_func=dgl.function.mean("m", "h"))

        # Finally we apply the FCLayer, and the non-linearities
        h = g.ndata["h"]
        h = self.layer(h)
        h = self.apply_norm_activation_dropout(h)
        return h

    # Finally, we define all the virtual properties according to how
    # the class works
    @classproperty
    def layer_supports_edges(cls):
        return True

    @property
    def layer_inputs_edges(self):
        return True

    @property
    def layer_outputs_edges(self):
        return False

    @property
    def out_dim_factor(self):
        return 1   

Now, we are ready to test the `ComplexMeanLayer` on some DGL graphs. Note that in this example, we **use** the edge features since they are mandatory.

In [6]:
graph = deepcopy(bg)
h_in = graph.ndata["h"]
e_in = graph.edata["e"]
layer = ComplexMeanLayer(
            in_dim=in_dim, out_dim=out_dim, in_dim_edges=in_dim_edges,
            activation="relu", dropout=.3, normalization="batch_norm").to(float)
h_out = layer(graph, h_in, e_in)

print(h_in.shape)
print(h_out.shape)

torch.Size([7, 5])
torch.Size([7, 11])
