In [3]:
import torch
from torch.nn import functional as F

# -----Input setup
# Triangle
# 0 with 1, 2
# 2 with 0, 1
adj = torch.Tensor([[0, 1, 1], [1, 0, 1], [1, 1, 0]])
X_prime = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
n_nodes, _ = X_prime.shape
# Init S to be a column vector of size 3 x 1
S = torch.Tensor([1, 1, 1, 1, 1, 1]).unsqueeze(1)


# -----AttentionConv
# Add self loop to adjaceny matrix
adj = adj + torch.eye(n_nodes)

# (2) Compute attention scores for each neighbor
# (2a) Get all edges
# E.g. if adj[i, j] = 1, then nbs will have an entry (i, j)
# as well as (j, i) since the graph is undirected
nbs = adj.nonzero() # of shape (2*num_edges, 2)
i, j = nbs[:, 0], nbs[:, 1]

# (2b) Concatenate each vector's v transformed features
# with its neightbors' transformed features
# E.g. if X_prime[i] = [1, 2] and X_prime[j] = [3, 4]
# then x_prime_concat_all will include [[1, 2, 3, 4], [3, 4, 1, 2]]
X_prime_i = X_prime[i] # Target vectors (vs)
X_prime_j = X_prime[j] # Neighbor vectors (vnb)
X_prime_concat_all = torch.cat([X_prime_i, X_prime_j], dim=1) # Shape: (2*num_edges, 2*out_features)

# (2c) Compute the dot product between S and the concat. vector
# of target vector i and its neighbor j (nb)
# Note: (2*num_edges, 2*out_features) @ (2*out_features, 1) = (2*num_edges, 1)
# --> squeeze result to get a 1D tensor of shape (2*num_edges, )
E_i_nb = (X_prime_concat_all @ S).squeeze()

# Apply leaky relu to get the raw attention scores
raw_attention_scores = F.leaky_relu(E_i_nb)

# (2d) Finally, apply the softmax function to get the attention weights
# Map the values to positive range using exponential function
exp_attention_scores = torch.exp(raw_attention_scores)

# Create an array to hold the sum of exp. attention scores for each node
# (i.e. the total attention of the node's neighborhood)
neighborhood_sum = torch.zeros(n_nodes)

# Sum the exp. attention scores for each node's neighborhood
# Note:
# 0 indicates dimnesion, here we just have one dimension so 0
# i indicates the index of the node from edge (i, j)
neighborhood_sum.index_add_(0, i, exp_attention_scores)

# Divide the exponential scores by the sum of the neighborhood
attention_scores = exp_attention_scores / neighborhood_sum[i]

# (3) Compute the weighted sum of the neighbors'
# Create summation mask of shape (# of nodes, 2*# of edges)
# In simple terms, the mask indicates which edges is the given node i
# associated with. E.g. if mask[i, j] = 1, then edge (i, j) is associated.
mask = (i.view(-1, 1) == torch.arange(n_nodes)).T.to(torch.float32)

# Weight the neighbors' features by the attention scores
nbs_feat_weighted = attention_scores.view(-1, 1) * X_prime_j

# Compute the weighted sum of the neighbors' x_prime
weighted_sum = mask @ nbs_feat_weighted

# Finally, apply sigmoid
X_result = torch.sigmoid(weighted_sum)

X_result


tensor([[0.9991, 0.9997, 0.9999],
        [0.9991, 0.9997, 0.9999],
        [0.9991, 0.9997, 0.9999]])