In [1]:
!pip install torch_geometric
!pip install early-stopping-pytorch

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m37.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1
Collecting early-stopping-pytorch
  Downloading early_stopping_pytorch-1.0.10-py3-none-any.whl.metadata (3.4 kB)
Downloading early_stopping_pytorch-1.0.10-py3-none-any.whl (4.6 kB)
Installing collected packages: early-stopping-pytorch
Successfully installed early-stopping-pytorch-1.0.10


In [2]:
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root="/content/Cora", name="Cora")

Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...
Done!


In [3]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class GCN(torch.nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1 = GCNConv(dataset.num_node_features, 16)
    self.conv2 = GCNConv(16, dataset.num_classes)

  def forward(self, data):
    x, edge_index = data.x, data.edge_index

    x = self.conv1(x, edge_index)
    x = F.relu(x)
    x = F.dropout(x, training=self.training)
    x = self.conv2(x, edge_index)

    return x

In [10]:
from early_stopping_pytorch import EarlyStopping

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN().to(device)
data = dataset[0].to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
early_stopping = EarlyStopping(patience=10, verbose=True)

for epoch in range(300):
  # training
  model.train()
  optimizer.zero_grad()
  out = model(data)
  loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
  loss.backward()
  optimizer.step()

  # validation
  model.eval()
  out = model(data)
  val_loss = F.cross_entropy(out[data.val_mask], data.y[data.val_mask])

  early_stopping(val_loss.detach().to('cpu'), model)

  if early_stopping.early_stop:
    print("Early stopping")
    break

model.load_state_dict(torch.load('checkpoint.pt', weights_only=True))

Validation loss decreased (inf --> 1.900381).  Saving model ...
Validation loss decreased (1.900381 --> 1.834695).  Saving model ...
Validation loss decreased (1.834695 --> 1.759267).  Saving model ...
Validation loss decreased (1.759267 --> 1.685596).  Saving model ...
Validation loss decreased (1.685596 --> 1.615939).  Saving model ...
Validation loss decreased (1.615939 --> 1.542956).  Saving model ...
Validation loss decreased (1.542956 --> 1.465196).  Saving model ...
Validation loss decreased (1.465196 --> 1.387739).  Saving model ...
Validation loss decreased (1.387739 --> 1.313817).  Saving model ...
Validation loss decreased (1.313817 --> 1.245491).  Saving model ...
Validation loss decreased (1.245491 --> 1.180979).  Saving model ...
Validation loss decreased (1.180979 --> 1.122972).  Saving model ...
Validation loss decreased (1.122972 --> 1.068007).  Saving model ...
Validation loss decreased (1.068007 --> 1.016896).  Saving model ...
Validation loss decreased (1.016896 -->

<All keys matched successfully>

In [7]:
model.eval()
pred = torch.softmax(model(data), dim=1).argmax(dim=1)
correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
acc = int(correct) / int(data.test_mask.sum())
print(f'Accuracy: {acc:.4f}')

Accuracy: 0.8110
