### Re-implementing GATConv layer in PyTorch Geometric for personal understanding
- this layer uses the graph attention layer from Graph Attention Networks (https://arxiv.org/abs/1710.10903)

In [247]:
# imports
import torch_geometric
import networkx as nx
import torch_geometric.nn
from data_gen import get_graph
import collections
import torch

In [248]:
# get the first graph from the dataset
graph = get_graph()

# node features are (x,y,z) coordinates, shape = [n_nodes, 3]
node_features = graph.pos

# edge_index are dst and src pairs, shape = [2, n_edges]
edge_index = graph.edge_index

n_nodes = node_features.shape[0]
n_edges = edge_index.shape[1]       # [2, n_edges]
print(f"n_nodes: {n_nodes}, n_edges: {n_edges}")

n_nodes: 2518, n_edges: 15108


Get graph in a more convenient "adjacency list"-esque form:

In [249]:
# preprocess graph to get list of neighbors for each node
graph_dict = collections.defaultdict(list)
for i in range(edge_index.shape[1]):
    src = edge_index[0, i].item()
    dst = edge_index[1, i].item()
    graph_dict[dst].append(src)

# get mapping from (src, dst) to edge index
edge_dict = {}
for i in range(edge_index.shape[1]):
    src = edge_index[0, i].item()
    dst = edge_index[1, i].item()
    edge_dict[(src, dst)] = i

Define torch_geometric graph conv layer and process graph to get the target output:

In [250]:
torch.manual_seed(3)
target_layer = torch_geometric.nn.GATConv(3, 10, add_self_loops=False, bias=False, aggr="sum", concat=True, negative_slope=0.2, dropout=0, fill_value=0.0, heads=1)
target_layer.eval()

old_node_features = node_features.clone()
target_output, (new_edge_idxs, attn_weights) = target_layer(node_features, edge_index, return_attention_weights=True)

Get equivalent torch linear layer:

In [251]:
# get weight and bias from torch_geometric layer
for name, param in target_layer.named_parameters():
    print(name, param.shape)
    if name == "lin.weight":
        target_weight = param
    elif name == "att_src":
        att_src = param
    elif name == "att_dst":
        att_dst = param

# load in GATConv weights to a torch.nn.Linear layer
layer = torch.nn.Linear(3, 10, bias=False)
layer.weight = target_weight

att_src torch.Size([1, 1, 10])
att_dst torch.Size([1, 1, 10])
lin.weight torch.Size([10, 3])


Loop through nodes and manually calculate attention weights and new features for each node:

In [252]:

output = []
for i in range(node_features.shape[0]):
    # get projected feature for current node
    curr_feature = layer(node_features[i])

    # calculate projected features and attention weights for each neighbour
    weights = []
    neighbour_features = []
    for neighbour_node in graph_dict[i]:
        neighbour_feature = layer(node_features[neighbour_node])
        neighbour_features.append(neighbour_feature)

        weight = torch.dot(neighbour_feature.squeeze(), att_src.squeeze()) + torch.dot(curr_feature.squeeze(), att_dst.squeeze())
        weights.append(weight)

    # normalise weights
    weights = torch.stack(weights)
    weights = torch.nn.functional.leaky_relu(weights, negative_slope=0.2)
    weights = torch.softmax(weights, dim=0)
    
    # calculate new feature for current node
    new_feature = 0
    for j in range(len(neighbour_features)):
        new_feature += weights[j] * neighbour_features[j]

    output.append(new_feature)

output = torch.stack(output)

Evaluate to make sure that both methods are the same:

In [253]:
result = torch.isclose(target_output, output).float().mean()
if result == 1:
    print("Success")
else:
    print(f"Failed - only {result*100}% of the values are close to the target output.")


Success
