### Re-implementing GCNConv layer in PyTorch Geometric for personal understanding

In [13]:
# imports
import torch_geometric
import networkx as nx
import torch_geometric.nn
from data_gen import get_graph
import collections
import torch
import math

In [14]:
# get the first graph from the dataset
graph = get_graph()

# node features are (x,y,z) coordinates
node_features = graph.pos
# edge_index are dst and src pairs
edge_index = graph.edge_index

n_nodes = node_features.shape[0]
n_edges = edge_index.shape[1]
print(f"n_nodes: {n_nodes}, n_edges: {n_edges}")

n_nodes: 2518, n_edges: 15108


Get graph in a more convenient "adjacency list"-esque form:

In [15]:
# preprocess graph to get list of neighbors for each node
graph_dict = collections.defaultdict(list)
for i in range(edge_index.shape[1]):
    src = edge_index[0, i].item()
    dst = edge_index[1, i].item()
    graph_dict[dst].append(src)

Define torch_geometric graph conv layer and process graph to get the target output:

In [16]:
torch.manual_seed(1)
target_layer = torch_geometric.nn.GCNConv(3, 10, normalize=True, add_self_loops=True, bias=False, aggr="sum")
target_output = target_layer(node_features, edge_index)
print(f"target_output {target_output.shape} \n{target_output[0]}")

target_output torch.Size([2518, 10]) 
tensor([ 0.0014,  0.0064,  0.0056, -0.0176, -0.0161,  0.0093, -0.0175, -0.0085,
        -0.0165,  0.0037], grad_fn=<SelectBackward0>)


Get equivalent torch linear layer:

In [17]:
# get weight and bias from torch_geometric layer
for name, param in target_layer.named_parameters():
    if "weight" in name:
        target_weight = param
    if name == "bias":
        target_bias = param

# load in GCN weights and bias to a torch.nn.Linear layer
layer = torch.nn.Linear(3, 10, bias=None)
layer.weight = target_weight

Loop through nodes and manually apply aggregation and linear projection of node features:

In [18]:
output = []
for i in range(node_features.shape[0]):
    # get feature for current node
    curr_node_feature = node_features[i]

    # get neighbouring node features
    n_neighbours = len(graph_dict[i])
    aggregated_neighbor_features = torch.zeros(3)
    for neighbour_node in graph_dict[i]:
        neighbor_feature = node_features[neighbour_node]
        aggregated_neighbor_features += neighbor_feature

    # calculate new node feature
    new_node_feature = layer(curr_node_feature) + layer(aggregated_neighbor_features)
    new_node_feature = new_node_feature / (n_neighbours + 1)
    output.append(new_node_feature)
output = torch.stack(output)


Evaluate to make sure that both methods are the same:

In [19]:
result = torch.isclose(target_output, output).float().mean()
if result == 1:
    print("Success")
else:
    print(f"Failed - only {result} of the values are close to the target output.")


Success
