In [10]:
import time
from copy import deepcopy

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from dhg import Hypergraph
from dhg.data import Cooking200
from dhg.nn import HGNNConv
from dhg.random import set_seed
from typing import Union, Dict, List
from dhg.metrics.classification import VertexClassificationEvaluator

In [7]:
class HGNN(nn.Module):
    def __init__(self, in_channels, hid_channels, num_classes, use_bn, drop_rate=0.5):
        super(HGNN, self).__init__()
        self.layers = nn.ModuleList()
        self.layers.append(
            HGNNConv(in_channels, hid_channels, use_bn=use_bn, drop_rate=drop_rate)
        )
        self.layers.append(
            HGNNConv(hid_channels, num_classes, use_bn=use_bn, is_last=True)
        )

    def forward(self, X, hg):
        for layer in self.layers:
            X = layer(X, hg)

In [11]:
class HypergraphVertexClassificationEvaluator(VertexClassificationEvaluator):
    def __init__(
        self, metric_configs: List[Union[str, Dict[str, dict]]], validate_index: int = 0
    ):
        super(HypergraphVertexClassificationEvaluator ,self).__init__(metric_configs, validate_index)

    def validate(self, y_true: torch.LongTensor, y_pred: torch.Tensor):
        return super().validate(y_true, y_pred)


    def test(self, y_true: torch.LongTensor, y_pred: torch.Tensor):
        return super().test(y_true, y_pred)

In [3]:
def train(net, X, A, lbls, train_idx, optimizer, epoch):
    net.train()

    st = time.time()
    optimizer.zero_grad()
    outs = net(X, A)
    outs, lbls = outs[train_idx], lbls[train_idx]
    loss = F.cross_entropy(outs, lbls)
    loss.backward()
    optimizer.step()
    print(f"Epoch: {epoch}, Time: {time.time()-st:.5f}s, Loss: {loss.item():.5f}")
    return loss.item()


@torch.no_grad()
def infer(net, X, A, lbls, idx, test=False):
    net.eval()
    outs = net(X, A)
    outs, lbls = outs[idx], lbls[idx]
    if not test:
        res = evaluator.validate(lbls, outs)
    else:
        res = evaluator.test(lbls, outs)
    return res

In [4]:
if __name__ == "__main__":
    set_seed(42)
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    evaluator = HypergraphVertexClassificationEvaluator(["accuracy", "f1_score", {"f1_score": {"average": "micro"}}])
    data = Cooking200()

    X, lbl = torch.eye(data["num_vertices"]), data["labels"]
    G = Hypergraph(data["num_vertices"], data["edge_list"])
    train_mask = data["train_mask"]
    val_mask = data["val_mask"]
    test_mask = data["test_mask"]

    net = HGNN(X.shape[1], 32, data["num_classes"], use_bn=True)
    optimizer = optim.Adam(net.parameters(), lr=0.01, weight_decay=5e-4)

    X, lbl = X.to(device), lbl.to(device)
    G = G.to(device)
    net = net.to(device)

    best_state = None
    best_epoch, best_val = 0, 0
    for epoch in range(200):
        # train
        train(net, X, G, lbl, train_mask, optimizer, epoch)
        # validation
        if epoch % 1 == 0:
            with torch.no_grad():
                val_res = infer(net, X, G, lbl, val_mask)
            if val_res > best_val:
                print(f"update best: {val_res:.5f}")
                best_epoch = epoch
                best_val = val_res
                best_state = deepcopy(net.state_dict())
    print("\ntrain finished!")
    print(f"best val: {best_val:.5f}")
    # test
    print("test...")
    net.load_state_dict(best_state)
    res = infer(net, X, G, lbl, test_mask, test=True)
    print(f"final result: epoch: {best_epoch}")
    print(res)


Epoch: 0, Time: 0.65269s, Loss: 3.01102
update best: 0.05000
Epoch: 1, Time: 0.42421s, Loss: 2.27640
Epoch: 2, Time: 0.29357s, Loss: 2.14381
Epoch: 3, Time: 0.29305s, Loss: 2.04013
Epoch: 4, Time: 0.29318s, Loss: 1.98351
Epoch: 5, Time: 0.30398s, Loss: 1.92233
Epoch: 6, Time: 0.44434s, Loss: 1.87511
Epoch: 7, Time: 0.41907s, Loss: 1.82192
Epoch: 8, Time: 0.41729s, Loss: 1.76315
update best: 0.08500
Epoch: 9, Time: 0.45520s, Loss: 1.74401
update best: 0.11500
Epoch: 10, Time: 0.53395s, Loss: 1.71227
Epoch: 11, Time: 0.42974s, Loss: 1.66392
Epoch: 12, Time: 0.43261s, Loss: 1.62948
Epoch: 13, Time: 0.44239s, Loss: 1.60624
Epoch: 14, Time: 0.44863s, Loss: 1.57577
Epoch: 15, Time: 0.59132s, Loss: 1.55663
Epoch: 16, Time: 0.43944s, Loss: 1.50966
Epoch: 17, Time: 0.54647s, Loss: 1.48429
Epoch: 18, Time: 0.42638s, Loss: 1.44926
Epoch: 19, Time: 0.42827s, Loss: 1.43334
Epoch: 20, Time: 0.42928s, Loss: 1.42709
Epoch: 21, Time: 0.43701s, Loss: 1.38506
Epoch: 22, Time: 0.42893s, Loss: 1.35650
Epoc