In [None]:
import warnings

warnings.filterwarnings("ignore")

import torch
import torch.nn.functional as F
from torch_geometric.datasets import WebKB
from torch_geometric.nn import GCNConv

torch.manual_seed(0)

# Load dataset
dataset = WebKB(root="../dataset/WebKB", name="Texas")
data = dataset[0]

data.x = data.x.float()
data.y = data.y.long()
data.edge_index = data.edge_index.long()

num_nodes = data.num_nodes
num_classes = dataset.num_classes

# Explore dataset: number of nodes per class across all splits
counts = torch.bincount(data.y, minlength=num_classes)
print(f"Total nodes: {num_nodes} (per class: {counts.tolist()})")

# Choose which official split to use (0â€“9)
SPLIT = 0

train_mask = data.train_mask[:, SPLIT]
val_mask = data.val_mask[:, SPLIT]
test_mask = data.test_mask[:, SPLIT]
train_counts = torch.bincount(data.y[train_mask], minlength=num_classes)
val_counts = torch.bincount(data.y[val_mask], minlength=num_classes)
test_counts = torch.bincount(data.y[test_mask], minlength=num_classes)


def fmt_per_class(counts, total):
    # Returns list of strings like '23.1%' per class
    # return [f"{100.0 * c / total:.1f}%" for c in counts.tolist()]
    return {f"{i}": f"{100.0 * c / total:.1f}%" for i, c in enumerate(counts.tolist())}


train_pct = fmt_per_class(train_counts, train_mask.sum().item())
val_pct = fmt_per_class(val_counts, val_mask.sum().item())
test_pct = fmt_per_class(test_counts, test_mask.sum().item())

print(
    f"Split {SPLIT}:\n"
    f"  train: {int(train_mask.sum())} nodes | per class: {train_pct}\n"
    f"  val:   {int(val_mask.sum())} nodes | per class: {val_pct}\n"
    f"  test:  {int(test_mask.sum())} nodes | per class: {test_pct}"
)

Total nodes: 183 ([33, 1, 18, 101, 30])
Split 0:
  train: 87 nodes | per class: {'0': '16.1%', '1': '0.0%', '2': '8.0%', '3': '52.9%', '4': '23.0%'}
  val:   59 nodes | per class: {'0': '25.4%', '1': '1.7%', '2': '11.9%', '3': '52.5%', '4': '8.5%'}
  test:  37 nodes | per class: {'0': '10.8%', '1': '0.0%', '2': '10.8%', '3': '64.9%', '4': '13.5%'}


In [24]:
class GCN(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, dropout=0.5):
        super().__init__()
        self.conv1 = GCNConv(in_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, out_dim)
        self.dropout = dropout

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index)
        return x


def accuracy(logits, labels):
    preds = logits.argmax(dim=-1)
    return (preds == labels).float().mean().item()


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data = data.to(device)
train_mask = train_mask.to(device)
val_mask = val_mask.to(device)
test_mask = test_mask.to(device)

model = GCN(
    in_dim=data.x.size(-1),
    hidden_dim=16,
    out_dim=num_classes,
    dropout=0.6,
).to(device)

optimiser = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)

# Train on official train split, evaluate on official test split
for epoch in range(1, 101):
    model.train()
    optimiser.zero_grad()
    out = model(data.x, data.edge_index)
    loss = F.cross_entropy(out[train_mask], data.y[train_mask])
    loss.backward()
    optimiser.step()

    if epoch % 20 == 0 or epoch == 1:
        model.eval()
        with torch.no_grad():
            out = model(data.x, data.edge_index)
            train_acc = accuracy(out[train_mask], data.y[train_mask])
            test_acc = accuracy(out[test_mask], data.y[test_mask])
        print(
            f"Epoch {epoch:03d} | "
            f"loss {loss.item():.4f} | "
            f"train acc {train_acc:.4f} | "
            f"test acc {test_acc:.4f}"
        )

model.eval()
with torch.no_grad():
    out = model(data.x, data.edge_index)
final_test_acc = accuracy(out[test_mask], data.y[test_mask])

print(f"Final test accuracy (split {SPLIT}): {final_test_acc:.4f}")


Epoch 001 | loss 1.5800 | train acc 0.1264 | test acc 0.2432
Epoch 020 | loss 1.1096 | train acc 0.6897 | test acc 0.6486
Epoch 040 | loss 0.8077 | train acc 0.7471 | test acc 0.5946
Epoch 060 | loss 0.7887 | train acc 0.7586 | test acc 0.5946
Epoch 080 | loss 0.7056 | train acc 0.8161 | test acc 0.5946
Epoch 100 | loss 0.5949 | train acc 0.8506 | test acc 0.5946
Final test accuracy (split 0): 0.5946


In [18]:
# Since there are 5 classes, we need to report accuracy for each class
class_accuracies = []
for class_label in range(num_classes):
    class_mask = (data.y == class_label) & test_mask
    if class_mask.sum().item() == 0:
        print(
            f"No test samples for class {class_label}, skipping accuracy calculation."
        )
        class_accuracies.append(float("nan"))
        continue
    class_acc = accuracy(out[class_mask], data.y[class_mask])
    class_accuracies.append(class_acc)
    print(f"Accuracy for class {class_label}: {class_acc:.4f}")
print("Class-wise accuracies:", class_accuracies)


Accuracy for class 0: 0.0000
No test samples for class 1, skipping accuracy calculation.
Accuracy for class 2: 0.0000
Accuracy for class 3: 0.5833
Accuracy for class 4: 0.0000
Class-wise accuracies: [0.0, nan, 0.0, 0.5833333134651184, 0.0]
