In [1]:
import time

import torch
import torch.nn.functional as F
from torch.nn import Linear, Parameter
import torch.nn.utils.parametrize as parametrize
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch.nn as nn

import torch_geometric
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree, homophily

from torch_geometric.datasets import WebKB, Planetoid

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

  from .autonotebook import tqdm as notebook_tqdm


## The GCNConv cell explains how to implement a customized MPNN, with MessagePassing class.

In [58]:
class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add')  # "Add" aggregation.
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        # Step 1: Add self-loops to the adjacency matrix.
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Step 2: Linearly transform node feature matrix.
        x = self.lin(x)

        # Step 3-5: Start propagating messages.

        ''' self.propagate has to take as input at least 'edge_index' and 'x', then we can specify also other arguments,
            like ciao. '''

        return self.propagate(edge_index, ciao = 'ciao', size = (x.size(0), x.size(0)), x=x)

    def message(self, x_j, edge_index, ciao, size):
        # x_j has shape [E, out_channels]
        # Step 3: Normalize node features.
        ''' x_j contains the node features for row '''
        row, col = edge_index
        deg = degree(row, size[0], dtype=x_j.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == 'inf'] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        print("1: ", (norm.view(-1, 1)*x_j).shape)
        print(x_j.shape)
        print(norm.view(-1, 1).shape)
        return norm.view(-1, 1) * x_j

    def update(self, aggr_out):
        # aggr_out has shape [N, out_channels]
        print("2: ", aggr_out.shape)
        # Step 5: Return new node embeddings.
        return aggr_out

In [59]:
dataset_texas = WebKB(root='/tmp/Texas', name='Texas')

In [60]:
x = dataset_texas.x
edges = dataset_texas.edge_index

In [61]:
test_network = GCNConv(dataset_texas.num_features, dataset_texas.num_classes)
t = test_network(x, edges)

1:  torch.Size([508, 5])
torch.Size([508, 5])
torch.Size([508, 1])
2:  torch.Size([183, 5])


## Customized implementation for the SAGEConv GNN layer

In [6]:
class SAGECv(MessagePassing):
    def __init__(self, input_dim, hidden_dim, project = False):
        super().__init__(aggr='mean')
        self.l1 = torch.nn.Linear(hidden_dim, hidden_dim, bias = False)
        self.l2 = torch.nn.Linear(hidden_dim, hidden_dim, bias = False)
        self.project = project
        if self.project:
            self.l3 = torch.nn.Linear(input_dim, hidden_dim, bias = True)

    def forward(self, x, edge_index):

        if self.project:
            x = F.relu(self.l3(x))

        return self.propagate(edge_index, x = x)
    
    def update(self, aggr_out, x):
        # aggr_out is always the output after the aggregation, and x are the nodes to update.
        aggr_out = self.l1(aggr_out)

        out = self.l2(x) + self.l1(aggr_out)

        return out
        


In [7]:
s = SAGECv(dataset_texas.num_features, dataset_texas.num_classes, project = True)

t = s(x, edges)

### Making a linear layer symmetric, info at [this link](https://pytorch.org/tutorials/intermediate/parametrizations.html).

In [8]:
# Given a square matrix:
matrix = torch.tensor([[1, 2, 44], [1, 2, 3111], [0, 1, 4]])
# We can get its upper triangular part as 
print(matrix.triu(0))
# and also its counter-version with zero diagonal terms
print(matrix.triu(1))
# Symmetric matrix is:
print(matrix.triu(0) + matrix.triu(1).transpose(-1, -2))

tensor([[   1,    2,   44],
        [   0,    2, 3111],
        [   0,    0,    4]])
tensor([[   0,    2,   44],
        [   0,    0, 3111],
        [   0,    0,    0]])
tensor([[   1,    2,   44],
        [   2,    2, 3111],
        [  44, 3111,    4]])


Since we want to parametrize the linear layers, but also to separate this process by the layer definition we should proceed by using the torch.nn.parametrize function.

In [9]:
import torch.nn.utils.parametrize as parametrize

class Symmetric(torch.nn.Module):
    def forward(self, w):
        # This class implements the method to define the symmetry in the squared matrices.
        return w.triu(0) + w.triu(1).transpose(-1, -2)

hidden_dimension = 5

# Let's notice that we need to define squared layers
layer = torch.nn.Linear(hidden_dimension, hidden_dimension)
print("BEFORE: ", layer.weight)
# LAYER DEFINITION & SYMMETRY are now separated processes.
parametrize.register_parametrization(layer, 'weight', Symmetric())
print("AFTER: ", layer.weight)


BEFORE:  Parameter containing:
tensor([[-0.0282, -0.2466,  0.1203, -0.0225, -0.2457],
        [ 0.1340,  0.0759, -0.1351, -0.2224, -0.0830],
        [-0.1430,  0.2537,  0.2474,  0.0577, -0.4088],
        [ 0.3516, -0.1308,  0.4388, -0.1365, -0.2749],
        [-0.1415,  0.2817,  0.4037,  0.0214, -0.0827]], requires_grad=True)
AFTER:  tensor([[-0.0282, -0.2466,  0.1203, -0.0225, -0.2457],
        [-0.2466,  0.0759, -0.1351, -0.2224, -0.0830],
        [ 0.1203, -0.1351,  0.2474,  0.0577, -0.4088],
        [-0.0225, -0.2224,  0.0577, -0.1365, -0.2749],
        [-0.2457, -0.0830, -0.4088, -0.2749, -0.0827]], grad_fn=<AddBackward0>)


### Let's see now the symmetry in the GRAFF paper
Here the parametrization is done as follows.

In [10]:
class PairwiseParametrization(torch.nn.Module):
    def forward(self, W):
        # Construct a symmetric matrix with zero diagonal
        # The weights are initialized to be non-squared, with 2 additional columns. We cut from two of these
        # two vectors q and r, and then we compute w_diag as described in the paper.
        # This procedure is done in order to easily distribute the mass in its spectrum through the values of q and r
        W0 = W[:, :-2].triu(1)

        W0 = W0 + W0.T

        # Retrieve the `q` and `r` vectors from the last two columns
        q = W[:, -2]
        r = W[:, -1]
        # Construct the main diagonal
        w_diag = torch.diag(q * torch.sum(torch.abs(W0), 1) + r)

        return W0 + w_diag


layer = torch.nn.Linear(hidden_dimension + 2, hidden_dimension)
print("BEFORE: ", layer.weight)
# LAYER DEFINITION & SYMMETRY are now separated processes.
parametrize.register_parametrization(
    layer, 'weight', PairwiseParametrization(), unsafe=True)
# unsafe = True is used to change the tensor dimension with the re-parametrization.
print("AFTER: ", layer.weight)

BEFORE:  Parameter containing:
tensor([[-0.1944,  0.2896, -0.1541,  0.1468,  0.0238,  0.3736, -0.0285],
        [-0.3449,  0.3421,  0.2561,  0.1855, -0.2920, -0.0639, -0.1927],
        [ 0.0495,  0.0405, -0.2633,  0.1034, -0.1770,  0.1188, -0.0787],
        [ 0.1491,  0.3262,  0.1329,  0.2281, -0.0370, -0.0072, -0.1882],
        [ 0.0835,  0.1011, -0.1975, -0.1815,  0.2231, -0.1395,  0.0717]],
       requires_grad=True)
AFTER:  tensor([[ 0.2010,  0.2896, -0.1541,  0.1468,  0.0238],
        [ 0.2896, -0.2581,  0.2561,  0.1855, -0.2920],
        [-0.1541,  0.2561,  0.0034,  0.1034, -0.1770],
        [ 0.1468,  0.1855,  0.1034, -0.1916, -0.0370],
        [ 0.0238, -0.2920, -0.1770, -0.0370, -0.0022]], grad_fn=<AddBackward0>)


In [99]:
class External_W(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.w = torch.nn.Parameter(torch.empty((1, input_dim)))
        self.reset_parameters()
    
    def reset_parameters(self):
        torch.nn.init.normal_(self.w)

    def forward(self, x):
        # x * self.w behave like a diagonal matrix op., we multiply each row of x by the element-wise w
        return x * self.w


class Source_b(nn.Module):
    def __init__(self):
        super().__init__()
        self.beta = torch.nn.Parameter(torch.empty(1))

        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.normal_(self.beta)
    


    def forward(self, x):
        return x * self.beta


class PairwiseInteraction_w(nn.Module):
    def __init__(self, input_dim, symmetry_type='1'):
        super().__init__()
        self.W = torch.nn.Linear(input_dim + 2, input_dim)

        if symmetry_type == '1':
            symmetry = PairwiseParametrization()
        elif symmetry_type == '2':
            symmetry = Symmetric()

        parametrize.register_parametrization(
            self.W, 'weight', symmetry, unsafe=True)
        self.reset_parameters()
        
    def reset_parameters(self):
        self.W.reset_parameters()

    def forward(self, x):
        return self.W(x)


class GRAFFConv(MessagePassing):
    def __init__(self, input_dim, symmetry_type='1', self_loops=True):
        super().__init__(aggr='add')
        self.in_dim = input_dim
        self.self_loops = self_loops
        self.external_w = External_W(self.in_dim)
        self.beta = Source_b()
        self.pairwise_W = PairwiseInteraction_w(
            self.in_dim, symmetry_type=symmetry_type)

    def forward(self, x, edge_index, x0):

        # We set the source term, which corrensponds with the initial conditions of our system.

        if self.self_loops:
            edge_index, _ = add_self_loops(edge_index, num_nodes=x.shape[0])

        out_p = self.pairwise_W(x)

        out = self.propagate(edge_index, x=out_p)

        out = out - (self.external_w(x) + self.beta(x0))

        return out

    def message(self, x_i, edge_index, x):
        # Does we need the degree of the row or from the columns?
        # x_i are the columns indices, whereas x_j are the row indices
        row, col = edge_index

        # Degree is specified by the row (outgoing edges)
        deg_matrix = degree(row, num_nodes=x.shape[0], dtype=x.dtype)
        deg_inv = deg_matrix.pow(-0.5)
        
        deg_inv[deg_inv == float('inf')] = 0

        denom_degree = deg_inv[row]*deg_inv[col]

        # Each row of denom_degree multiplies (element-wise) the rows of x_j
        return denom_degree.unsqueeze(-1) * x_i


class PhysicsGNN(nn.Module):
    def __init__(self, dataset, hidden_dim, num_layers, step = 0.1, symmetry_type='1', self_loops=False):
        super().__init__()

        self.enc = torch.nn.Linear(dataset.num_features, hidden_dim)
        self.dec = torch.nn.Linear(hidden_dim, dataset.num_classes)

        self.layers = [GRAFFConv(hidden_dim, symmetry_type=symmetry_type,
                            self_loops=self_loops) for i in range(num_layers)]
        self.step = step
    #     self.reset_parameters()

    # def reset_parameters(self):
    #     self.enc.reset_parameters()
    #     self.dec.reset_parameters()
    #     for layer in self.layers:
    #         layer.reset_parameters()


        
    def forward(self, data):

        x, edge_index = data.x, data.edge_index

        
        x = enc_out = self.enc(x)

        x0 = enc_out.clone()
        for layer in self.layers:
                
            x = x + self.step*F.relu(layer(x, edge_index, x0))

        output = self.dec(x)

        return F.log_softmax(output, dim=1)
        
            
    
        
        

In [96]:
g = PhysicsGNN(dataset_texas, 512, 3)

In [97]:
re = g(dataset_texas.x, dataset_texas.edge_index)