In [None]:
import warnings

warnings.filterwarnings("ignore")

import torch
import torch.nn.functional as F
from torch_geometric.datasets import WebKB
from torch_geometric.nn import GCNConv, MixHopConv

torch.manual_seed(0)

In [None]:
# ----------------------------------------------------------------------
# 1. Load dataset and select split
# ----------------------------------------------------------------------
dataset = WebKB(root="../dataset/WebKB", name="Texas")
data = dataset[0]

data.x = data.x.float()
data.y = data.y.long()
data.edge_index = data.edge_index.long()

num_nodes = data.num_nodes
num_classes = dataset.num_classes

print(f"Total nodes: {num_nodes}, num_classes: {num_classes}")
print("Per class counts:", torch.bincount(data.y, minlength=num_classes).tolist())


def calculate_split_masks(data):
    """Merge all splits into one and create a new stratified 70/0/30 split."""
    labels = data.y
    num_nodes = data.num_nodes
    train_mask = torch.zeros(num_nodes, dtype=torch.bool)
    val_mask = torch.zeros(num_nodes, dtype=torch.bool)
    test_mask = torch.zeros(num_nodes, dtype=torch.bool)
    for class_label in range(num_classes):
        class_indices = (labels == class_label).nonzero(as_tuple=True)[0]
        num_in_class = class_indices.size(0)
        num_train = int(0.7 * num_in_class)
        # Shuffle indices
        perm = torch.randperm(num_in_class)
        shuffled_indices = class_indices[perm]
        train_indices = shuffled_indices[:num_train]
        test_indices = shuffled_indices[num_train:]
        train_mask[train_indices] = True
        test_mask[test_indices] = True
    return train_mask, val_mask, test_mask


train_mask, val_mask, test_mask = calculate_split_masks(data)

# Total nodes and per class counts in each split
train_counts = torch.bincount(data.y[train_mask], minlength=num_classes)
val_counts = torch.bincount(data.y[val_mask], minlength=num_classes)
test_counts = torch.bincount(data.y[test_mask], minlength=num_classes)


def fmt_per_class(counts, total):
    # Returns list of strings like '23.1%' per class
    # return [f"{100.0 * c / total:.1f}%" for c in counts.tolist()]
    return {f"{i}": f"{100.0 * c / total:.1f}%" for i, c in enumerate(counts.tolist())}


train_pct = fmt_per_class(train_counts, train_mask.sum().item())
val_pct = fmt_per_class(val_counts, val_mask.sum().item())
test_pct = fmt_per_class(test_counts, test_mask.sum().item())

print(
    f"  train: {int(train_mask.sum())} nodes | per class: {train_pct}\n"
    f"  val:   {int(val_mask.sum())} nodes | per class: {val_pct}\n"
    f"  test:  {int(test_mask.sum())} nodes | per class: {test_pct}"
)

In [16]:
# ----------------------------------------------------------------------
# 2. Utility: accuracy
# ----------------------------------------------------------------------
def accuracy(logits, labels):
    preds = logits.argmax(dim=-1)
    return (preds == labels).float().mean().item()


In [None]:
class GCN(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, dropout=0.5):
        super().__init__()
        self.conv1 = GCNConv(in_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, out_dim)
        self.dropout = dropout

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index)
        return x


gcn_model = GCN(
    in_dim=dataset.num_features, hidden_dim=16, out_dim=dataset.num_classes, dropout=0.5
)

In [None]:
# ----------------------------------------------------------------------
# 3. Model: MixHop-based network
#    MixHopConv mixes A^0 X, A^1 X, A^2 X inside one layer.
# ----------------------------------------------------------------------
class MixHopNet(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, dropout=0.5, powers=(0, 1, 2)):
        super().__init__()
        self.powers = powers

        # First MixHop layer
        self.conv1 = MixHopConv(in_dim, hidden_dim, powers=powers)
        # conv1 output dimension is hidden_dim * len(powers)
        conv1_out_dim = hidden_dim * len(powers)

        # Second MixHop layer (optional, still relatively shallow)
        self.conv2 = MixHopConv(conv1_out_dim, hidden_dim, powers=powers)
        conv2_out_dim = hidden_dim * len(powers)

        # Final linear classifier
        self.lin = torch.nn.Linear(conv2_out_dim, out_dim)

        self.dropout = dropout

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)

        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)

        x = self.lin(x)
        return x


# ----------------------------------------------------------------------
# 4. Setup device and model
# ----------------------------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data = data.to(device)
train_mask = train_mask.to(device)
val_mask = val_mask.to(device)
test_mask = test_mask.to(device)

model = MixHopNet(
    in_dim=data.x.size(-1),
    hidden_dim=32,
    out_dim=num_classes,
    dropout=0.8,
    powers=(0, 1, 2),
).to(device)

# ----------------------------------------------------------------------
# 5. Class weights for cross-entropy (computed on training split)
# ----------------------------------------------------------------------
train_labels = data.y[train_mask]
train_counts = torch.bincount(train_labels, minlength=num_classes).float()

# Avoid division by zero by adding a small epsilon
eps = 1e-6
inv_freq = 1.0 / (train_counts + eps)

# Normalise the weights so that avg weight is 1
class_weights = inv_freq * (num_classes / inv_freq.sum())
class_weights = class_weights.to(device)

print("Train class counts:", train_counts.tolist())
print("Class weights:", class_weights.tolist())

optimiser = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

In [None]:
# ----------------------------------------------------------------------
# 6. Training loop
# ----------------------------------------------------------------------
best_val_acc = 0.0
best_test_acc = 0.0

for epoch in range(1, 501):
    model.train()
    optimiser.zero_grad()

    out = model(data.x, data.edge_index)
    loss = F.cross_entropy(out[train_mask], data.y[train_mask], weight=class_weights)
    loss.backward()
    optimiser.step()

    # Evaluation
    if epoch % 50 == 0 or epoch == 1:
        model.eval()
        with torch.no_grad():
            out = model(data.x, data.edge_index)

            train_acc = accuracy(out[train_mask], data.y[train_mask])
            val_acc = accuracy(out[val_mask], data.y[val_mask])
            test_acc = accuracy(out[test_mask], data.y[test_mask])

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_test_acc = test_acc

        print(
            f"Epoch {epoch:03d} | "
            f"loss {loss.item():.4f} | "
            f"train acc {train_acc:.4f} | "
            f"val acc {val_acc:.4f} | "
            f"test acc {test_acc:.4f}"
        )

print(f"Best validation accuracy: {best_val_acc:.4f}")
print(f"Test accuracy at best validation: {best_test_acc:.4f}")

# ----------------------------------------------------------------------
# 7. Final evaluation, including per class accuracies
# ----------------------------------------------------------------------
model.eval()
with torch.no_grad():
    out = model(data.x, data.edge_index)

final_test_acc = accuracy(out[test_mask], data.y[test_mask])
print(f"Final test accuracy (split {SPLIT}): {final_test_acc:.4f}")

# Class-wise accuracy
num_classes = dataset.num_classes
class_accuracies = []
for class_label in range(num_classes):
    class_mask = (data.y == class_label) & test_mask
    num_in_class = int(class_mask.sum().item())
    if num_in_class == 0:
        print(
            f"No test samples for class {class_label}, skipping accuracy calculation."
        )
        class_accuracies.append(float("nan"))
        continue

    class_acc = accuracy(out[class_mask], data.y[class_mask])
    class_accuracies.append(class_acc)
    print(
        f"Accuracy for class {class_label}: {class_acc:.4f} "
        f"(num test samples: {num_in_class})"
    )

print("Class-wise accuracies:", class_accuracies)
