### Cora dataset

In [34]:
# cora dataset = 1 graph of citation network: node classification dataset
import dgl.data
dataset=dgl.data.CoraGraphDataset()
g=dataset[0]

# train_mask, val_mask, test_mask = boolean indices for train, val, test
train_mask=g.ndata['train_mask']
val_mask=g.ndata['val_mask']
test_mask=g.ndata['test_mask']

print(f"num_train_nodes: {sum(train_mask)} | num_val_nodes: {sum(val_mask)} | num_test_nodes: {sum(test_mask)}")

  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.
num_train_nodes: 140 | num_val_nodes: 500 | num_test_nodes: 1000


In [35]:
# hyperparameters (used on models below)
in_dim=g.ndata['feat'].shape[-1]
hidden_dim=16
out_dim=dataset.num_classes
num_heads=2
print(f"in_dim: {in_dim} | hidden_dim: {hidden_dim} | out_dim: {out_dim} | num_heads: {num_heads}")

in_dim: 1433 | hidden_dim: 16 | out_dim: 7 | num_heads: 2


### 1. Graph Attention Network
Preview on GCN and GraphSage:
1. GCN update equation
$$h_i^{(l+1)}=W\sum_{j\in N(i)}\dfrac{1}{\sqrt{\deg(i)\deg(j)}}h_j^{(l)}+b$$
2. GraphSage update equation:
$$h_i^{(l+1)}= W\text{concat}(h_i^{(l)},h_{N(i)}^{(l+1)})+b \ \text{with} \ h_{N(i)}^{(l+1)}=\text{Mean}\{h_j^{(l)}: j\in N(i)\}$$
3. **GAT** update equation
\begin{align}
h_i^{(l+1)} &=W\sum_{j\in N(i)}\alpha_{ij}^{(l)}h_j^{(l)} \ \text{with} \\
\alpha_{ij}^{(l)}&=\text{softmax}_j(e_{ik}^{(l)}: k \in N(i)) \\
e_{ij}^{(l)}&= \text{LeakyReLU}\left(a^{(l)T}\cdot\text{concat}(Wh_i^{(l)} , Wh_j^{(l)})\right)
\end{align}
with $e_{ij}$ = un-normalized attention of edge $\{i,j\}$, $\alpha_{ij}$ = normalized attention coefficients.

### 2. GAT layer from scratch

#### (a) Single-head GAT Layer

In [36]:
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F 

# flow: (1) Apply linear projection to all nodes via matrix W
#       (2) Compute attention e_ij for all edges {i,j}
#       (3) Send info through edges and update node features 

class GAT_layer(nn.Module):
    def __init__(self,in_dim,out_dim):
        super().__init__()

        # matrix W
        self.linear_proj=nn.Linear(in_dim,out_dim,bias=False)

        # attention params a
        self.attn_param=nn.Linear(2*out_dim,1,bias=False)

        self.reset_parameters()

    def reset_parameters(self):
        # initialize learnable parameters by xavier_normal with "relu gain"
        nn.init.xavier_normal_(self.linear_proj.weight, gain=nn.init.calculate_gain('relu'))
        nn.init.xavier_normal_(self.attn_param.weight,  gain=nn.init.calculate_gain('relu'))
    
    # user-defined function for equation (3)
    #              return dict {'e': e} stored in edges.data: via g.apply_edges
    def edge_attention(self,edges): 
        # concatenate src and dst node features
        concat=torch.cat([edges.src['W_h'], edges.dst['W_h']],dim=1) 
        e=self.attn_param(concat)                                    
        return {'e': F.leaky_relu(e)}                                

    # message_func(self, edges): send info through edges
    #             return all information needed to update node features 
    #             {'W_h':W_h, 'e': e} is stored in nodes.mailbox
    def message_func(self,edges):
        return {'W_h': edges.src['W_h'], 'e': edges.data['e']}  # nodes.mailbox['W_h']=[E,in_dim]
        

    # reduce_func(self,nodes): update node features in equation (1)
    #                          return {'h':h} that is stored in 
    def reduce_func(self,nodes):
        # attention coefficients
        alpha=F.softmax(nodes.mailbox['e'],dim=1)               

        # take weighted sum of the neighbors
        h_N=torch.sum(alpha * nodes.mailbox['W_h'],dim=1)       
        
        return {'h_N': h_N} # new node features
        

    def forward(self,g,h):
        
        with g.local_scope():

            g.ndata['h']=h                                      # [N,in_dim]
            
            W_h=self.linear_proj(h)                             # [N,out_dim]
            g.ndata['W_h']=W_h                                  # [N,out_dim]
            
            # compute attention e_ij to every edges
            g.apply_edges(self.edge_attention) 

            # send info through edges and update node features
            g.update_all(self.message_func,self.reduce_func)

            return g.ndata['h_N']
 

In [37]:
# sanity check
lay=GAT_layer(in_dim,out_dim)
out=lay(g,g.ndata['feat'])
print(out.shape)

torch.Size([2708, 7])


#### (b) Multi-head GAT Layer

In [38]:
class MultiHeadGAT_Layer(nn.Module):
    def __init__(self,in_dim,out_dim,num_heads):
        super().__init__()
        self.num_heads=num_heads     # num_heads=2
        self.heads=nn.ModuleList([GAT_layer(in_dim,out_dim) for _ in range(num_heads)])

    def forward(self,g,h):
        # concatenate individual outputs
        out=torch.cat([self.heads[i](g,h) for i in range(self.num_heads)],dim=1)
        return out  # [N,out_dim*num_heads]
    

In [39]:
# sanity check
layer_multihead=MultiHeadGAT_Layer(in_dim,out_dim,num_heads)
output=layer_multihead(g,g.ndata['feat'])
print(output.shape)

torch.Size([2708, 14])


#### (c) Model
input -> MultiHeadGAT_Layer1 -> elu -> MultiHeadGAT_Layer2 -> out

In [40]:
class GAT_Net(nn.Module):
    def __init__(self,in_dim,hidden_dim,out_dim,num_heads):
        super().__init__()
        self.gat1=MultiHeadGAT_Layer(in_dim,hidden_dim,num_heads)    # [N,hidden_dim*num_heads]
        self.gat2=MultiHeadGAT_Layer(hidden_dim*num_heads,out_dim,1) # [N,out_dim*1]

    def forward(self,g,h):
        h=self.gat1(g,h)
        h=F.elu(h)
        h=self.gat2(g,h)
        return h

    

### 3. Train

In [41]:
# train loop

def train(model, graph, loss_fn, optimizer):
    model.train()
    features=graph.ndata['feat']
    labels=graph.ndata['label']
    
    # forward and backward on train_mask    
    optimizer.zero_grad()

    # prediction on the whole graph
    logits=model(graph, features) 
    
    # only consider train_mask
    loss=loss_fn(logits[train_mask],labels[train_mask])
    loss.backward()
    optimizer.step()

    # compute accuracy
    preds=logits.argmax(dim=-1)
    acc=(preds[train_mask]==labels[train_mask]).float().mean()
    return loss, acc

@torch.no_grad()
def evaluate(model, graph, loss_fn):
    model.eval()
    features=graph.ndata['feat']
    labels=graph.ndata['label']

    # forward
    logits=model(graph,features)
    loss=loss_fn(logits[val_mask],labels[val_mask])

    # compute accuracy
    preds=logits.argmax(dim=-1)
    acc=(preds[val_mask]==labels[val_mask]).float().mean()
    return loss, acc


In [44]:
torch.manual_seed(1442)

num_epochs=50

# model and optimizer

model=GAT_Net(in_dim,hidden_dim,out_dim,num_heads)
print(f"{sum(p.numel() for p in model.parameters())/1e6} million parameters")

loss_fn=F.cross_entropy
optimizer=torch.optim.AdamW(model.parameters(),lr=0.01)


# train and test
for epoch in range(num_epochs):
    train_loss, train_acc=train(model,g,loss_fn,optimizer)
    val_loss, val_acc=evaluate(model,g,loss_fn)
    if epoch%5==0 or epoch == num_epochs-1:
        print(f"Epoch : {epoch} | train_loss : {train_loss:.4f} | train_acc : {train_acc*100:.2f}% | "
            f" val_loss : {val_loss:.4f} | val_acc : {val_acc*100:.2f}%")

0.046158 million parameters
Epoch : 0 | train_loss : 1.9467 | train_acc : 12.86% |  val_loss : 1.9308 | val_acc : 56.40%
Epoch : 5 | train_loss : 1.7944 | train_acc : 95.71% |  val_loss : 1.8511 | val_acc : 73.40%
Epoch : 10 | train_loss : 1.6095 | train_acc : 95.71% |  val_loss : 1.7515 | val_acc : 74.60%
Epoch : 15 | train_loss : 1.3920 | train_acc : 95.71% |  val_loss : 1.6309 | val_acc : 74.40%
Epoch : 20 | train_loss : 1.1520 | train_acc : 96.43% |  val_loss : 1.4909 | val_acc : 75.40%
Epoch : 25 | train_loss : 0.9071 | train_acc : 96.43% |  val_loss : 1.3380 | val_acc : 75.60%
Epoch : 30 | train_loss : 0.6785 | train_acc : 97.14% |  val_loss : 1.1865 | val_acc : 76.20%
Epoch : 35 | train_loss : 0.4849 | train_acc : 98.57% |  val_loss : 1.0557 | val_acc : 76.60%
Epoch : 40 | train_loss : 0.3351 | train_acc : 98.57% |  val_loss : 0.9618 | val_acc : 76.80%
Epoch : 45 | train_loss : 0.2252 | train_acc : 99.29% |  val_loss : 0.9103 | val_acc : 77.20%
Epoch : 49 | train_loss : 0.1639 |

### 4. GAT from dgl

In [45]:
from dgl.nn.pytorch import GATConv

class GAT_DGL(nn.Module):
    def __init__(self,in_dim,hidden_dim,out_dim,num_heads):
        super().__init__()
        self.gat1=GATConv(in_dim,hidden_dim,num_heads)    # [num_nodes,num_heads,hidden_dim]
        self.gat2=GATConv(hidden_dim*num_heads,out_dim,1) # [num_nodes,1,out_dim]

    def forward(self,g,h):
        h=self.gat1(g,h)
        h=h.reshape(-1,h.shape[-2]*h.shape[-1]) # [num_nodes,num_heads*hidden_dim]
        h=F.elu(h)
        h=self.gat2(g,h)                        # [num_nodes,1,out_dim]
        h=h.squeeze(dim=-2)                     # [num_nodes,out_dim]
        return h
    
# sanity check
gat_dgl=GAT_DGL(in_dim,hidden_dim,out_dim,num_heads)
print(g.ndata['feat'].shape)
output=gat_dgl(g,g.ndata['feat'])
print(output.shape)

torch.Size([2708, 1433])
torch.Size([2708, 7])


In [49]:
torch.manual_seed(1442)

num_epochs=50

# model and optimizer

model=GAT_DGL(in_dim,hidden_dim,out_dim,num_heads)
print(f"{sum(p.numel() for p in model.parameters())/1e6} million parameters")

loss_fn=F.cross_entropy
optimizer=torch.optim.AdamW(model.parameters(),lr=0.01)


# train and test
for epoch in range(num_epochs):
    train_loss, train_acc=train(model,g,loss_fn,optimizer)
    val_loss, val_acc=evaluate(model,g,loss_fn)
    if epoch%5==0 or epoch == num_epochs-1:
        print(f"Epoch : {epoch} | train_loss : {train_loss:.4f} | train_acc : {train_acc*100:.2f}% | "
            f" val_loss : {val_loss:.4f} | val_acc : {val_acc*100:.2f}%")

0.046197 million parameters
Epoch : 0 | train_loss : 1.9469 | train_acc : 15.00% |  val_loss : 1.9190 | val_acc : 34.00%
Epoch : 5 | train_loss : 1.7943 | train_acc : 97.86% |  val_loss : 1.8468 | val_acc : 76.40%
Epoch : 10 | train_loss : 1.6067 | train_acc : 96.43% |  val_loss : 1.7410 | val_acc : 77.60%
Epoch : 15 | train_loss : 1.3828 | train_acc : 97.86% |  val_loss : 1.6176 | val_acc : 77.20%
Epoch : 20 | train_loss : 1.1312 | train_acc : 97.86% |  val_loss : 1.4723 | val_acc : 77.00%
Epoch : 25 | train_loss : 0.8702 | train_acc : 97.86% |  val_loss : 1.3191 | val_acc : 77.00%
Epoch : 30 | train_loss : 0.6273 | train_acc : 98.57% |  val_loss : 1.1743 | val_acc : 77.80%
Epoch : 35 | train_loss : 0.4271 | train_acc : 99.29% |  val_loss : 1.0554 | val_acc : 77.60%
Epoch : 40 | train_loss : 0.2775 | train_acc : 99.29% |  val_loss : 0.9740 | val_acc : 77.00%
Epoch : 45 | train_loss : 0.1740 | train_acc : 100.00% |  val_loss : 0.9282 | val_acc : 75.60%
Epoch : 49 | train_loss : 0.1182 