In [1]:
import warnings
warnings.filterwarnings('ignore')
from build_graph import BuildGraph
from py_hgat import GAT
import torch.nn.functional as F
import torch 
import time
from torcheval.metrics.functional import multiclass_f1_score
import numpy as np

In [2]:
# Build Graph
g = BuildGraph("UIT_VFSC").g
print(g)

step pre processing
step add word doc edge


Processing documents: 16175it [00:02, 6311.29it/s]


step add word word edge


Constructing word pair count frequency: 100%|██████████| 87309/87309 [00:03<00:00, 21977.41it/s]
Adding word_word edges: 100%|██████████| 337534/337534 [00:00<00:00, 520706.29it/s]


step setup graph
Graph(num_nodes=19020, num_edges=595385,
      ndata_schemes={'x': Scheme(shape=(2845,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.float32), 'train_mask': Scheme(shape=(), dtype=torch.bool), 'test_mask': Scheme(shape=(), dtype=torch.bool)}
      edata_schemes={'weight': Scheme(shape=(), dtype=torch.float32)})


In [3]:
node_features = g.ndata['x']
node_labels = g.ndata['label']
train_mask = g.ndata['train_mask']
test_mask = g.ndata['test_mask']
n_features = node_features.shape[1]
n_labels = int(node_labels.max().item() + 1)
print(n_features)
print(node_features.shape)

2845
torch.Size([19020, 2845])


In [4]:
def evaluate(model, graph, features, labels, mask, edge_weights=None):
    
    with torch.no_grad():
        model.eval()
        logits = model(graph, features)
        logits = logits[mask]
        labels = labels[mask]
        _, indices = torch.max(logits, dim=1)

        correct = torch.sum(indices == labels)

        acc = correct.item() * 1.0 / len(labels)
        mf1 = multiclass_f1_score(indices.type(torch.long), labels.type(torch.long), num_classes=n_labels, average='macro')
        wf1 = multiclass_f1_score(indices.type(torch.long), labels.type(torch.long), num_classes=n_labels, average='weighted')

        return acc, mf1, wf1

In [5]:
model = GAT(
                 num_layers=1,
                 in_dim=n_features,
                 num_hidden=200,
                 num_classes=n_labels,
                 heads=[4] + [1],
                 activation=F.elu,
                 feat_drop=0.5,
                 attn_drop=0.5,
        )
opt = torch.optim.Adam(model.parameters(), lr=5e-3)

In [6]:
dur = []
max_acc = 0
max_f1 = 0
for epoch in range(1000):
    if epoch >= 3:
        t0 = time.time()
    model.train()
    # forward propagation by using all nodes
    logits = model(g, node_features)
    
    # compute loss
    loss = F.cross_entropy(logits[train_mask].to(torch.float32), node_labels[train_mask].to(torch.long))

    # backward propagation
    opt.zero_grad()
    loss.backward()
    opt.step()

    if epoch >= 3:
        dur.append(time.time() - t0)

    # compute validation accuracy
    acc, mf1, wf1 = evaluate(model, g, node_features, node_labels, test_mask)
    print(f"Epoch {epoch:05d} | Loss {loss.item():.4f} | Test Acc {acc:.4f} | mF1 {mf1:.4f} | wF1 {wf1:.4f} | Time(s) {np.mean(dur):.4f}")

    if acc > max_acc:
        max_acc = acc
    max_f1 = max(max(max_f1, wf1), mf1)

Epoch 00000 | Loss 1.1054 | Test Acc 0.5816 | mF1 0.3367 | wF1 0.4916 | Time(s) nan
Epoch 00001 | Loss 0.8827 | Test Acc 0.7593 | mF1 0.5163 | wF1 0.7374 | Time(s) nan
Epoch 00002 | Loss 0.7825 | Test Acc 0.7707 | mF1 0.5273 | wF1 0.7516 | Time(s) nan
Epoch 00003 | Loss 0.7426 | Test Acc 0.7597 | mF1 0.5198 | wF1 0.7407 | Time(s) 4.3610
Epoch 00004 | Loss 0.7368 | Test Acc 0.7543 | mF1 0.5159 | wF1 0.7356 | Time(s) 4.2659
Epoch 00005 | Loss 0.7412 | Test Acc 0.7446 | mF1 0.5093 | wF1 0.7254 | Time(s) 4.3451
Epoch 00006 | Loss 0.7248 | Test Acc 0.7326 | mF1 0.5004 | wF1 0.7122 | Time(s) 4.9874
Epoch 00007 | Loss 0.7411 | Test Acc 0.7494 | mF1 0.5128 | wF1 0.7305 | Time(s) 5.4236
Epoch 00008 | Loss 0.7313 | Test Acc 0.7614 | mF1 0.5210 | wF1 0.7425 | Time(s) 5.2635
Epoch 00009 | Loss 0.7171 | Test Acc 0.7635 | mF1 0.5224 | wF1 0.7446 | Time(s) 5.2834
Epoch 00010 | Loss 0.7260 | Test Acc 0.7501 | mF1 0.5131 | wF1 0.7309 | Time(s) 5.2956
Epoch 00011 | Loss 0.6967 | Test Acc 0.7543 | mF1 0.

KeyboardInterrupt: 

In [7]:
print(f"Max accuracy: {max_acc:.4f}")
print(f'Max F1: {max_f1:.4f}')

Max accuracy: 0.8699
Max F1: 0.8492
