# Train GAT models for OGBN_arxiv

In [1]:
import torch
import torch.nn.functional as F

def train(model, optimizer, data, epochs):
    model.train()
    criterion = torch.nn.CrossEntropyLoss()

    for _ in range(epochs):
        optimizer.zero_grad()
        out = model(data.x, data.edge_index)
        loss = criterion(out[data.train_mask], data.y[data.train_mask])
        acc = (out[data.train_mask].argmax(dim=1) == data.y[data.train_mask]).sum().item() / data.train_mask.sum().item()
        loss.backward()
        optimizer.step()

    return model, loss, acc

@torch.no_grad()
def test(model, data):
    model.eval()
    criterion = torch.nn.CrossEntropyLoss()
    x, edge_index, y = data.x, data.edge_index, data.y
    out = model(x, edge_index)
    loss = criterion(out[data.test_mask], y[data.test_mask])
    acc = (out[data.test_mask].argmax(dim=1) == y[data.test_mask]).sum().item() / data.test_mask.sum().item()
    return loss, acc

### OGBN-arxiv

In [1]:
import os.path as osp

import torch
from ogb.nodeproppred import Evaluator, PygNodePropPredDataset
import torch_geometric.transforms as T
from torch_geometric.loader import NeighborSampler

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
root = osp.join(osp.dirname(osp.realpath('[Dataset]OGBN_arxiv.ipynb')), '..', 'data', 'arxiv')
# transform = T.Compose([T.NormalizeFeatures()])
dataset = PygNodePropPredDataset('ogbn-arxiv', root)
split_idx = dataset.get_idx_split()
evaluator = Evaluator(name='ogbn-arxiv')
data = dataset[0].to(device)

# make train mask from split_idx['train']
data.train_mask = torch.zeros(data.num_nodes, dtype=torch.bool).to(device)
data.train_mask[split_idx['train']] = True
data.test_mask = torch.zeros(data.num_nodes, dtype=torch.bool).to(device)
data.test_mask[split_idx['test']] = True

out_channels = data.y.max().item() + 1
data.y = data.y.squeeze_()

In [2]:
# torch.save(data, '/workspace/datasets/OGBN_arxiv.pt')

In [16]:
# Train GAT models on the BA-Shapes dataset
from models import GAT_L2_intervention, GAT_L3_intervention

out_channels = data.y.max().item() + 1

# Define several GAT models with 1, 2, 4, 8 attention heads to be used for 'data.pt', and move them to the GPU device (if available)
model1_L2 = GAT_L2_intervention(in_channels=data.num_node_features, hidden_channels=64, out_channels=out_channels, heads=1)
# model2_L2 = GAT_L2_intervention(in_channels=data.num_node_features, hidden_channels=64, out_channels=out_channels, heads=2)
# model4_L2 = GAT_L2_intervention(in_channels=data.num_node_features, hidden_channels=64, out_channels=out_channels, heads=4)
# model8_L2 = GAT_L2_intervention(in_channels=data.num_node_features, hidden_channels=64, out_channels=out_channels, heads=8)
# model1_L3 = GAT_L3_intervention(in_channels=data.num_node_features, hidden_channels=64, out_channels=out_channels, heads=1)
# model2_L3 = GAT_L3_intervention(in_channels=data.num_node_features, hidden_channels=64, out_channels=out_channels, heads=2)
# model4_L3 = GAT_L3_intervention(in_channels=data.num_node_features, hidden_channels=64, out_channels=out_channels, heads=4)
# model8_L3 = GAT_L3_intervention(in_channels=data.num_node_features, hidden_channels=64, out_channels=out_channels, heads=8)

# Move the models to the GPU device (if available)
model1_L2 = model1_L2.to(device)
# model2_L2 = model2_L2.to(device)
# model4_L2 = model4_L2.to(device)
# model8_L2 = model8_L2.to(device)
# model1_L3 = model1_L3.to(device)
# model2_L3 = model2_L3.to(device)
# model4_L3 = model4_L3.to(device)
# model8_L3 = model8_L3.to(device)

"""
Now we can train all the models and compare their performance.
Keep the number of epochs and the learning rate the same for all the models.
"""

# Define the number of epochs
epochs = 1600
# Define the learning rate
lr = 0.01
# Prepare the optimizer
optimizer1_L2 = torch.optim.Adam(model1_L2.parameters(), lr=lr, weight_decay=0)
# optimizer2_L2 = torch.optim.Adam(model2_L2.parameters(), lr=lr, weight_decay=0)
# optimizer4_L2 = torch.optim.Adam(model4_L2.parameters(), lr=lr, weight_decay=0)
# optimizer8_L2 = torch.optim.Adam(model8_L2.parameters(), lr=lr, weight_decay=0)
# optimizer1_L3 = torch.optim.Adam(model1_L3.parameters(), lr=lr, weight_decay=0)
# optimizer2_L3 = torch.optim.Adam(model2_L3.parameters(), lr=lr, weight_decay=0)
# optimizer4_L3 = torch.optim.Adam(model4_L3.parameters(), lr=lr, weight_decay=0)
# optimizer8_L3 = torch.optim.Adam(model8_L3.parameters(), lr=lr, weight_decay=0)

# Train the models
model1_L2, loss1_L2, acc1_L2 = train(model=model1_L2, data=data, optimizer=optimizer1_L2, epochs=epochs)
# model2_L2, loss2_L2, acc2_L2 = train(model=model2_L2, data=data, optimizer=optimizer2_L2, epochs=epochs)
# model4_L2, loss4_L2, acc4_L2 = train(model=model4_L2, data=data, optimizer=optimizer4_L2, epochs=epochs)
# model8_L2, loss8_L2, acc8_L2 = train(model=model8_L2, data=data, optimizer=optimizer8_L2, epochs=epochs)
# model1_L3, loss1_L3, acc1_L3 = train(model=model1_L3, data=data, optimizer=optimizer1_L3, epochs=epochs)
# model2_L3, loss2_L3, acc2_L3 = train(model=model2_L3, data=data, optimizer=optimizer2_L3, epochs=epochs)
# model4_L3, loss4_L3, acc4_L3 = train(model=model4_L3, data=data, optimizer=optimizer4_L3, epochs=epochs)
# model8_L3, loss8_L3, acc8_L3 = train(model=model8_L3, data=data, optimizer=optimizer8_L3, epochs=epochs)

# Test the models
test_loss1_L2, test_acc1_L2 = test(model=model1_L2, data=data)
# test_loss2_L2, test_acc2_L2 = test(model=model2_L2, data=data)
# test_loss4_L2, test_acc4_L2 = test(model=model4_L2, data=data)
# test_loss8_L2, test_acc8_L2 = test(model=model8_L2, data=data)
# test_loss1_L3, test_acc1_L3 = test(model=model1_L3, data=data)
# test_loss2_L3, test_acc2_L3 = test(model=model2_L3, data=data)
# test_loss4_L3, test_acc4_L3 = test(model=model4_L3, data=data)
# test_loss8_L3, test_acc8_L3 = test(model=model8_L3, data=data)

# Print the results
print(f"Model: GAT_Arxiv_2L1H, Loss: {loss1_L2:.4f}, Train Accuracy: {acc1_L2:.4f}, Test Loss: {test_loss1_L2:.4f}, Test Accuracy: {test_acc1_L2:.4f}")
# print(f"Model: GAT_Arxiv_2L2H, Loss: {loss2_L2:.4f}, Train Accuracy: {acc2_L2:.4f}, Test Loss: {test_loss2_L2:.4f}, Test Accuracy: {test_acc2_L2:.4f}")
# print(f"Model: GAT_Arxiv_2L4H, Loss: {loss4_L2:.4f}, Train Accuracy: {acc4_L2:.4f}, Test Loss: {test_loss4_L2:.4f}, Test Accuracy: {test_acc4_L2:.4f}")
# print(f"Model: GAT_Arxiv_2L8H, Loss: {loss8_L2:.4f}, Train Accuracy: {acc8_L2:.4f}, Test Loss: {test_loss8_L2:.4f}, Test Accuracy: {test_acc8_L2:.4f}")
# print(f"Model: GAT_Arxiv_3L1H, Loss: {loss1_L3:.4f}, Train Accuracy: {acc1_L3:.4f}, Test Loss: {test_loss1_L3:.4f}, Test Accuracy: {test_acc1_L3:.4f}")
# print(f"Model: GAT_Arxiv_3L2H, Loss: {loss2_L3:.4f}, Train Accuracy: {acc2_L3:.4f}, Test Loss: {test_loss2_L3:.4f}, Test Accuracy: {test_acc2_L3:.4f}")
# print(f"Model: GAT_Arxiv_3L4H, Loss: {loss4_L3:.4f}, Train Accuracy: {acc4_L3:.4f}, Test Loss: {test_loss4_L3:.4f}, Test Accuracy: {test_acc4_L3:.4f}")
# print(f"Model: GAT_Arxiv_3L8H, Loss: {loss8_L3:.4f}, Train Accuracy: {acc8_L3:.4f}, Test Loss: {test_loss8_L3:.4f}, Test Accuracy: {test_acc8_L3:.4f}")

Model: GAT_Arxiv_2L1H, Loss: 1.3877, Train Accuracy: 0.6029, Test Loss: 1.6097, Test Accuracy: 0.5299


In [17]:
# Save the model locally
torch.save(model1_L2, '/workspace/models/GAT_Arxiv_2L1H.pt')
# torch.save(model2_L2, '/workspace/models/GAT_Arxiv_2L2H.pt')
# torch.save(model4_L2, '/workspace/models/GAT_Arxiv_2L4H.pt')
# torch.save(model8_L2, '/workspace/models/GAT_Arxiv_2L8H.pt')

# torch.save(model1_L3, '/workspace/models/GAT_Arxiv_3L1H.pt')
# torch.save(model2_L3, '/workspace/models/GAT_Arxiv_3L2H.pt')
# torch.save(model4_L3, '/workspace/models/GAT_Arxiv_3L4H.pt')
# torch.save(model8_L3, '/workspace/models/GAT_Arxiv_3L8H.pt')