In [1]:
import warnings
warnings.filterwarnings('ignore')
from build_graph import BuildGraph
from py_graphsage import SAGE
import torch.nn.functional as F
import torch 
import time
from torcheval.metrics.functional import multiclass_f1_score
import numpy as np

In [2]:
# Build Graph
g = BuildGraph("UIT_VFSC").g
print(g)

step pre processing
step add word doc edge


Processing documents: 16175it [00:02, 7458.62it/s]


step add word word edge


Constructing word pair count frequency: 100%|██████████| 87309/87309 [00:05<00:00, 15725.12it/s]
Adding word_word edges: 100%|██████████| 337534/337534 [00:00<00:00, 358506.93it/s]


step setup graph
Graph(num_nodes=19020, num_edges=595385,
      ndata_schemes={'x': Scheme(shape=(2845,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.float32), 'train_mask': Scheme(shape=(), dtype=torch.bool), 'test_mask': Scheme(shape=(), dtype=torch.bool)}
      edata_schemes={'weight': Scheme(shape=(), dtype=torch.float32)})


In [3]:
node_features = g.ndata['x']
node_labels = g.ndata['label']
train_mask = g.ndata['train_mask']
test_mask = g.ndata['test_mask']
n_features = node_features.shape[1]
n_labels = int(node_labels.max().item() + 1)
print(n_features)
print(node_features.shape)

2845
torch.Size([19020, 2845])


In [4]:
# evaluate model
def evaluate(model, graph, features, labels, mask, edge_weights=None):
    
    with torch.no_grad():
        model.eval()
        logits = model(graph, features)
        logits = logits[mask]
        labels = labels[mask]
        _, indices = torch.max(logits, dim=1)

        correct = torch.sum(indices == labels)

        acc = correct.item() * 1.0 / len(labels)
        mf1 = multiclass_f1_score(indices.type(torch.long), labels.type(torch.long), num_classes=n_labels, average='macro')
        wf1 = multiclass_f1_score(indices.type(torch.long), labels.type(torch.long), num_classes=n_labels, average='weighted')
        
        return acc, mf1, wf1

In [5]:
# Build model
model = SAGE(in_feats=n_features, hid_feats=200, out_feats=n_labels)
opt = torch.optim.Adam(model.parameters(), lr=5e-3)

In [6]:
# training model
dur = []
max_acc = 0
max_f1 = 0
for epoch in range(1000):
    if epoch >= 3:
        t0 = time.time()
    model.train()
    # forward propagation by using all nodes
    logits = model(g, node_features)
    
    # compute loss
    loss = F.cross_entropy(logits[train_mask].to(torch.float32), node_labels[train_mask].to(torch.long))

    # backward propagation
    opt.zero_grad()
    loss.backward()
    opt.step()

    if epoch >= 3:
        dur.append(time.time() - t0)

    # compute validation accuracy
    acc, mf1, wf1 = evaluate(model, g, node_features, node_labels, test_mask)
    print(f"Epoch {epoch:05d} | Loss {loss.item():.4f} | Test Acc {acc:.4f} | mF1 {mf1:.4f} | wF1 {wf1:.4f} | Time(s) {np.mean(dur):.4f}")

    if acc > max_acc:
        max_acc = acc
    max_f1 = max(max(max_f1, wf1), mf1)

Epoch 00000 | Loss 1.0604 | Test Acc 0.7846 | mF1 0.5368 | wF1 0.7648 | Time(s) nan
Epoch 00001 | Loss 0.8841 | Test Acc 0.8077 | mF1 0.5527 | wF1 0.7875 | Time(s) nan
Epoch 00002 | Loss 0.7501 | Test Acc 0.8170 | mF1 0.5589 | wF1 0.7966 | Time(s) nan
Epoch 00003 | Loss 0.6457 | Test Acc 0.8238 | mF1 0.5634 | wF1 0.8031 | Time(s) 1.2052
Epoch 00004 | Loss 0.5673 | Test Acc 0.8273 | mF1 0.5657 | wF1 0.8064 | Time(s) 1.2676
Epoch 00005 | Loss 0.5107 | Test Acc 0.8318 | mF1 0.5686 | wF1 0.8105 | Time(s) 1.2176
Epoch 00006 | Loss 0.4713 | Test Acc 0.8358 | mF1 0.5713 | wF1 0.8143 | Time(s) 1.1989
Epoch 00007 | Loss 0.4437 | Test Acc 0.8379 | mF1 0.5727 | wF1 0.8164 | Time(s) 1.1894
Epoch 00008 | Loss 0.4227 | Test Acc 0.8410 | mF1 0.5749 | wF1 0.8195 | Time(s) 1.1701
Epoch 00009 | Loss 0.4048 | Test Acc 0.8433 | mF1 0.5765 | wF1 0.8218 | Time(s) 1.1880
Epoch 00010 | Loss 0.3880 | Test Acc 0.8459 | mF1 0.5783 | wF1 0.8242 | Time(s) 1.1881
Epoch 00011 | Loss 0.3709 | Test Acc 0.8492 | mF1 0.

KeyboardInterrupt: 

In [7]:
print(f"Max accuracy: {max_acc:.4f}")
print(f'Max F1: {max_f1:.4f}')

Max accuracy: 0.8886
Max F1: 0.8822
