### GCN => GAT 

- GCN：Graph Conv Networks, GAT：Graph Attention Networks
- 都是用来学习 target node 的 representation
- gcn
    - $d_v$：表示节点的degree（即有多少条边连接），注意包含自己；
$$
\begin{align}
\mathbf{h}_{\mathcal{N}(v)} &= \sum_{u \in \mathcal{N}(v)} w_{u,v} \mathbf{h}_u \\
&= \sum_{u \in \mathcal{N}(v)} \sqrt{\frac{1}{d_v}} \sqrt{\frac{1}{d_u}} \mathbf{h}_u \\
&= \sqrt{\frac{1}{d_v}} \sum_{u \in \mathcal{N}(v)} \sqrt{\frac{1}{d_u}} \mathbf{h}_u
\end{align}
$$

- gat
    - https://www.youtube.com/watch?v=SnRfBfXwLuY

$$
\begin{align}
\mathbf{h}_{\mathcal{N}(v)} &= \sum_{u \in \mathcal{N}(v)} \underbrace{\text{softmax}_u \left( a(\mathbf{h}_u, \mathbf{h}_v) \right)}_{\alpha_{u,v}} \mathbf{h}_u \\
\alpha_{u,v} &= \frac{\exp\left(a(\mathbf{h}_u, \mathbf{h}_v)\right)}{\sum_{k \in \mathcal{N}(v)} \exp\left(a(\mathbf{h}_k, \mathbf{h}_v)\right)}
\end{align}
$$

### GAT

https://github.com/AntonioLonga/PytorchGeometricTutorial/blob/main/Tutorial3/Tutorial3.pdf

- input: a set of node features, $\mathbf{h} = \{ \bar{h}_1, \bar{h}_2, \dots, \bar{h}_n \}, \quad \bar{h}_i \in \mathbb{R}^{F}$
- output: a set of node features, $\mathbf{h'} = \{ \bar{h'}_1, \bar{h'}_2, \dots, \bar{h'}_n \}, \quad \bar{h'}_i \in \mathbb{R}^{F'}$
- GAT
    - by a parameterized linear transformation to every node
        - $\mathbf W\cdot \bar h_i, \mathbf W\in \mathbf R^{F'\times F}$
    - self attention:
        - $a: \mathbf R^{F'}\times \mathbf R^{F'} \rightarrow R$
        - $e_{ij}=a(\mathbf W\cdot \bar h_i,\mathbf W\cdot \bar h_j)$：specify the Importance of node j's features  to node i
    - normalization
        - $\alpha_{ij}=\text{softmax}(e_{ij})=\frac{\exp(e_{ij})}{\sum_{k\in N(i)}\exp(e_{ik})}$

### coding

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class GATLayer(nn.Module):
    """
    Simple PyTorch Implementation of the Graph Attention layer.
    """
    def __init__(self):
        super(GATLayer, self).__init__()
      
    def forward(self, input, adj):
        print("")

In [4]:
in_features = 5
out_features = 2
nb_nodes = 3

X = torch.rand(nb_nodes, in_features) 
W = nn.Parameter(torch.zeros(size=(in_features, out_features))) #xavier paramiter inizializator
nn.init.xavier_uniform_(W.data, gain=1.414)

X.shape, W.shape

(torch.Size([3, 5]), torch.Size([5, 2]))

In [6]:
h = torch.mm(X, W)
N = h.shape[0]
h.shape

torch.Size([3, 2])

In [8]:
h.repeat(1, N).view(N*N, -1)

tensor([[-0.8332, -1.1838],
        [-0.8332, -1.1838],
        [-0.8332, -1.1838],
        [-0.4022, -0.7821],
        [-0.4022, -0.7821],
        [-0.4022, -0.7821],
        [ 0.3184, -1.1586],
        [ 0.3184, -1.1586],
        [ 0.3184, -1.1586]], grad_fn=<ViewBackward0>)

In [10]:
h.repeat(N, 1)

tensor([[-0.8332, -1.1838],
        [-0.4022, -0.7821],
        [ 0.3184, -1.1586],
        [-0.8332, -1.1838],
        [-0.4022, -0.7821],
        [ 0.3184, -1.1586],
        [-0.8332, -1.1838],
        [-0.4022, -0.7821],
        [ 0.3184, -1.1586]], grad_fn=<RepeatBackward0>)

In [11]:
torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1)

tensor([[-0.8332, -1.1838, -0.8332, -1.1838],
        [-0.8332, -1.1838, -0.4022, -0.7821],
        [-0.8332, -1.1838,  0.3184, -1.1586],
        [-0.4022, -0.7821, -0.8332, -1.1838],
        [-0.4022, -0.7821, -0.4022, -0.7821],
        [-0.4022, -0.7821,  0.3184, -1.1586],
        [ 0.3184, -1.1586, -0.8332, -1.1838],
        [ 0.3184, -1.1586, -0.4022, -0.7821],
        [ 0.3184, -1.1586,  0.3184, -1.1586]], grad_fn=<CatBackward0>)

In [12]:
torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2 * out_features)

tensor([[[-0.8332, -1.1838, -0.8332, -1.1838],
         [-0.8332, -1.1838, -0.4022, -0.7821],
         [-0.8332, -1.1838,  0.3184, -1.1586]],

        [[-0.4022, -0.7821, -0.8332, -1.1838],
         [-0.4022, -0.7821, -0.4022, -0.7821],
         [-0.4022, -0.7821,  0.3184, -1.1586]],

        [[ 0.3184, -1.1586, -0.8332, -1.1838],
         [ 0.3184, -1.1586, -0.4022, -0.7821],
         [ 0.3184, -1.1586,  0.3184, -1.1586]]], grad_fn=<ViewBackward0>)

In [13]:
a_input = torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2 * out_features)


In [14]:
leakyrelu = nn.LeakyReLU(0.2)  # LeakyReLU


In [15]:
a = nn.Parameter(torch.zeros(size=(2*out_features, 1))) #xavier paramiter inizializator
nn.init.xavier_uniform_(a.data, gain=1.414)
print(a.shape)

torch.Size([4, 1])


In [16]:
print(a_input.shape,a.shape)
print("")
print(torch.matmul(a_input,a).shape)
print("")
print(torch.matmul(a_input,a).squeeze(2).shape)

torch.Size([3, 3, 4]) torch.Size([4, 1])

torch.Size([3, 3, 1])

torch.Size([3, 3])
