In [1]:
import torch
import torch.nn.functional as F
import torch.optim as optim

from dhg.data import Cora
from dhg import Graph
from dhg.models import GCN
from dhg.metrics import GraphVertexClassificationEvaluator as Evaluator
from dhg.random import set_seed

import time
from copy import deepcopy

In [2]:
def train(A, X, lbls, net, train_idx, optimizer, epoch):

    net.train()

    st = time.time()
    optimizer.zero_grad()
    outs = net(X, A)
    outs, lbls = outs[train_idx], lbls[train_idx]
    loss = F.cross_entropy(outs, lbls)
    loss.backward()
    optimizer.step()

    print(f"Epoch: {epoch}, Time: {time.time()-st:.5f}s, Loss: {loss.item():.5f}")

    return loss.item()

In [3]:
@torch.no_grad()
def infer(A, X, lbls, net, idx, test=False):

    net.eval()

    outs = net(X, A)
    outs, lbls = outs[idx], lbls[idx]

    if not test:
        res = evaluator.validate(lbls, outs)
    else:
        res = evaluator.test(lbls, outs)

    return res

In [4]:
if __name__ == "__main__":

    set_seed(2024)

    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

    ################################################################
    data = Cora()
    G = Graph(data["num_vertices"], data["edge_list"])
    X, lbl = data["features"], data["labels"]

    train_mask = data["train_mask"]
    val_mask   = data["val_mask"]
    test_mask  = data["test_mask"]

    net = GCN(data["dim_features"], 16, data["num_classes"])
    optimizer = optim.Adam(net.parameters(), lr=0.01, weight_decay=5e-4)

    evaluator = Evaluator(["accuracy", "f1_score", {"f1_score": {"average": "micro"}}])

    ################################################################
    G = G.to(device)
    X, lbl = X.to(device), lbl.to(device)
    net = net.to(device)

    ################################################################
    best_val = 0
    best_state = None
    best_epoch = 0
    for epoch in range(300):

        # train, 每轮optimizer的参数在优化
        train(G, X, lbl, net, train_mask, optimizer, epoch)

        # validation
        if epoch % 1 == 0:
            with torch.no_grad():
                val_res = infer(G, X, lbl, net, val_mask)
            if val_res > best_val:
                print(f"update best: {val_res:.5f}")
                best_val   = val_res
                best_state = deepcopy(net.state_dict())
                best_epoch = epoch

    print("\ntrain finished!")
    print(f"best val: {best_val:.5f}")

    # test
    print("test...")
    net.load_state_dict(best_state)
    res = infer(G, X, lbl, net, test_mask, test=True)
    print(f"final result: epoch: {best_epoch}")
    print(res)

Epoch: 0, Time: 0.33405s, Loss: 1.96712
update best: 0.05800
Epoch: 1, Time: 0.00088s, Loss: 1.96182
Epoch: 2, Time: 0.00087s, Loss: 1.95624
Epoch: 3, Time: 0.00079s, Loss: 1.95098
Epoch: 4, Time: 0.00083s, Loss: 1.94301
update best: 0.10800
Epoch: 5, Time: 0.00079s, Loss: 1.93629
update best: 0.15400
Epoch: 6, Time: 0.00081s, Loss: 1.93044
Epoch: 7, Time: 0.00077s, Loss: 1.92347
Epoch: 8, Time: 0.00078s, Loss: 1.91221
Epoch: 9, Time: 0.00068s, Loss: 1.90219
Epoch: 10, Time: 0.00066s, Loss: 1.90319
update best: 0.20000
Epoch: 11, Time: 0.00069s, Loss: 1.88750
update best: 0.27600
Epoch: 12, Time: 0.00125s, Loss: 1.88159
update best: 0.29400
Epoch: 13, Time: 0.00077s, Loss: 1.87375
Epoch: 14, Time: 0.00076s, Loss: 1.86158
Epoch: 15, Time: 0.00084s, Loss: 1.85921
Epoch: 16, Time: 0.00090s, Loss: 1.84759
Epoch: 17, Time: 0.00089s, Loss: 1.83661
Epoch: 18, Time: 0.00094s, Loss: 1.81793
Epoch: 19, Time: 0.00090s, Loss: 1.81405
Epoch: 20, Time: 0.00090s, Loss: 1.80205
Epoch: 21, Time: 0.0009