In [1]:
import torch
import torch.nn.functional as F
import torch.optim as optim

from dhg.data import Cora
from dhg import Graph
from dhg.models import GAT
from dhg.metrics import GraphVertexClassificationEvaluator as Evaluator
from dhg.random import set_seed

import time
from copy import deepcopy

In [2]:
def train(A, X, lbls, net, train_idx, optimizer, epoch):

    net.train()

    st = time.time()
    optimizer.zero_grad()
    outs = net(X, A)
    outs, lbls = outs[train_idx], lbls[train_idx]
    loss = F.cross_entropy(outs, lbls)
    loss.backward()
    optimizer.step()

    print(f"Epoch: {epoch}, Time: {time.time()-st:.5f}s, Loss: {loss.item():.5f}")

    return loss.item()

In [3]:
@torch.no_grad()
def infer(A, X, lbls, net, idx, test=False):

    net.eval()

    outs = net(X, A)
    outs, lbls = outs[idx], lbls[idx]

    if not test:
        res = evaluator.validate(lbls, outs)
    else:
        res = evaluator.test(lbls, outs)

    return res

In [4]:
if __name__ == "__main__":

    set_seed(2022)

    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

    ################################################################
    data = Cora()
    G = Graph(data["num_vertices"], data["edge_list"])
    X, lbl = data["features"], data["labels"]

    train_mask = data["train_mask"]
    val_mask   = data["val_mask"]
    test_mask  = data["test_mask"]

    net = GAT(data["dim_features"], 8, data["num_classes"], num_heads=8, drop_rate=0.6)
    optimizer = optim.Adam(net.parameters(), lr=0.005, weight_decay=5e-4)

    evaluator = Evaluator(["accuracy", "f1_score", {"f1_score": {"average": "micro"}}])

    ################################################################
    G = G.to(device)
    # X, lbl = X.to(device), lbl.to(device)
    X, lbl = X.cuda(), lbl.cuda()
    # net = net.to(device)
    net = net.cuda()

    ################################################################
    best_val = 0
    best_state = None
    best_epoch = 0
    for epoch in range(300):

        # train, 每轮optimizer的参数在优化
        train(G, X, lbl, net, train_mask, optimizer, epoch)

        # validation
        if epoch % 1 == 0:
            with torch.no_grad():
                val_res = infer(G, X, lbl, net, val_mask)
            if val_res > best_val:
                print(f"update best: {val_res:.5f}")
                best_val   = val_res
                best_state = deepcopy(net.state_dict())
                best_epoch = epoch

    print("\ntrain finished!")
    print(f"best val: {best_val:.5f}")

    # test
    print("test...")
    net.load_state_dict(best_state)
    res = infer(G, X, lbl, net, test_mask, test=True)
    print(f"final result: epoch: {best_epoch}")
    print(res)

Epoch: 0, Time: 0.32923s, Loss: 1.94815
update best: 0.12200
Epoch: 1, Time: 0.01027s, Loss: 1.94413
Epoch: 2, Time: 0.01060s, Loss: 1.93961
Epoch: 3, Time: 0.01084s, Loss: 1.93500
Epoch: 4, Time: 0.01072s, Loss: 1.93154
update best: 0.14000
Epoch: 5, Time: 0.01077s, Loss: 1.92676
update best: 0.18000
Epoch: 6, Time: 0.01021s, Loss: 1.91839
update best: 0.30000
Epoch: 7, Time: 0.01011s, Loss: 1.91322
update best: 0.48800
Epoch: 8, Time: 0.01005s, Loss: 1.91055
update best: 0.56800
Epoch: 9, Time: 0.00971s, Loss: 1.89851
update best: 0.59200
Epoch: 10, Time: 0.01029s, Loss: 1.89522
update best: 0.65000
Epoch: 11, Time: 0.01035s, Loss: 1.88976
update best: 0.66000
Epoch: 12, Time: 0.01025s, Loss: 1.87861
update best: 0.68400
Epoch: 13, Time: 0.00979s, Loss: 1.87413
Epoch: 14, Time: 0.00986s, Loss: 1.86477
update best: 0.69600
Epoch: 15, Time: 0.01024s, Loss: 1.85221
update best: 0.72200
Epoch: 16, Time: 0.01057s, Loss: 1.85074
update best: 0.73800
Epoch: 17, Time: 0.01051s, Loss: 1.84227