In [2]:
import time

import torch
import torch.nn.functional as F
from torch.nn import Linear, Parameter
import torch.nn.utils.parametrize as parametrize
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau

import torch_geometric
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree, homophily

from torch_geometric.datasets import WebKB, Planetoid

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

  from .autonotebook import tqdm as notebook_tqdm


## The GCNConv cell explains how to implement a customized MPNN, with MessagePassing class.

In [110]:
class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add')  # "Add" aggregation.
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        # Step 1: Add self-loops to the adjacency matrix.
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Step 2: Linearly transform node feature matrix.
        x = self.lin(x)

        # Step 3-5: Start propagating messages.

        ''' self.propagate has to take as input at least 'edge_index' and 'x', then we can specify also other arguments,
            like ciao. '''

        return self.propagate(edge_index, ciao = 'ciao', size = (x.size(0), x.size(0)), x=x)

    def message(self, x_j, edge_index, ciao, size):
        # x_j has shape [E, out_channels]
        # Step 3: Normalize node features.
        ''' x_j contains the node features for row '''
        row, col = edge_index
        deg = degree(row, size[0], dtype=x_j.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == 'inf'] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        print("1: ", (norm.view(-1, 1)*x_j).shape)

        return norm.view(-1, 1) * x_j

    def update(self, aggr_out):
        # aggr_out has shape [N, out_channels]
        print("2: ", aggr_out.shape)
        # Step 5: Return new node embeddings.
        return aggr_out

In [111]:
dataset_texas = WebKB(root='/tmp/Texas', name='Texas')

In [112]:
x = dataset_texas.x
edges = dataset_texas.edge_index

In [None]:
test_network = GCNConv(dataset_texas.num_features, dataset_texas.num_classes)
t = test_network(x, edges)

## Customized implementation for the SAGEConv GNN layer

In [122]:
class SAGECv(MessagePassing):
    def __init__(self, input_dim, hidden_dim, project = False):
        super().__init__(aggr='mean')
        self.l1 = torch.nn.Linear(hidden_dim, hidden_dim, bias = False)
        self.l2 = torch.nn.Linear(hidden_dim, hidden_dim, bias = False)
        self.project = project
        if self.project:
            self.l3 = torch.nn.Linear(input_dim, hidden_dim, bias = True)

    def forward(self, x, edge_index):

        if self.project:
            x = F.relu(self.l3(x))

        return self.propagate(edge_index, x = x)
    
    def update(self, aggr_out, x):
        # aggr_out is always the output after the aggregation, and x are the nodes to update.
        aggr_out = self.l1(aggr_out)

        out = self.l2(x) + self.l1(aggr_out)

        return out
        


In [124]:
s = SAGECv(dataset_texas.num_features, dataset_texas.num_classes, project = True)

t = s(x, edges)

torch.Size([183, 5])
torch.Size([183, 5])
torch.Size([183, 5])


### Making a linear layer symmetric, info at [this link](https://pytorch.org/tutorials/intermediate/parametrizations.html).

In [None]:
# Given a square matrix:
matrix = torch.tensor([[1, 2, 44], [1, 2, 3111], [0, 1, 4]])
# We can get its upper triangular part as 
print(matrix.triu(0))
# and also its counter-version with zero diagonal terms
print(matrix.triu(1))
# Symmetric matrix is:
print(matrix.triu(0) + matrix.triu(1).transpose(-1, -2))

Since we want to parametrize the linear layers, but also to separate this process by the layer definition we should proceed by using the torch.nn.parametrize function.

In [14]:
import torch.nn.utils.parametrize as parametrize

class Symmetric(torch.nn.Module):
    def forward(self, w):
        # This class implements the method to define the symmetry in the squared matrices.
        return w.triu(0) + w.triu(1).transpose(-1, -2)

hidden_dimension = 5

# Let's notice that we need to define squared layers
layer = torch.nn.Linear(hidden_dimension, hidden_dimension)
print("BEFORE: ", layer.weight)
# LAYER DEFINITION & SYMMETRY are now separated processes.
parametrize.register_parametrization(layer, 'weight', Symmetric())
print("AFTER: ", layer.weight)


BEFORE:  Parameter containing:
tensor([[-0.3834,  0.3018,  0.1072, -0.1389, -0.0130],
        [-0.2427, -0.0612, -0.0034, -0.0523, -0.3426],
        [ 0.3694,  0.2531,  0.3518,  0.0486,  0.0359],
        [-0.2783, -0.1788, -0.0249, -0.4132, -0.1983],
        [ 0.4021, -0.1767,  0.1618,  0.3949,  0.1506]], requires_grad=True)
AFTER:  tensor([[-0.3834,  0.3018,  0.1072, -0.1389, -0.0130],
        [ 0.3018, -0.0612, -0.0034, -0.0523, -0.3426],
        [ 0.1072, -0.0034,  0.3518,  0.0486,  0.0359],
        [-0.1389, -0.0523,  0.0486, -0.4132, -0.1983],
        [-0.0130, -0.3426,  0.0359, -0.1983,  0.1506]], grad_fn=<AddBackward0>)


### Let's see now the symmetry in the GRAFF paper
Here the parametrization is done as follows.

In [20]:
class PairwiseParametrization(torch.nn.Module):
    def forward(self, W):
        # Construct a symmetric matrix with zero diagonal
        # The weights are initialized to be non-squared, with 2 additional columns. We cut from two of these
        # two vectors q and r, and then we compute w_diag as described in the paper.
        # This procedure is done in order to easily distribute the mass in its spectrum through the values of q and r
        W0 = W[:, :-2].triu(1)

        W0 = W0 + W0.T

        # Retrieve the `q` and `r` vectors from the last two columns
        q = W[:, -2]
        r = W[:, -1]
        # Construct the main diagonal
        w_diag = torch.diag(q * torch.sum(torch.abs(W0), 1) + r) 

        return W0 + w_diag

layer = torch.nn.Linear(hidden_dimension + 2, hidden_dimension)
print("BEFORE: ", layer.weight)
# LAYER DEFINITION & SYMMETRY are now separated processes.
parametrize.register_parametrization(layer, 'weight', PairwiseParametrization(), unsafe = True) 
# unsafe = True is used to change the tensor dimension with the re-parametrization.
print("AFTER: ", layer.weight)

BEFORE:  Parameter containing:
tensor([[-0.0695, -0.0465, -0.3562,  0.2027, -0.1504,  0.1841,  0.1395],
        [-0.0970,  0.2512,  0.2056, -0.3359, -0.1418,  0.3448,  0.2772],
        [-0.1398, -0.0160,  0.0813, -0.3447, -0.2584,  0.1737, -0.0010],
        [ 0.2093,  0.3520, -0.1950, -0.0436, -0.2272,  0.0019,  0.0587],
        [ 0.2048, -0.2496, -0.0404, -0.1384,  0.2391,  0.0524, -0.3009]],
       requires_grad=True)
tensor([[ 0.0000, -0.0465, -0.3562,  0.2027, -0.1504],
        [ 0.0000,  0.0000,  0.2056, -0.3359, -0.1418],
        [ 0.0000,  0.0000,  0.0000, -0.3447, -0.2584],
        [ 0.0000,  0.0000,  0.0000,  0.0000, -0.2272],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000]],
       grad_fn=<TriuBackward0>)
AFTER:  tensor([[ 0.2787, -0.0465, -0.3562,  0.2027, -0.1504],
        [-0.0465,  0.5288,  0.2056, -0.3359, -0.1418],
        [-0.3562,  0.2056,  0.2013, -0.3447, -0.2584],
        [ 0.2027, -0.3359, -0.3447,  0.0609, -0.2272],
        [-0.1504, -0.1418, -0.2584, -0.2