In [6]:
import torch
from torch import nn
import math

In [7]:
num_nodes = 4
num_hidden = 5
batch_size = 1
attention_heads = 1

## Transformer Version

In [83]:
# Set pytorch random seed
torch.manual_seed(0);

In [84]:
linear_q = nn.Linear(num_hidden, num_hidden).to("cuda")
linear_k = nn.Linear(num_hidden, num_hidden).to("cuda")
linear_v = nn.Linear(num_hidden, num_hidden).to("cuda")
softmax = nn.Softmax(dim=-1).to("cuda")
linear_out = nn.Linear(num_hidden, num_hidden).to("cuda")

In [85]:
h = torch.randn(batch_size, attention_heads, num_nodes, num_hidden).to("cuda")

In [86]:
q = linear_q(h)
k = linear_k(h)
v = linear_v(h)

In [87]:
dot_product = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(num_hidden)

In [88]:
dot_product

tensor([[[[ 0.2910, -0.1490, -0.5712,  0.1110],
          [ 0.0089,  0.8476,  1.3841, -0.5947],
          [-0.5246,  0.4699,  2.5532, -1.0397],
          [ 0.2927, -0.4655, -0.7424,  0.1140]]]], device='cuda:0',
       grad_fn=<DivBackward0>)

In [89]:
attn_weights = softmax(dot_product)
attn_output = torch.matmul(attn_weights, v)

In [90]:
attn_weights

tensor([[[[0.3446, 0.2220, 0.1455, 0.2879],
          [0.1279, 0.2960, 0.5061, 0.0700],
          [0.0384, 0.1039, 0.8347, 0.0230],
          [0.3759, 0.1761, 0.1335, 0.3144]]]], device='cuda:0',
       grad_fn=<SoftmaxBackward0>)

In [5]:
def transformer_update(h):
    # apply linear layers to compute query, key, and value
    q = linear_q(h)
    k = linear_k(h)
    v = linear_v(h)

    # compute dot product of query and key, and apply softmax
    dot_product = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(num_hidden)
    attn_weights = softmax(dot_product)

    # apply dot product of attention weights and value
    attn_output = torch.matmul(attn_weights, v)
    
    return attn_output

## GNN Version

In [92]:
from torch_scatter import scatter_add
from torch_geometric import utils
import warnings
warnings.filterwarnings("ignore")

In [93]:
# Set pytorch random seed
torch.manual_seed(0);

In [94]:
linear_q = nn.Linear(num_hidden, num_hidden).to("cuda")
linear_k = nn.Linear(num_hidden, num_hidden).to("cuda")
linear_v = nn.Linear(num_hidden, num_hidden).to("cuda")
softmax = nn.Softmax(dim=-1).to("cuda")
linear_out = nn.Linear(num_hidden, num_hidden).to("cuda")

In [95]:
h = torch.randn(batch_size, attention_heads, num_nodes, num_hidden).to("cuda")
h = h.squeeze()
edges = torch.stack(torch.meshgrid(torch.arange(num_nodes), torch.arange(num_nodes))).reshape(2, -1).to("cuda")

In [96]:
q = linear_q(h)
k = linear_k(h)
v = linear_v(h)

In [97]:
incoming, outgoing = edges

In [98]:
dot_product = (q[incoming]*k[outgoing]).sum(dim=-1) / math.sqrt(num_hidden)

In [99]:
attn_weights = utils.softmax(dot_product, incoming)

In [101]:
attn_output = scatter_add(attn_weights[:, None]*v[outgoing], incoming, dim=0, dim_size=h.shape[0])

In [102]:
attn_output

tensor([[ 0.0539, -0.2314, -0.7193,  0.2311,  0.5558],
        [ 0.0022,  0.2890, -0.1338, -0.1758,  0.5741],
        [-0.1504,  0.6230,  0.3954, -0.5747,  0.5293],
        [ 0.0375, -0.2709, -0.7365,  0.2375,  0.5486]], device='cuda:0',
       grad_fn=<ScatterAddBackward0>)

In [8]:
def gnn_update(h, edges):
    start, end = edges
    
    q = linear_q(h)
    k = linear_k(h)
    v = linear_v(h)
    
    dot_product = (q[start]*k[end]).sum(dim=-1)
    attn_output = scatter_add(dot_product[:, None]*v[start], end, dim=0, dim_size=h.shape[0])
    
    return attn_output