In [3]:
import torch
#!pip install torch_geometric
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures

dataset = Planetoid(root='./data', name='Cora', transform=NormalizeFeatures())
data = dataset[0]

print(data)
print("Nodes:", data.num_nodes)
print("Edges:", data.edge_index.shape[1])
print("Features per node:", dataset.num_features)
print("Num classes:", dataset.num_classes)
print(data.x.shape)
print(data.y.shape)
print ("Train/Val/Test sizes:", int(data.train_mask.sum()), int(data.val_mask.sum()), int(data.test_mask.sum()))

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv

class GAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads=8, dropout=0.6):
      super(GAT, self).__init__()
      self.dropout = nn.Dropout(dropout)
      self.conv1 = GATConv(in_channels, hidden_channels, heads=heads, dropout=dropout)
      self.conv2 = GATConv(hidden_channels * heads, out_channels, heads=1, concat=False, dropout=dropout)

    def forward(self, x, edge_index):
      x = self.dropout(x)
      x = self.conv1(x, edge_index)
      x = F.elu(x)
      x = self.dropout(x)
      x = self.conv2(x, edge_index)
      return x

def main():
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  graph = data.to(device)
  model = GAT(in_channels=dataset.num_features,
              hidden_channels=8,
              out_channels=dataset.num_classes,
              heads=8,
              dropout=0.6).to(device)
  optimizer = torch.optim.Adam(model.parameters(), lr=5e-3, weight_decay=5e-4)
  criterion = torch.nn.CrossEntropyLoss()

  def train_epoch():
    model.train()
    optimizer.zero_grad()
    logits = model(graph.x, graph.edge_index)
    loss = criterion(logits[graph.train_mask], graph.y[graph.train_mask])
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
    optimizer.step()
    return loss.item()

  def accuracy(mask):
    model.eval()
    with torch.no_grad():
      logits = model(graph.x, graph.edge_index)
      pred = logits.argmax(dim=1)
      acc = (pred[mask] == graph.y[mask]).float().mean().item()
      return acc

  best_val = 0.0
  best_state = None
  test_at_best = None

  for epoch in range(1, 201):
    loss = train_epoch()
    train_acc = accuracy(graph.train_mask)
    val_acc = accuracy(graph.val_mask)
    test_acc = accuracy(graph.test_mask)

    if val_acc > best_val:
      best_val = val_acc
      best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
      test_at_best = accuracy(data.test_mask)
      print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}')
    else:
      print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}')

  if best_state is not None:
    model.load_state_dict(best_state)
  final_test= accuracy(graph.test_mask)
  print(f"Best Val {best_val:.3f} | Final Test {test_acc:.3f}")


if __name__ == "__main__":
  main()




Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])
Nodes: 2708
Edges: 10556
Features per node: 1433
Num classes: 7
torch.Size([2708, 1433])
torch.Size([2708])
Train/Val/Test sizes: 140 500 1000
Epoch: 001, Loss: 1.9486, Train: 0.2714, Val: 0.1880, Test: 0.2140
Epoch: 002, Loss: 1.9452, Train: 0.2571, Val: 0.1080, Test: 0.1410
Epoch: 003, Loss: 1.9349, Train: 0.3143, Val: 0.1280, Test: 0.1500
Epoch: 004, Loss: 1.9304, Train: 0.5000, Val: 0.2880, Test: 0.3030
Epoch: 005, Loss: 1.9223, Train: 0.5857, Val: 0.3820, Test: 0.3930
Epoch: 006, Loss: 1.9145, Train: 0.6286, Val: 0.4700, Test: 0.4740
Epoch: 007, Loss: 1.9143, Train: 0.6286, Val: 0.4600, Test: 0.4740
Epoch: 008, Loss: 1.9032, Train: 0.6929, Val: 0.4900, Test: 0.4840
Epoch: 009, Loss: 1.8980, Train: 0.7286, Val: 0.4920, Test: 0.4880
Epoch: 010, Loss: 1.8936, Train: 0.7714, Val: 0.5060, Test: 0.5170
Epoch: 011, Loss: 1.8802, Train: 0.8143, Val: 0.5340, Test: 0.5510
Epoch: 012, 

In [None]:
# Test to make sure out.shape matches data.num_nodes, and dataset.num_classes
#model = GAT(in_channels=dataset.num_features,
#            hidden_channels=8,
#            out_channels=dataset.num_classes,
#            heads=8,
#            dropout=0.6)
#out = model(data.x, data.edge_index)
#print("Expected Shape:", out.shape)  # should be [data.num_nodes, dataset.num_classes]