In [None]:
import torch
print(torch.__version__)

1.12.0+cu113


In [None]:
!pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.12.0+cu113.html

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://data.pyg.org/whl/torch-1.12.0+cu113.html
Collecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-1.12.0%2Bcu113/torch_scatter-2.0.9-cp37-cp37m-linux_x86_64.whl (7.9 MB)
[K     |████████████████████████████████| 7.9 MB 2.9 MB/s 
[?25hCollecting torch-sparse
  Downloading https://data.pyg.org/whl/torch-1.12.0%2Bcu113/torch_sparse-0.6.14-cp37-cp37m-linux_x86_64.whl (3.5 MB)
[K     |████████████████████████████████| 3.5 MB 39.1 MB/s 
[?25hCollecting torch-cluster
  Downloading https://data.pyg.org/whl/torch-1.12.0%2Bcu113/torch_cluster-1.6.0-cp37-cp37m-linux_x86_64.whl (2.4 MB)
[K     |████████████████████████████████| 2.4 MB 60.2 MB/s 
[?25hCollecting torch-spline-conv
  Downloading https://data.pyg.org/whl/torch-1.12.0%2Bcu113/torch_spline_conv-1.2.1-cp37-cp37m-linux_x86_64.whl (709 kB)
[K     |████████████████████████████████| 709 kB 60.6 MB/s

In [None]:
# torch lib
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data

# others
import numpy as np
import networkx as nx
import os
from tqdm import tqdm
import pandas as pd
import matplotlib.pyplot as plt

In [None]:
# GraphAttentionLayer
class GraphAttentionLayer(nn.Module):
    
    def __init__(self, in_features, out_features, dropout, alpha, concat=True):
        super(GraphAttentionLayer, self).__init__()
        self.in_features = in_features   
        self.out_features = out_features   
        self.dropout = dropout    
        self.alpha = alpha     
        self.concat = concat   
        
        self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))  
        nn.init.xavier_uniform_(self.W.data, gain=1.414)  # xavier
        self.a = nn.Parameter(torch.zeros(size=(2*out_features, 1)))
        nn.init.xavier_uniform_(self.a.data, gain=1.414)   # xavier
        
        self.leakyrelu = nn.LeakyReLU(self.alpha)
    
    def forward(self, inp, adj):
        
        h = torch.mm(inp, self.W)   # [N, out_features]
        N = h.size()[0]    # N 
        
        a_input = torch.cat([h.repeat(1, N).view(N*N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2*self.out_features)
        # [N, N, 2*out_features]
        e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2))
        # [N, N, 1] => [N, N] 
        
        zero_vec = -1e12 * torch.ones_like(e)    
        attention = torch.where(adj>0, e, zero_vec)   # [N, N]
        attention = F.softmax(attention, dim=1)  
        attention = F.dropout(attention, self.dropout, training=self.training)   # dropout
        h_prime = torch.matmul(attention, h)  # [N, N].[N, out_features] => [N, out_features]
        if self.concat:
            return F.elu(h_prime)
        else:
            return h_prime 
    
    def __repr__(self):
        return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'

In [None]:
# GAT Model
class GAT(nn.Module):
    
    def __init__(self, n_feat, n_hid, n_class, dropout, alpha, n_heads):
        super(GAT, self).__init__()
        self.dropout = dropout 
        self.attentions = [GraphAttentionLayer(n_feat, n_hid, dropout=dropout, alpha=alpha, concat=True) for _ in range(n_heads)]
        for i, attention in enumerate(self.attentions):
            self.add_module('attention_{}'.format(i), attention) 
        self.out_att = GraphAttentionLayer(n_hid * n_heads, n_class, dropout=dropout,alpha=alpha, concat=False)
    
    def forward(self, x, adj):
        x = F.dropout(x, self.dropout, training=self.training)   # dropout
        x = torch.cat([att(x, adj) for att in self.attentions], dim=1)  
        x = F.dropout(x, self.dropout, training=self.training)   # dropout
        x = F.elu(self.out_att(x, adj))   
        return F.log_softmax(x, dim=1)