In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch as th
import os
from dataset import StrokeDataset, collate
from torch.utils.data import DataLoader

home = os.environ['HOME']
data_dir = "./cls_5class"
trainset = StrokeDataset(os.path.join(data_dir, "train"))
validset = StrokeDataset(os.path.join(data_dir, "valid"))
testset = StrokeDataset(os.path.join(data_dir, "test"))

In [3]:
device = th.device('cuda')

train_loader = DataLoader(trainset,
                          batch_size=32,
                          shuffle=True,
                          collate_fn=collate(device))

valid_loader = DataLoader(validset,
                          batch_size=32,
                          shuffle=False,
                          collate_fn=collate(device))

test_loader = DataLoader(testset,
                         batch_size=32,
                         shuffle=False,
                         collate_fn=collate(device))

In [4]:
from sklearn.metrics import confusion_matrix
import numpy as np

def print_result(count):
    num_classes = count.shape[0]
    if num_classes == 5:
        id_to_tag = ['Graph', 'Text', 'Table', 'List', 'Math']
    elif num_classes == 2:
        id_to_tag = ['Non-text', 'Text']
    
    # Confusion matrix with accuracy for each tag
    print (("{: >2}{: >9}{: >9}%s{: >9}" % ("{: >9}" * num_classes)).format(
        "ID", "NE", "Total",
        *([id_to_tag[i] for i in range(num_classes)] + ["Percent"]))
    )
    for i in range(num_classes):
        print (("{: >2}{: >9}{: >9}%s{: >9}" % ("{: >9}" * num_classes)).format(
            str(i), id_to_tag[i], str(count[i].sum()),
            *([count[i][j] for j in range(num_classes)] +
              ["%.3f" % (count[i][i] * 100. / max(1, count[i].sum()))])
        ))

    # Global accuracy
    accuracy = 100. * count.trace() / max(1, count.sum())
    print ("Stroke accuracy: %i/%i (%.5f%%)" % (
        count.trace(), count.sum(), accuracy)
    )
    
def evaluate(model, loader, num_classes):
    model.eval()
    count = np.zeros((num_classes, num_classes), dtype=np.int32)
    for it, (fg, lg) in enumerate(loader):
        logits = model(fg)
        _, predictions = th.max(logits, dim=1)
        labels = lg.ndata['y']
        predictions = predictions.cpu().numpy()
        labels = labels.cpu().numpy()
        count += confusion_matrix(labels, predictions)
    model.train()
    
    print_result(count)
    
    accuracy = 100. * count.trace() / max(1, count.sum())
    return accuracy, count

In [13]:
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

from gcn import GCNNet
from gat import GAT

in_feats = 23
hidden_feats = 10
num_heads = 8
num_out_heads = 1
num_layers = 2
residual = True
in_drop = 0.2
attn_drop = 0.2
lr = 5e-3
weight_decay = 5e-4
alpha = 0.2

num_classes = 5
num_epoches = 20
heads = ([num_heads] * num_layers) + [num_out_heads]

# model = GCNNet(in_feats, hidden_feats, num_classes).to(device)
model = GAT(num_layers, in_feats, hidden_feats, num_classes,
            heads, F.elu, in_drop, attn_drop, alpha, residual).to(device)
loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

epoch_losses = []
best_valid_acc = 0
best_test_acc = 0
best_round = 0

for epoch in range(num_epoches):
    epoch_loss = 0
    for it, (fg, lg) in enumerate(train_loader):
        logits = model(fg)
        labels = lg.ndata['y']
        loss = loss_func(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.detach().item()
        
    epoch_loss /= (it + 1)
    print('Epoch {:3d}, loss {:4f}'.format(epoch, epoch_loss))
    epoch_losses.append(epoch_loss)
    
    valid_acc, _ = evaluate(model, valid_loader, num_classes)
    if valid_acc > best_valid_acc:
        test_acc, test_conf_mat = evaluate(model, test_loader, num_classes)
        if test_acc > best_test_acc:
            best_test_acc = test_acc
            best_conf_mat = test_conf_mat
            best_round = epoch

print("Best round: %d" % best_round)
print_result(best_conf_mat)



Epoch   0, loss 2.375172
ID       NE    Total    Graph     Text    Table     List     Math  Percent
 0    Graph  15481.0   9740.0   4206.0    469.0    133.0    933.0   62.916
 1     Text  39796.0   2445.0  32980.0   1476.0    435.0   2460.0   82.873
 2    Table   6562.0    947.0   4311.0    509.0    222.0    573.0    7.757
 3     List   3474.0    202.0   2690.0    229.0    120.0    233.0    3.454
 4     Math   3412.0    585.0    959.0    245.0     86.0   1537.0   45.047
Stroke accuracy: 44886/68725 (65.31248%)
ID       NE    Total    Graph     Text    Table     List     Math  Percent
 0    Graph  17488.0  12356.0   3427.0    480.0    148.0   1077.0   70.654
 1     Text  40469.0   2293.0  33104.0   1645.0    567.0   2860.0   81.801
 2    Table   6883.0    867.0   4694.0    523.0    203.0    596.0    7.598
 3     List   3115.0    227.0   2333.0    187.0     97.0    271.0    3.114
 4     Math   2972.0    521.0    797.0    264.0     64.0   1326.0   44.616
Stroke accuracy: 47406/70927 (66.8

ID       NE    Total    Graph     Text    Table     List     Math  Percent
 0    Graph  17488.0  13695.0   3137.0     77.0      1.0    578.0   78.311
 1     Text  40469.0   2368.0  37117.0     19.0      0.0    965.0   91.717
 2    Table   6883.0    820.0   5648.0    145.0      0.0    270.0    2.107
 3     List   3115.0    259.0   2738.0      3.0      1.0    114.0    0.032
 4     Math   2972.0    590.0   1168.0     24.0      0.0   1190.0   40.040
Stroke accuracy: 52148/70927 (73.52348%)
Epoch   9, loss 0.929424
ID       NE    Total    Graph     Text    Table     List     Math  Percent
 0    Graph  15481.0  10901.0   3694.0     99.0     11.0    776.0   70.415
 1     Text  39796.0   2373.0  35778.0     38.0      3.0   1604.0   89.904
 2    Table   6562.0    922.0   4988.0    167.0      1.0    484.0    2.545
 3     List   3474.0    214.0   3091.0     11.0      3.0    155.0    0.086
 4     Math   3412.0    517.0   1167.0     16.0      0.0   1712.0   50.176
Stroke accuracy: 48561/68725 (70.6

Epoch  17, loss 0.874024
ID       NE    Total    Graph     Text    Table     List     Math  Percent
 0    Graph  15481.0  11527.0   3613.0     66.0      1.0    274.0   74.459
 1     Text  39796.0   2529.0  36726.0     26.0      1.0    514.0   92.286
 2    Table   6562.0    984.0   5219.0    170.0      0.0    189.0    2.591
 3     List   3474.0    224.0   3195.0     11.0      1.0     43.0    0.029
 4     Math   3412.0    785.0   1502.0     12.0      0.0   1113.0   32.620
Stroke accuracy: 49537/68725 (72.08003%)
ID       NE    Total    Graph     Text    Table     List     Math  Percent
 0    Graph  17488.0  14030.0   3069.0     65.0      0.0    324.0   80.226
 1     Text  40469.0   2148.0  37829.0     22.0      0.0    470.0   93.476
 2    Table   6883.0    818.0   5729.0    185.0      0.0    151.0    2.688
 3     List   3115.0    256.0   2805.0      1.0      0.0     53.0    0.000
 4     Math   2972.0    689.0   1376.0     21.0      0.0    886.0   29.812
Stroke accuracy: 52930/70927 (74.6