In [None]:
"""
Citations:

GCN MODEL
@article{DBLP:journals/corr/KipfW16,
  author       = {Thomas N. Kipf and
                  Max Welling},
  title        = {Semi-Supervised Classification with Graph Convolutional Networks},
  journal      = {CoRR},
  volume       = {abs/1609.02907},
  year         = {2016},
  url          = {http://arxiv.org/abs/1609.02907},
  eprinttype    = {arXiv},
  eprint       = {1609.02907},
  timestamp    = {Mon, 13 Aug 2018 16:48:31 +0200},
  biburl       = {https://dblp.org/rec/journals/corr/KipfW16.bib},
  bibsource    = {dblp computer science bibliography, https://dblp.org}
}

HGNN MODEL
@inproceedings{feng2019hypergraph,
  title={Hypergraph neural networks},
  author={Feng, Yifan and You, Haoxuan and Zhang, Zizhao and Ji, Rongrong and Gao, Yue},
  booktitle={Proceedings of the AAAI conference on artificial intelligence},
  volume={33},
  number={01},
  pages={3558--3565},
  year={2019}
}

HGNN+ MODEL
@article{gao2022hgnn,
  title={HGNN $\^{}+ $: General Hypergraph Neural Networks},
  author={Gao, Yue and Feng, Yifan and Ji, Shuyi and Ji, Rongrong},
  journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
  year={2022},
  publisher={IEEE}
}
"""


import time
from copy import deepcopy

import torch
import torch.optim as optim
import torch.nn.functional as F

from dhg import Graph
from dhg.data import Citeseer #change dataset as desired, must change any instances below as well 
#I used Cora, PubMed, and Citeseer
from dhg.models import GCN #GCN model
from dhg.random import set_seed
from dhg.metrics import GraphVertexClassificationEvaluator as Evaluator
#the evaluator function that "evaluates" how well the model worked, graph-edition

#training the model
def train(net, X, A, lbls, train_idx, optimizer, epoch):
    net.train()

    st = time.time()
    optimizer.zero_grad()
    outs = net(X, A)
    outs, lbls = outs[train_idx], lbls[train_idx]
    loss = F.cross_entropy(outs, lbls)
    loss.backward()
    optimizer.step()
    print(f"Epoch: {epoch}, Time: {time.time()-st:.5f}s, Loss: {loss.item():.5f}")
    return loss.item()

#the inferring or prediction function
@torch.no_grad()
def infer(net, X, A, lbls, idx, test=False):
    net.eval()
    outs = net(X, A)
    outs, lbls = outs[idx], lbls[idx]
    if not test:
        res = evaluator.validate(lbls, outs)
    else:
        res = evaluator.test(lbls, outs)
    return res

if __name__ == "__main__":
    set_seed(2022)
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    evaluator = Evaluator(["accuracy", "f1_score", {"f1_score": {"average": "micro"}}])
    data = Citeseer()
    X, lbl = data["features"], data["labels"]
    G = Graph(data["num_vertices"], data["edge_list"])
    train_mask = data["train_mask"]
    val_mask = data["val_mask"]
    test_mask = data["test_mask"]

    net = GCN(data["dim_features"], 16, data["num_classes"])
    optimizer = optim.Adam(net.parameters(), lr=0.01, weight_decay=5e-4)

    X, lbl = X.to(device), lbl.to(device)
    G = G.to(device)
    net = net.to(device)
    
    #initializing best states, values, and epochs
    best_state = None
    best_epoch, best_val = 0, 0
    
    #training and validating, updating best state, value, and epoch
    for epoch in range(100):
        # train
        startTime = time.time() #recording start time
        train(net, X, G, lbl, train_mask, optimizer, epoch)
        # validation
        if epoch % 1 == 0:
            with torch.no_grad():
                val_res = infer(net, X, G, lbl, val_mask)
            if val_res > best_val:
                print(f"update best: {val_res:.5f}")
                best_epoch = epoch
                best_val = val_res
                best_state = deepcopy(net.state_dict())
    #recording end time
    endTime = time.time()  
    #calculating elapsed time
    elapsedTime = endTime - startTime
    print("\ntrain finished!")
    print(f"best val: {best_val:.5f}")
    
    #testing and printing the final results
    print("test...")
    net.load_state_dict(best_state)
    res = infer(net, X, G, lbl, test_mask, test=True)
    print(f"final result: epoch: {best_epoch}")
    print(res)
    print("Elapsed Time: ", elapsedTime)