In [1]:
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv

In [2]:
class GCN(torch.nn.Module):
    def __init__(self, num_features, num_classes, dropout_rate=0.6) -> None:
        super().__init__()
        self.num_features = num_features
        self.num_classes = num_classes
        self.dropout_rate = dropout_rate

        self.conv1 = GCNConv(self.num_features, 16)
        self.conv2 = GCNConv(16, num_classes)

    def forward(self, graph):
        x = graph.x
        edge_index = graph.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout_rate)
        x = self.conv2(x, edge_index)
        x = F.log_softmax(x, dim=1)
        return x

In [3]:
dataset = Planetoid(root='./data/Cora', name='Cora')
print(f'# of graph:    {len(dataset)}')
print(f'# of nodes:    {dataset[0].num_nodes}')
print(f'# of edges:    {dataset[0].num_edges}')
print(f'# of features: {dataset.num_node_features}')
print(f'# of classes:  {dataset.num_classes}')

print(f'Train:      {dataset[0].train_mask.sum().item()}')
print(f'Validation: {dataset[0].val_mask.sum().item()}')
print(f'Test:       {dataset[0].test_mask.sum().item()}')

# of graph:    1
# of nodes:    2708
# of edges:    10556
# of features: 1433
# of classes:  7
Train:      140
Validation: 500
Test:       1000


In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [5]:
model = GCN(dataset.num_features, dataset.num_classes).to(device)
data = dataset[0].to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

In [6]:
model.train()
for epoch in range(200):
    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()

In [7]:
model.eval()
out = model(data).argmax(dim=1)

In [8]:
correct = (out[data.test_mask] == data.y[data.test_mask]).sum()
acc = int(correct) / int(data.test_mask.sum())
print(f'Accuracy: {acc:.4f}')

Accuracy: 0.7480
