In [1]:
import torch
import torch.nn.functional as F
import torch.optim as optim

from dhg.data import Cora
from dhg import Graph
from dhg.models import GAT
from dhg.metrics import GraphVertexClassificationEvaluator as Evaluator
from dhg.random import set_seed

import time
from copy import deepcopy

In [2]:
def train(A, X, lbls, net, train_idx, optimizer, epoch):

    net.train()

    st = time.time()
    optimizer.zero_grad()
    outs = net(X, A)
    outs, lbls = outs[train_idx], lbls[train_idx]
    loss = F.cross_entropy(outs, lbls)
    loss.backward()
    optimizer.step()

    print(f"Epoch: {epoch}, Time: {time.time()-st:.5f}s, Loss: {loss.item():.5f}")

    return loss.item()

In [3]:
@torch.no_grad()
def infer(A, X, lbls, net, idx, test=False):

    net.eval()

    outs = net(X, A)
    outs, lbls = outs[idx], lbls[idx]

    if not test:
        res = evaluator.validate(lbls, outs)
    else:
        res = evaluator.test(lbls, outs)

    return res

In [4]:
if __name__ == "__main__":

    set_seed(2024)

    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

    ################################################################
    data = Cora()
    G = Graph(data["num_vertices"], data["edge_list"])
    X, lbl = data["features"], data["labels"]

    train_mask = data["train_mask"]
    val_mask   = data["val_mask"]
    test_mask  = data["test_mask"]

    net = GAT(data["dim_features"], 8, data["num_classes"], num_heads=8, drop_rate=0.6)
    optimizer = optim.Adam(net.parameters(), lr=0.005, weight_decay=5e-4)

    evaluator = Evaluator(["accuracy", "f1_score", {"f1_score": {"average": "micro"}}])

    ################################################################
    G = G.to(device)
    X, lbl = X.to(device), lbl.to(device)
    net = net.to(device)

    ################################################################
    best_val = 0
    best_state = None
    best_epoch = 0
    for epoch in range(300):

        # train, 每轮optimizer的参数在优化
        train(G, X, lbl, net, train_mask, optimizer, epoch)

        # validation
        if epoch % 1 == 0:
            with torch.no_grad():
                val_res = infer(G, X, lbl, net, val_mask)
            if val_res > best_val:
                print(f"update best: {val_res:.5f}")
                best_val   = val_res
                best_state = deepcopy(net.state_dict())
                best_epoch = epoch

    print("\ntrain finished!")
    print(f"best val: {best_val:.5f}")

    # test
    print("test...")
    net.load_state_dict(best_state)
    res = infer(G, X, lbl, net, test_mask, test=True)
    print(f"final result: epoch: {best_epoch}")
    print(res)

Epoch: 0, Time: 0.33132s, Loss: 1.94915
update best: 0.15800
Epoch: 1, Time: 0.01032s, Loss: 1.94330
update best: 0.20800
Epoch: 2, Time: 0.01011s, Loss: 1.93807
update best: 0.25200
Epoch: 3, Time: 0.00988s, Loss: 1.93426
update best: 0.28000
Epoch: 4, Time: 0.01073s, Loss: 1.92814
update best: 0.31600
Epoch: 5, Time: 0.01040s, Loss: 1.92436
update best: 0.34800
Epoch: 6, Time: 0.01068s, Loss: 1.91924
update best: 0.49800
Epoch: 7, Time: 0.01079s, Loss: 1.91081
update best: 0.62000
Epoch: 8, Time: 0.01103s, Loss: 1.90589
update best: 0.65800
Epoch: 9, Time: 0.01079s, Loss: 1.89611
update best: 0.67800
Epoch: 10, Time: 0.01087s, Loss: 1.88985
update best: 0.73400
Epoch: 11, Time: 0.01001s, Loss: 1.88423
update best: 0.75400
Epoch: 12, Time: 0.01002s, Loss: 1.87434
update best: 0.76600
Epoch: 13, Time: 0.01009s, Loss: 1.86490
Epoch: 14, Time: 0.01075s, Loss: 1.86375
Epoch: 15, Time: 0.01052s, Loss: 1.84934
Epoch: 16, Time: 0.01076s, Loss: 1.84064
Epoch: 17, Time: 0.01094s, Loss: 1.83152