In [1]:
import torch
import torch.nn.functional as F
import torch.optim as optim

from dhg.data import Cooking200
from dhg import Hypergraph
from dhg.models import HGNN
from dhg.metrics import HypergraphVertexClassificationEvaluator as Evaluator
from dhg.random import set_seed

import time
from copy import deepcopy

In [2]:
def train(A, X, lbls, net, train_idx, optimizer, epoch):

    net.train()

    st = time.time()
    optimizer.zero_grad()
    outs = net(X, A)
    outs, lbls = outs[train_idx], lbls[train_idx]
    loss = F.cross_entropy(outs, lbls)
    loss.backward()
    optimizer.step()

    print(f"Epoch: {epoch}, Time: {time.time()-st:.5f}s, Loss: {loss.item():.5f}")

    return loss.item()

In [3]:
@torch.no_grad()
def infer(A, X, lbls, net, idx, test=False):

    net.eval()

    outs = net(X, A)
    outs, lbls = outs[idx], lbls[idx]

    if not test:
        res = evaluator.validate(lbls, outs)
    else:
        res = evaluator.test(lbls, outs)

    return res

In [4]:
if __name__ == "__main__":

    set_seed(2024)

    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

    ################################################################
    data = Cooking200()
    G = Hypergraph(data["num_vertices"], data["edge_list"])
    X, lbl = torch.eye(data["num_vertices"]), data["labels"]

    train_mask = data["train_mask"]
    val_mask   = data["val_mask"]
    test_mask  = data["test_mask"]

    net = HGNN(X.shape[1], 32, data["num_classes"], use_bn=True)
    optimizer = optim.Adam(net.parameters(), lr=0.01, weight_decay=5e-4)

    evaluator = Evaluator(["accuracy", "f1_score", {"f1_score": {"average": "micro"}}])

    ################################################################
    G = G.to(device)
    X, lbl = X.to(device), lbl.to(device)
    net = net.to(device)

    ################################################################
    best_val = 0
    best_state = None
    best_epoch = 0
    for epoch in range(300):

        # train, 每轮optimizer的参数在优化
        train(G, X, lbl, net, train_mask, optimizer, epoch)

        # validation
        if epoch % 1 == 0:
            with torch.no_grad():
                val_res = infer(G, X, lbl, net, val_mask)
            if val_res > best_val:
                print(f"update best: {val_res:.5f}")
                best_val   = val_res
                best_state = deepcopy(net.state_dict())
                best_epoch = epoch

    print("\ntrain finished!")
    print(f"best val: {best_val:.5f}")

    # test
    print("test...")
    net.load_state_dict(best_state)
    res = infer(G, X, lbl, net, test_mask, test=True)
    print(f"final result: epoch: {best_epoch}")
    print(res)

Epoch: 0, Time: 0.32912s, Loss: 2.99382
update best: 0.05000
Epoch: 1, Time: 0.00915s, Loss: 2.69577
Epoch: 2, Time: 0.00790s, Loss: 2.42079
Epoch: 3, Time: 0.00771s, Loss: 2.23312
Epoch: 4, Time: 0.00892s, Loss: 2.04026
update best: 0.09000
Epoch: 5, Time: 0.00819s, Loss: 1.89091
update best: 0.11000
Epoch: 6, Time: 0.00753s, Loss: 1.77104
Epoch: 7, Time: 0.00888s, Loss: 1.64345
Epoch: 8, Time: 0.00888s, Loss: 1.50801
update best: 0.12000
Epoch: 9, Time: 0.00754s, Loss: 1.39273
Epoch: 10, Time: 0.00759s, Loss: 1.28937
update best: 0.13500
Epoch: 11, Time: 0.00892s, Loss: 1.17559
Epoch: 12, Time: 0.00753s, Loss: 1.08902
Epoch: 13, Time: 0.00767s, Loss: 0.99352
Epoch: 14, Time: 0.00872s, Loss: 0.90363
update best: 0.14000
Epoch: 15, Time: 0.00758s, Loss: 0.81562
update best: 0.14500
Epoch: 16, Time: 0.00756s, Loss: 0.75506
Epoch: 17, Time: 0.00845s, Loss: 0.66870
Epoch: 18, Time: 0.00781s, Loss: 0.59584
update best: 0.15500
Epoch: 19, Time: 0.00832s, Loss: 0.54505
Epoch: 20, Time: 0.009