![MLU Logo](../images/MLU_Logo.png)

# Customize Graph Convolution using Message Passing APIs

In previous sessions, we have learned using the built-in [graph convolution modules](https://docs.dgl.ai/api/python/nn.pytorch.html#module-dgl.nn.pytorch.conv) to build a multi-layer graph neural network. However, sometimes one desires to invent a new way of aggregating neighbor information. DGL's message passing APIs are designed for this scenario.

In this notebook, you will learn:
* What is under the hood of the `nn.SAGEConv` module in DGL?
* DGL's message passing APIs.
* Design a new graph convolution module.

In [1]:
!pip install dgl
!pip install torch



In [2]:
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F

Using backend: pytorch


## A gentle explanation of the `SAGEConv` module

Recall that a `SAGEConv` module aggregates neighbor information and generates new node representations as follows:


$$
h_{\mathcal{N}(v)}^k\leftarrow \text{AGGREGATE}_k\{h_u^{k-1},\forall u\in\mathcal{N}(v)\}
$$

$$
h_v^k\leftarrow \text{ReLU}\left(W^k\cdot \text{CONCAT}(h_v^{k-1}, h_{\mathcal{N}(v)}^k) \right)
$$

Here is its implementation in DGL.

In [3]:
import dgl.function as fn

class SAGEConv(nn.Module):
    """Graph convolution module used by the GraphSAGE model.
    
    Parameters
    ----------
    in_feat : int
        Input feature size.
    out_feat : int
        Output feature size.
    """
    def __init__(self, in_feat, out_feat):
        super(SAGEConv, self).__init__()
        # A linear submodule for projecting the input and neighbor feature to the output.
        # Dark Gray box 
        self.linear = nn.Linear(in_feat * 2, out_feat)
    
    def forward(self, g, h):
        """Forward computation
        
        Parameters
        ----------
        g : Graph
            The input graph.
        h : Tensor
            The input node feature.
        """
        # All the `ndata` set within a local scope will be automatically popped out.
        with g.local_scope():
            g.ndata['h'] = h
            # update_all is a message passing API.
            g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_neigh'))
            h_neigh = g.ndata['h_neigh']
            h_total = torch.cat([h, h_neigh], dim=1)
            return F.relu(self.linear(h_total))

The central piece in this code is the `g.update_all` function, which gathers and averages the neighbor features. There are three concepts here:
* Message function `fn.copy_u('h', 'm')` that copies the node feature under name `'m'` as *messages* sent to neighbors.
* Reduce function `fn.mean('m', 'h_neigh')` that averages all the received messages under name `'m'` and saves the result as a new node feature `'h_neigh'`.
* `update_all` tells DGL to trigger the message and reduce functions for all the nodes and edges.

**NOTE** : The above implementation is for a *single layer* of graph unrolling. It does not constitute an entire trainable network - for that, we need to use this in place of the `GraphConv` layers in the Day 2 notebook.

## Message passing and GNNs

The `update_all` is one of the **message passing APIs** in DGL, inspired by the Message Passing Neural Network proposed by [Gilmer et al.](https://arxiv.org/abs/1704.01212) Essentailly, they found many GNN models can fit into the following framework:

$$
m_{u\sim v}^{(l)} = M^{(l)}\left(h_v^{(l-1)}, h_u^{(l-1)}, e_{u\sim v}^{(l-1)}\right)
$$

$$
m_{v}^{(l)} = \sum_{u\in\mathcal{N}(v)}m_{u\sim v}^{(l)}
$$

$$
h_v^{(l)} = U^{(l)}\left(h_v^{(l-1)}, m_v^{(l)}\right)
$$

, where the $M^{(l)}$ is called message function and the $\sum$ is the reduce function.

You can find more details in [the API doc](https://docs.dgl.ai/api/python/function.html).

DGL's message passing APIs allow one to quickly implement new graph convolution modules. For example, the following implements a new `SAGEConv` that aggregates neighbor representations using a weighted average.

In [4]:
class SAGEConv(nn.Module):
    """Graph convolution module used by the GraphSAGE model.
    
    Parameters
    ----------
    in_feat : int
        Input feature size.
    out_feat : int
        Output feature size.
    """
    def __init__(self, in_feat, out_feat):
        super(SAGEConv, self).__init__()
        # A linear submodule for projecting the input and neighbor feature to the output.
        self.linear = nn.Linear(in_feat * 2, out_feat)
    
    def forward(self, g, h, w):
        """Forward computation
        
        Parameters
        ----------
        g : Graph
            The input graph.
        h : Tensor
            The input node feature.
        w : Tensor
            The edge weight.
        """
        # All the `ndata` set within a local scope will be automatically popped out.
        with g.local_scope():
            g.ndata['h'] = h
            g.edata['w'] = w
            # update_all is a message passing API.
            g.update_all(fn.u_mul_e('h', 'w', 'm'), fn.mean('m', 'h_neigh'))
            h_neigh = g.ndata['h_neigh']
            h_total = torch.cat([h, h_neigh], dim=1)
            return F.relu(self.linear(h_total))

## Even more customization by user-defined function

DGL allows user-defined message and reduce function for the maximal expressiveness. Here is a user-defined message function that is equivalent to `fn.u_mul_e('h', 'w', 'm')`.

In [5]:
def u_mul_e_udf(edges):
    return {'m' : edges.src['h'] * edges.data['w']}

## Recap

* `dgl.nn` provides many popular modules for quick bootstrap.
* Using the built-in message and reduce functions in `dgl.function` to customize a new NN module.
* User-defined function provides even more flexibility.

Exercise
-----------------

In this notebook we implemented `SAGEConv` with a single neural network layer (`nn.Linear`). Similarly, most layers  defined within DGL viz. `GraphConv` have a single layer by default. As we have seen in class, there are 2 notions of depth : one referring to the number of times we unroll the computation graph induced by the underlying network and the other referring to the depth of the neural network in each `GraphConv` layer. Let us consider the second notion of depth for this exercise : let's try to make a `SAGEConv` layer with a 2 layer network.

To do this, we will define a new layer class called `DeepSageConv` and implement a simple message passing-based `forward` method for it. Once this is defined, we can use our `DeepSageConv` layer in place of the `GraphConv` layer in the Day 2 notebook and go about the training process. 