# Attention Encoder: Two Equivalent Approaches

The following is a guide to the self-attention encoder update. If you feel most comfortable with Transformers, then start at the top and work down to the `Transformer Output` result. If you feel most comfortable with Graph Neural Networks, then start at the bottom and work up to the `GNN Output` result. Spoiler alert: They are equivalent mathematically. And indeed we expect that: The only difference is that the Transformer is a densely represented operation, the GNN is a sparesly represented operation. Check out the `2-PerformanceComparison` notebook for a comparison of the two approaches in terms of memory and timing.

## Transformer Version

In [None]:
import torch
from torch import nn
import math

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
num_nodes = 4
num_hidden = 5
batch_size = 1
attention_heads = 1

In [3]:
# Set pytorch random seed
torch.manual_seed(0);

In [4]:
linear_q = nn.Linear(num_hidden, num_hidden).to("cuda")
linear_k = nn.Linear(num_hidden, num_hidden).to("cuda")
linear_v = nn.Linear(num_hidden, num_hidden).to("cuda")
softmax = nn.Softmax(dim=-1).to("cuda")
linear_out = nn.Linear(num_hidden, num_hidden).to("cuda")

In [5]:
h = torch.randn(batch_size, attention_heads, num_nodes, num_hidden).to("cuda")

In [6]:
q = linear_q(h)
k = linear_k(h)
v = linear_v(h)

In [7]:
dot_product = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(num_hidden)

In [9]:
attn_weights = softmax(dot_product)
attn_output = torch.matmul(attn_weights, v)

In [13]:
print("------------------- Transformer output ----------------------\n", attn_output)

------------------- Transformer output ----------------------
 tensor([[[[ 0.0539, -0.2314, -0.7193,  0.2311,  0.5558],
          [ 0.0022,  0.2890, -0.1338, -0.1758,  0.5741],
          [-0.1504,  0.6230,  0.3954, -0.5747,  0.5293],
          [ 0.0375, -0.2709, -0.7365,  0.2375,  0.5486]]]], device='cuda:0',
       grad_fn=<UnsafeViewBackward0>)


In [16]:
print(attn_output, "\n ---------------------- GNN output -------------------------")

tensor([[ 0.0539, -0.2314, -0.7193,  0.2311,  0.5558],
        [ 0.0022,  0.2890, -0.1338, -0.1758,  0.5741],
        [-0.1504,  0.6230,  0.3954, -0.5747,  0.5293],
        [ 0.0375, -0.2709, -0.7365,  0.2375,  0.5486]], device='cuda:0',
       grad_fn=<ScatterAddBackward0>) 
 ---------------------- GNN output -------------------------


In [15]:
attn_output = scatter_add(attn_weights[:, None]*v[outgoing], incoming, dim=0, dim_size=h.shape[0])

In [14]:
attn_weights = utils.softmax(dot_product, incoming, dim=0)

In [12]:
dot_product

tensor([ 0.2910, -0.1490, -0.5712,  0.1110,  0.0089,  0.8476,  1.3841, -0.5947,
        -0.5246,  0.4699,  2.5532, -1.0397,  0.2927, -0.4655, -0.7424,  0.1140],
       device='cuda:0', grad_fn=<DivBackward0>)

In [11]:
incoming, outgoing = edges
dot_product = (q[incoming]*k[outgoing]).sum(dim=-1) / math.sqrt(num_hidden)

In [9]:
q = linear_q(h)
k = linear_k(h)
v = linear_v(h)

In [8]:
h = torch.randn(num_nodes, num_hidden).to("cuda")
# h = h.squeeze()
edges = torch.stack(torch.meshgrid(torch.arange(num_nodes), torch.arange(num_nodes))).reshape(2, -1).to("cuda")

In [7]:
linear_q = nn.Linear(num_hidden, num_hidden).to("cuda")
linear_k = nn.Linear(num_hidden, num_hidden).to("cuda")
linear_v = nn.Linear(num_hidden, num_hidden).to("cuda")
softmax = nn.Softmax(dim=-1).to("cuda")
linear_out = nn.Linear(num_hidden, num_hidden).to("cuda")

In [6]:
# Set pytorch random seed
torch.manual_seed(0);

In [5]:
num_nodes = 4
num_hidden = 5

In [4]:
import torch
import torch.nn as nn
from torch_scatter import scatter_add
from torch_geometric import utils
import math
import warnings
warnings.filterwarnings("ignore")

## GNN Version