In [1]:
import time

import torch
import torch.nn.functional as F
from torch.nn import Linear, Parameter
import torch.nn.utils.parametrize as parametrize
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau

import torch_geometric
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree, homophily

from torch_geometric.datasets import WebKB, Planetoid

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

  from .autonotebook import tqdm as notebook_tqdm


## The GCNConv cell explains how to implement a customized MPNN, with MessagePassing class.

In [110]:
class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add')  # "Add" aggregation.
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        # Step 1: Add self-loops to the adjacency matrix.
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Step 2: Linearly transform node feature matrix.
        x = self.lin(x)

        # Step 3-5: Start propagating messages.

        ''' self.propagate has to take as input at least 'edge_index' and 'x', then we can specify also other arguments,
            like ciao. '''

        return self.propagate(edge_index, ciao = 'ciao', size = (x.size(0), x.size(0)), x=x)

    def message(self, x_j, edge_index, ciao, size):
        # x_j has shape [E, out_channels]
        # Step 3: Normalize node features.
        ''' x_j contains the node features for row '''
        row, col = edge_index
        deg = degree(row, size[0], dtype=x_j.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == 'inf'] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        print("1: ", (norm.view(-1, 1)*x_j).shape)

        return norm.view(-1, 1) * x_j

    def update(self, aggr_out):
        # aggr_out has shape [N, out_channels]
        print("2: ", aggr_out.shape)
        # Step 5: Return new node embeddings.
        return aggr_out

In [111]:
dataset_texas = WebKB(root='/tmp/Texas', name='Texas')

In [112]:
x = dataset_texas.x
edges = dataset_texas.edge_index

In [None]:
test_network = GCNConv(dataset_texas.num_features, dataset_texas.num_classes)
t = test_network(x, edges)

## Customized implementation for the SAGEConv GNN layer

In [122]:
class SAGECv(MessagePassing):
    def __init__(self, input_dim, hidden_dim, project = False):
        super().__init__(aggr='mean')
        self.l1 = torch.nn.Linear(hidden_dim, hidden_dim, bias = False)
        self.l2 = torch.nn.Linear(hidden_dim, hidden_dim, bias = False)
        self.project = project
        if self.project:
            self.l3 = torch.nn.Linear(input_dim, hidden_dim, bias = True)

    def forward(self, x, edge_index):

        if self.project:
            x = F.relu(self.l3(x))

        return self.propagate(edge_index, x = x)
    
    def update(self, aggr_out, x):
        # aggr_out is always the output after the aggregation, and x are the nodes to update.
        aggr_out = self.l1(aggr_out)

        out = self.l2(x) + self.l1(aggr_out)

        return out
        


In [124]:
s = SAGECv(dataset_texas.num_features, dataset_texas.num_classes, project = True)

t = s(x, edges)

torch.Size([183, 5])
torch.Size([183, 5])
torch.Size([183, 5])
