In [2]:
import torch.nn as nn
import torch
from torch_geometric.utils import dense_to_sparse
from torch.nn import Linear, Parameter
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

Let's build a graph convolutional neural network. I will use the definition of a GCN layer given in the torch_geometric dicumentation.

$x_{i}^{(k)} = \sum_{j \in N(i) \bigcup i} \frac{1}{\sqrt(deg(i))\cdot\sqrt(deg(j)) } \cdot (W^{T} \cdot x_{j}^{(k-1)})+b$

Here k is used to denote the graph convolution layer. This can be identified as a pass through the graph convolution operation or a single iteration of using the above equation on the current set of node (and/or edge) features.

Let's understand what happens here. In this equation, we update the the node features of node $i$. The subscript of the summation symbol indicates that we consider the neighbors of node $i$ and node $i$ itself. For a given neighbor $j$, we transform it's node feature $x_{j}$ using the weight matrix $W$. We divide this value by the product of the square root of the degrees of nodes $i$ and $j$.

We do this operation for all the neighboring nodes of $i$.

Then we collect all these transformed node features of the neighbhring nodes corresponding to node $i$ and add them all together.
Optionally, we can add a bias term $b$ to this summation. This summation becomes the new node feature of node $i$, $x_{i}^{k}$ at iteration $k$.

![title](assets/gcn1.png)


Now let's build graph convolution nework using torch_geometric. This implementation is based on the code segment provided in their documentation. I did some modifications to make it easy to compare with another implementaion. Specifically I removed `self.lin.reset_parameters()` and added my own weight matrx $W$,
and set the bias values to zero.

In [275]:
class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels, weight_data):
        super().__init__(aggr='add') 
        self.lin = Linear(in_channels, out_channels, bias=True)


        self.lin.weight.data = weight_data
        self.lin.bias.data = torch.Tensor([0., 0.])

        self.b = Parameter(torch.empty(out_channels))
        self.b.data.zero_()


    def forward(self, x, edge_index):

        x = self.lin(x)

        row, col = edge_index
        deg = degree(col, x.size(1), dtype=x.dtype)

        deg_inv_sqrt = deg**(-0.5)

        # if the inverse square root is infinity, that value is set to zero.
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
        print(norm)

        # propagate is a method implemented in MessagePassing. This is where 
        # massage passing takes place. you have to send the edge indices and the node features.
        # additionally, you can send other optional values to create your message from node j to node i.
        out = self.propagate(edge_index, x=x, norm=norm)
        out = out + self.b

        return out

    def message(self, x_j, norm):
        # this is wehere we define the message from node j to node i.
        print(x_j)
        return norm.view(-1, 1) * x_j

#### Define the node and edge features

In [276]:
# node_feats = torch.arange(8, dtype=torch.float32).view(1, 4, 2)
node_feats = torch.tensor([[[0., 1.],
         [2., 3.],
         [4., 5.],
         [6., 7.]]], dtype=torch.float32)
adj_matrix = torch.Tensor([[[1, 1, 0, 0],
                            [1, 1, 1, 1],
                            [0, 1, 1, 1],
                            [0, 1, 1, 1]]])

In [277]:
weight_data = torch.Tensor([[1., 0.], [0., 1.]])

In [279]:
edge_index, _ = dense_to_sparse(adj_matrix)

In [280]:
edge_index

tensor([[0, 0, 1, 1, 1, 1, 2, 2, 2, 3, 3, 3],
        [0, 1, 0, 1, 2, 3, 1, 2, 3, 1, 2, 3]])

In [281]:
conv = GCNConv(2, 2, weight_data)

In [283]:
conv(node_feats, edge_index)

tensor([0.5000, 0.3536, 0.3536, 0.2500, 0.2887, 0.2887, 0.2887, 0.3333, 0.3333,
        0.2887, 0.3333, 0.3333])
tensor([[[0., 1.],
         [0., 1.],
         [2., 3.],
         [2., 3.],
         [2., 3.],
         [2., 3.],
         [4., 5.],
         [4., 5.],
         [4., 5.],
         [6., 7.],
         [6., 7.],
         [6., 7.]]], grad_fn=<IndexSelectBackward0>)


tensor([[[0.7071, 1.5607],
         [3.3868, 4.5677],
         [3.9107, 4.8660],
         [3.9107, 4.8660]]], grad_fn=<AddBackward0>)