In [1]:
import warnings
warnings.filterwarnings('ignore')
from build_graph import BuildGraph
from py_gcn import GCN
import torch.nn.functional as F
import torch 
import time
from torcheval.metrics.functional import multiclass_f1_score
import numpy as np

In [2]:
# Build Graph
g = BuildGraph("UIT_VFSC").g
print(g)

step pre processing
step add word doc edge


Processing documents: 16175it [00:02, 5528.17it/s]


step add word word edge


Constructing word pair count frequency: 100%|██████████| 87309/87309 [00:04<00:00, 21045.41it/s]
Adding word_word edges: 100%|██████████| 337534/337534 [00:00<00:00, 408558.98it/s]


step setup graph
Graph(num_nodes=19020, num_edges=595385,
      ndata_schemes={'x': Scheme(shape=(2845,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.float32), 'train_mask': Scheme(shape=(), dtype=torch.bool), 'test_mask': Scheme(shape=(), dtype=torch.bool), 'val_mask': Scheme(shape=(), dtype=torch.bool)}
      edata_schemes={'weight': Scheme(shape=(), dtype=torch.float32)})


In [3]:
node_features = g.ndata['x']
node_labels = g.ndata['label']
train_mask = g.ndata['train_mask']
test_mask = g.ndata['test_mask']
n_features = node_features.shape[1]
n_labels = int(node_labels.max().item() + 1)
print(n_features)
print(node_features.shape)

2845
torch.Size([19020, 2845])


In [4]:
# evaluate model
def evaluate(model, graph, features, labels, mask, edge_weights=None):
    
    with torch.no_grad():
        model.eval()
        logits = model(features, graph, edge_weights)
        logits = logits[mask]
        labels = labels[mask]
        _, indices = torch.max(logits, dim=1)

        correct = torch.sum(indices == labels)

        acc = correct.item() * 1.0 / len(labels)
        mf1 = multiclass_f1_score(indices.type(torch.long), labels.type(torch.long), num_classes=n_labels, average='macro')
        wf1 = multiclass_f1_score(indices.type(torch.long), labels.type(torch.long), num_classes=n_labels, average='weighted')

        return acc, mf1, wf1

In [5]:
# Build Model
model = GCN(
            in_feats=n_features,
            n_hidden=200,
            n_classes=n_labels,
            n_layers=1,
            activation=F.elu,
            dropout=0.5
        )
opt = torch.optim.Adam(model.parameters(), lr=5e-3)

In [6]:
# training model
dur = []
max_acc = 0
max_f1 = 0
for epoch in range(1000):
    if epoch >= 3:
        t0 = time.time()
    model.train()
    # forward propagation by using all nodes
    logits = model(node_features, g, g.edata['weight'])
    
    # compute loss
    loss = F.cross_entropy(logits[train_mask].to(torch.float32), node_labels[train_mask].to(torch.long))

    # backward propagation
    opt.zero_grad()
    loss.backward()
    opt.step()

    if epoch >= 3:
        dur.append(time.time() - t0)

    # compute validation accuracy
    acc, mf1, wf1 = evaluate(model, g, node_features, node_labels, test_mask, g.edata['weight'])
    print(f"Epoch {epoch:05d} | Loss {loss.item():.4f} | Test Acc {acc:.4f} | mF1 {mf1:.4f} | wF1 {wf1:.4f} | Time(s) {np.mean(dur):.4f}")

    if acc > max_acc:
        max_acc = acc
    max_f1 = max(max(max_f1, wf1), mf1)


Epoch 00000 | Loss 2.2434 | Test Acc 0.4570 | mF1 0.2226 | wF1 0.2999 | Time(s) nan
Epoch 00001 | Loss 11.5292 | Test Acc 0.5411 | mF1 0.3273 | wF1 0.4554 | Time(s) nan
Epoch 00002 | Loss 9.7142 | Test Acc 0.6608 | mF1 0.4424 | wF1 0.6251 | Time(s) nan
Epoch 00003 | Loss 7.2514 | Test Acc 0.7527 | mF1 0.5148 | wF1 0.7311 | Time(s) 1.0434
Epoch 00004 | Loss 5.1819 | Test Acc 0.7745 | mF1 0.5303 | wF1 0.7543 | Time(s) 1.0460
Epoch 00005 | Loss 4.4022 | Test Acc 0.7432 | mF1 0.5048 | wF1 0.7196 | Time(s) 1.0296
Epoch 00006 | Loss 4.9085 | Test Acc 0.7227 | mF1 0.4861 | wF1 0.6941 | Time(s) 1.0245
Epoch 00007 | Loss 5.9783 | Test Acc 0.7344 | mF1 0.4957 | wF1 0.7074 | Time(s) 1.0131
Epoch 00008 | Loss 5.4559 | Test Acc 0.7517 | mF1 0.5112 | wF1 0.7285 | Time(s) 1.0324
Epoch 00009 | Loss 5.1341 | Test Acc 0.7735 | mF1 0.5286 | wF1 0.7525 | Time(s) 1.0299
Epoch 00010 | Loss 4.5637 | Test Acc 0.7836 | mF1 0.5365 | wF1 0.7633 | Time(s) 1.0324
Epoch 00011 | Loss 4.2856 | Test Acc 0.7843 | mF1 0

KeyboardInterrupt: 

In [None]:
print(f"Max accuracy: {max_acc:.4f}")
print(f'Max F1: {max_f1:.4f}')