In [47]:
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.nn import GATConv
import matplotlib.pyplot as plt


In [41]:
dataset_cora = Planetoid(root="./tmp", name="Cora", transform=T.NormalizeFeatures())
dataset_citeseer = Planetoid(root="./tmp", name="CiteSeer", transform=T.NormalizeFeatures())
dataset_pubmed = Planetoid(root="./tmp", name="Pubmed",transform=T.NormalizeFeatures())

data_cora = dataset_cora[0]
data_citeseer = dataset_citeseer[0]
data_pubmed = dataset_pubmed[0]

print("Citation network information")
print("----------------------------\n")
print("Cora: ", data_cora, "\n")
print("Citeseer: ", data_citeseer, "\n")
print("Pubmed: ", data_pubmed, "\n")

In [60]:
def train(model, data, optimizer):
    model.train() 
    optimizer.zero_grad() 
    log_softmax = model(data) 
    labels = data.y 
    nll_loss = F.nll_loss(log_softmax[data.train_mask], labels[data.train_mask])
    nll_loss.backward()
    optimizer.step() 

def compute_accuracy(model, data, mask):
    model.eval()
    logprob = model(data)
    _, y_pred = logprob[mask].max(dim=1)
    y_true=data.y[mask]
    acc = y_pred.eq(y_true).sum()/ mask.sum().float()
    return acc.item()

def test(model, data):
    acc_train = compute_accuracy(model, data, data.train_mask)
    acc_val = compute_accuracy(model, data, data.val_mask)
    return acc_train, acc_val

In [61]:
class GAT(torch.nn.Module):
    def __init__(self, data, heads_layer1, heads_layer2, dropout, dropout_alphas):
        super().__init__()

        self.dropout=dropout
        num_features = data.num_features
        num_classes = len(data.y.unique())

        self.conv1 = GATConv(in_channels=num_features, out_channels=8,
                             heads=heads_layer1, concat=True, negative_slope=0.2, 
                             dropout=dropout_alphas)

        self.conv2 = GATConv(in_channels=8*heads_layer1, out_channels=num_classes, 
                             heads=heads_layer2, concat=False, negative_slope=0.2,
                             dropout=dropout_alphas)

    def forward(self, data):
        x=data.x
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv1(x, data.edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, data.edge_index)

        return F.log_softmax(x, dim=1)

# Cora Implementation

In [62]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_cora_gat = GAT(data=data_cora, heads_layer1=8, heads_layer2=8, dropout=0.6,  dropout_alphas=0.6).to(device)
data_cora= data_cora.to(device)

optimizer = torch.optim.Adam(model_cora_gat.parameters(), lr=0.001, weight_decay=1e-4)
train_acc = []
val_acc = []

for epoch in range(1, 200+1):
    train(model_cora_gat, data_cora, optimizer)
    if epoch %10 ==0:
        acc_train, acc_val = test(model_cora_gat,data_cora)
        train_acc.append(acc_train)
        val_acc.append(acc_val)
        log = 'Epoch: {:03d}, Train: {:.4f}, Val: {:.4f}'
        print(log.format(epoch, acc_train, acc_val))

In [63]:
# Plot the training and validation accuracy curves
epochs = range(10, 201, 10)
plt.plot(epochs, train_acc, label='Training Accuracy')
plt.plot(epochs, val_acc, label='Validation Accuracy')
plt.title('Cora: GAT Model Training Progress')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

# Pubmed Implementation

In [64]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_pubmed_gat = GAT(data=data_pubmed, heads_layer1=8, heads_layer2=8, dropout=0.6,  dropout_alphas=0.6).to(device)
data_pubmed= data_pubmed.to(device)

optimizer = torch.optim.Adam(model_pubmed_gat.parameters(), lr=0.001, weight_decay=1e-4)
train_acc = []
val_acc = []

for epoch in range(1, 200+1):
    train(model_pubmed_gat, data_pubmed, optimizer)
    if epoch %10 ==0:
        acc_train, acc_val = test(model_pubmed_gat,data_pubmed)
        train_acc.append(acc_train)
        val_acc.append(acc_val)
        log = 'Epoch: {:03d}, Train: {:.4f}, Val: {:.4f}'
        print(log.format(epoch, acc_train, acc_val))

In [65]:
# Plot the training and validation accuracy curves
epochs = range(10, 201, 10)
plt.plot(epochs, train_acc, label='Training Accuracy')
plt.plot(epochs, val_acc, label='Validation Accuracy')
plt.title('Pubmed: GAT Model Training Progress')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()
