# Install and import torch-geometric

In [1]:
import torch

CUDA_AVAILABLE = torch.cuda.is_available()

try:
    import torch_geometric

except ImportError as import_exc:  # Pip Installation via pure python code

    import sys
    import subprocess

    TORCH = torch.__version__.split('+')[0]
    CUDA = 'cu' + torch.version.cuda.replace('.', '') if CUDA_AVAILABLE else 'cpu'

    try:
        subprocess.run([
            sys.executable, '-m', 'pip', 'install',
            f'torch-scatter==latest+{CUDA}',
            f'torch-sparse==latest+{CUDA}',
            f'torch-cluster==latest+{CUDA}',
            f'torch-spline-conv==latest+{CUDA}',
            f'torch-geometric',
            '-f', f'https://pytorch-geometric.com/whl/torch-{TORCH}.html'
        ], check=True)

    # check=True will raise CalledProcessError after unsuccessful installation
    except subprocess.CalledProcessError as proc_exc:
        raise import_exc from proc_exc

import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.utils import accuracy, add_self_loops, to_dense_adj
import torch_geometric.nn as gnn
import torch_geometric.transforms as T

import os
import glob

device = torch.device('cuda' if CUDA_AVAILABLE else 'cpu')

# Load and preprocess data

In [2]:
# load Cora dataset
dataset = Planetoid(root='./datasets/Cora', name='Cora')
data = dataset[0].to(device)

in_dim = dataset.num_node_features
class_cardinality = dataset.num_classes
node_features = data.x
A = data.edge_index

# split dataset
train_node = data.train_mask
train_target = data.y[data.train_mask]
valid_node = data.val_mask
valid_target = data.y[data.val_mask]
test_node = data.test_mask
test_target = data.y[data.test_mask]

In [3]:
A = to_dense_adj(A).squeeze(0)
self_loop = torch.eye(*A.shape).to(A)
A += self_loop  # add self-loops

deg = A.sum(1)
deg_inv = deg.pow(-0.5)
deg_inv.masked_fill_(deg_inv == torch.inf, 0.)
deg_inv = deg_inv * self_loop

A = deg_inv.mm(A).mm(deg_inv)
A

tensor([[0.2500, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.2500, 0.2041,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.2041, 0.1666,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 0.4999, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.2000, 0.2000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.2000, 0.2000]],
       device='cuda:0')

# Implement GAT model

In [4]:
class GATLayer(nn.Module):

    def __init__(self, in_features, out_features, dropout, alpha, concat=True):

        super().__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.alpha = alpha
        self.concat = concat
        self.dropout = dropout

        self.W = nn.Parameter(torch.zeros(in_features, out_features))
        self.a = nn.Parameter(torch.zeros(2 * out_features, 1))

        self.leaky_relu = nn.LeakyReLU(alpha)
        self.dropout = nn.Dropout(dropout)

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.W, gain=1.414)
        nn.init.xavier_uniform_(self.a, gain=1.414)

    def forward(self, x, adj):

        h = torch.mm(x, self.W)
        N = h.shape[0]
        a_input = torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2 * self.out_features)
        e = torch.matmul(a_input, self.a).squeeze(2)
        e = self.leaky_relu(e)

        zero_vec = -9e15 * torch.ones_like(e)
        attention = torch.where(adj > 0, e, zero_vec)
        attention = attention.softmax(dim=1)
        attention = self.dropout(attention)
        h_prime = torch.matmul(attention, h)
        if self.concat:
            return F.elu(h_prime)
        else:
            return h_prime


In [5]:
class GNNNet(nn.Module):

    def __init__(self, nfeat, nhid, nheads, nclass, dropout, alpha=0.2):
        super().__init__()
        self.layer1_heads = nn.ModuleList([
            GATLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True)
            for _ in range(nheads)
        ])
        self.layer2 = GATLayer(nhid * nheads, nclass, dropout=dropout, alpha=alpha, concat=False)
        self.dropout = nn.Dropout(dropout)
        self.elu = nn.ELU()

    def forward(self, x, adj):
        x = self.dropout(x)
        x = torch.cat([head(x, adj) for head in self.layer1_heads], dim=1)
        x = self.dropout(x)
        x = self.layer2(x, adj)
        x = self.elu(x)
        return x.log_softmax(dim=1)


In [None]:
# Another Implementation with Torch Geometric

class GATNet(nn.Module):

    def __init__(self, nfeat, nhid, num_head, nclass, dropout):
        super().__init__()
        self.layer0 = gnn.GATConv(nfeat, nhid, heads=num_head, concat=True, dropout=dropout)
        self.layer1 = gnn.GATConv(nhid * num_head, nclass, concat=False, dropout=dropout)
        self.dropout = nn.Dropout(dropout)
        self.relu = nn.ReLU()

    def forward(self, x, adj):
        x = self.dropout(x)
        x = self.layer0(x, adj)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.layer1(x, adj)
        return x.log_softmax(dim=1)


In [6]:
hidden_dim = 32  # We set the number of hid. dim as 8 for running on Colab!
dropout_rate = 0.75
lr = 5e-3
weight_decay = 5e-4
epochs = 10000
patience = 200
num_head = 16  # We set the number of heads as 1 for running on Colab!

In [7]:
loss_values = []
bad_counter = 0
best = epochs + 1
best_epoch = 0

model = GNNNet(in_dim, hidden_dim, num_head, class_cardinality, dropout_rate).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

In [8]:
def step(epoch):

    log = "[Epoch] {:0>3} ".format(epoch)

    model.train().zero_grad()

    log_logits = model(node_features, A)
    train_loss = F.nll_loss(log_logits[train_node], train_target)
    train_acc = accuracy(log_logits[train_node].argmax(1), train_target)

    train_loss.backward()
    optimizer.step()

    log += " [Train] Loss: {:6.4f}, Accuracy: {:6.4f}".format(train_loss, train_acc)

    model.eval()

    with torch.no_grad():
        log_logits = model(node_features, A)
        valid_loss = F.nll_loss(log_logits[valid_node], valid_target)
        valid_acc = accuracy(log_logits[valid_node].argmax(1), valid_target)

    log += " [Valid] Loss: {:6.4f}, Accuracy: {:6.4f}".format(valid_loss, valid_acc)
    print(log)

    return valid_loss.item()


@torch.no_grad()
def compute_test():

    model.eval()

    log_logits = model(node_features, A)
    test_loss = F.nll_loss(log_logits[valid_node], valid_target)
    test_acc = accuracy(log_logits[valid_node].argmax(1), valid_target)

    return test_loss.item(), test_acc

In [9]:
try:

    for e in range(epochs):
        loss_values.append(step(e))
        torch.save(model.state_dict(), "%s.pth" % e)
        if loss_values[-1] < best:
            best = loss_values[-1]
            best_epoch = e
            bad_counter = 0
        else:
            bad_counter += 1
        if bad_counter == patience:
            break

    print("Optimization Finished!\nLoading %sth model\n" % best_epoch)
    model.load_state_dict(torch.load("%s.pth" % best_epoch))

    best_test_loss, best_test_acc = compute_test()
    print("[Test] Loss: {:6.4f}, Accuracy: {:6.4f}".format(best_test_loss, best_test_acc))

finally:
    for file in glob.iglob('*.pth'):
        os.remove(file)


[Epoch] 000  [Train] Loss: 2.7472, Accuracy: 0.1786 [Valid] Loss: 1.9477, Accuracy: 0.1880
[Epoch] 001  [Train] Loss: 2.7749, Accuracy: 0.1500 [Valid] Loss: 1.9002, Accuracy: 0.2420
[Epoch] 002  [Train] Loss: 2.8936, Accuracy: 0.1143 [Valid] Loss: 1.8551, Accuracy: 0.2980
[Epoch] 003  [Train] Loss: 2.4070, Accuracy: 0.1857 [Valid] Loss: 1.8171, Accuracy: 0.3600
[Epoch] 004  [Train] Loss: 2.1712, Accuracy: 0.2214 [Valid] Loss: 1.7843, Accuracy: 0.3880
[Epoch] 005  [Train] Loss: 2.4434, Accuracy: 0.2214 [Valid] Loss: 1.7535, Accuracy: 0.4380
[Epoch] 006  [Train] Loss: 2.3732, Accuracy: 0.2214 [Valid] Loss: 1.7272, Accuracy: 0.4680
[Epoch] 007  [Train] Loss: 2.0766, Accuracy: 0.3000 [Valid] Loss: 1.7061, Accuracy: 0.4760
[Epoch] 008  [Train] Loss: 2.5431, Accuracy: 0.2286 [Valid] Loss: 1.6873, Accuracy: 0.4860
[Epoch] 009  [Train] Loss: 2.3309, Accuracy: 0.2643 [Valid] Loss: 1.6718, Accuracy: 0.4920
[Epoch] 010  [Train] Loss: 2.1841, Accuracy: 0.2357 [Valid] Loss: 1.6590, Accuracy: 0.5140