In [41]:
import torch
from torch.nn import Linear, Parameter
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops,degree
from torch_geometric.data import Data

In [42]:
class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='add')  # "Add" aggregation (Step 5).
        self.lin = Linear(in_channels, out_channels, bias=False)
        self.bias = Parameter(torch.Tensor(out_channels))

        self.reset_parameters()

    def reset_parameters(self):
        self.lin.reset_parameters()
        self.bias.data.zero_()

    def forward(self, x, edge_index):
        # x has size [num_nodes, num_features]
        # edge_index has size [2, num_edges]

        #step 1: add self-loops to the adjacency matrix.
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        #step 2: liniearly transform node feature matrix
        x=self.lin(x)

        #step 3: compute normalzation
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # step 4-5: start propagating messages.
        out = self.propagate(edge_index, x=x , norm=norm)

        # step 6: apply a final bias vector.
        out += self.bias

        return out

    def message(self, x_j, norm):
        # x_j has size [num_edges, out_channels]
        # step 4: normalize node features:
        return norm.view(-1,1) * x_j

In [43]:
# data
edge_index = torch.tensor([[0, 1, 1, 2],
                           [1, 0, 2, 1]], dtype=torch.long)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
print("before message passing: ", x.shape)

# model initialization
conv = GCNConv(1, 32)
x = conv(x,edge_index)
print("after message passing:", x.shape)

before message passing:  torch.Size([3, 1])
after message passing: torch.Size([3, 32])
