In [182]:
# Created while I was learning torch geometric and testing a few things
# code (sort of) adapted from https://pytorch-geometric.readthedocs.io/en/latest/notes/introduction.html

In [48]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid

print(torch.__version__)
device = "cuda:0" if torch.cuda.is_available() else "cpu"

1.9.0


In [49]:
data_dir = "data/"
cora = Planetoid(root=data_dir, name="Cora")

print(f"Number of classes: {cora.num_classes}")
print(f"Number of features: {cora.num_features}")
print(f"Number of nodes: {cora.data.num_nodes}")
print(f"Number of edges: {cora.data.num_edges}")

Number of classes: 7
Number of features: 1433
Number of nodes: 2708
Number of edges: 10556


In [135]:
class GNN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = GCNConv(cora.num_node_features, 128)
        self.conv2 = GCNConv(128, 32)
        self.conv3 = GCNConv(32, cora.num_classes)
        
    def forward(self, cora):
        x, edge_index = cora.x, cora.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        #x = F.dropout(x, training=self.training)
        x = self.conv3(x, edge_index)

        return F.log_softmax(x, dim=1)

In [140]:
lr = 0.0001
weight_decay = 5e-3
epochs = 1000

In [152]:
model = GNN().to(device)
data = cora.data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

model.train()
for epoch in range(epochs):
    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    
    if epoch % 10 == 0:
        print(f"epoch: {epoch} - loss: {loss}")

epoch: 0 - loss: 1.9489637613296509
epoch: 10 - loss: 1.9298394918441772
epoch: 20 - loss: 1.9116084575653076
epoch: 30 - loss: 1.8860547542572021
epoch: 40 - loss: 1.8637586832046509
epoch: 50 - loss: 1.831717610359192
epoch: 60 - loss: 1.8016436100006104
epoch: 70 - loss: 1.7624397277832031
epoch: 80 - loss: 1.7200994491577148
epoch: 90 - loss: 1.668863296508789
epoch: 100 - loss: 1.6199707984924316
epoch: 110 - loss: 1.5583577156066895
epoch: 120 - loss: 1.509276032447815
epoch: 130 - loss: 1.4529714584350586
epoch: 140 - loss: 1.378476858139038
epoch: 150 - loss: 1.3157808780670166
epoch: 160 - loss: 1.2441827058792114
epoch: 170 - loss: 1.1869970560073853
epoch: 180 - loss: 1.1238960027694702
epoch: 190 - loss: 1.0601407289505005
epoch: 200 - loss: 1.0054330825805664
epoch: 210 - loss: 0.945567786693573
epoch: 220 - loss: 0.905239462852478
epoch: 230 - loss: 0.8456275463104248
epoch: 240 - loss: 0.8059667348861694
epoch: 250 - loss: 0.7510144710540771
epoch: 260 - loss: 0.71317422

In [183]:
model.eval()
pred = model(data).argmax(dim=1)
correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
acc = int(correct) / int(data.test_mask.sum())
print('Accuracy: {:.4f}'.format(acc))

Accuracy: 0.8100


In [181]:
correct = (pred[data.test_mask] == data.y[data.test_mask])
print(correct.shape)
print(correct)

torch.Size([1000])
tensor([False,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True, False,  True,  True,  True,  True,
         True,  True,  True, False,  True, False,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True, False,  True,  True,  True,
         True,  True,  True,  True,  True,  True, False,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True, False,  True,  True,
         True,  True,  True,  True, False, False,  True,  True,  True,  True,
         True, False,  True, False, False, False, False,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  T

In [189]:
data.y.shape

torch.Size([2708])