In [29]:
import torch

# -- Input
# 4 nodes, 3 edges
# 1 and 2 have 2 connections, the rest just 1
adj = torch.Tensor([[0, 0, 1, 0], [0, 0, 1, 1], [1, 1, 0, 0], [0, 1, 0, 0]])
edge_feat = torch.Tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1], [1,1,1], [1,1,1]])
edge_index = torch.Tensor([[0, 2], [1, 2], [1, 3], [2, 1], [2, 3], [3, 1]])


# -- Process
# Using edge_index and edge_feat matrices, aggregate the edge features
# so that each node has a feature vector based on the edges it is connected to
num_nodes = adj.size(0)
num_edges = edge_index.size(0)

# Initialize the node features
node_feat = torch.zeros(num_nodes, edge_feat.size(1))

# -- Method 1
# Iterate through the edges and aggregate edge features into node features
node_feat2 = torch.zeros(num_nodes, edge_feat.size(1))
for i, (src, dst) in enumerate(edge_index):
    edge_feat_i = edge_feat[i]
    node_feat2[src.to(int)] += edge_feat_i

# -- Method 2
# Gather edge features and accumulate them into node features
node_feat = torch.zeros(num_nodes, edge_feat.size(1))
src_nodes = edge_index[:, 0].to(int)
print(node_feat.shape, src_nodes.shape, edge_feat.shape)
node_feat.index_add_(0, src_nodes, edge_feat)

assert torch.allclose(node_feat, node_feat2)

torch.Size([4, 3]) torch.Size([6]) torch.Size([6, 3])
