In [6]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
import torch as th
import os
from dataset import StrokeDataset, collate
from torch.utils.data import DataLoader

home = os.environ['HOME']
data_dir = os.path.join(home, "data/IAMonDo/cls_5class")
trainset = StrokeDataset(os.path.join(data_dir, "train"))
validset = StrokeDataset(os.path.join(data_dir, "valid"))
testset = StrokeDataset(os.path.join(data_dir, "test"))

In [8]:
device = th.device('cpu')

train_loader = DataLoader(trainset,
                          batch_size=32,
                          shuffle=True,
                          collate_fn=collate(device))

valid_loader = DataLoader(validset,
                          batch_size=32,
                          shuffle=False,
                          collate_fn=collate(device))

test_loader = DataLoader(testset,
                         batch_size=32,
                         shuffle=False,
                         collate_fn=collate(device))

In [9]:
from sklearn.metrics import confusion_matrix
import numpy as np

def print_result(count):
    num_classes = count.shape[0]
    if num_classes == 5:
        id_to_tag = ['Graph', 'Text', 'Table', 'List', 'Math']
    elif num_classes == 2:
        id_to_tag = ['Non-text', 'Text']
    
    # Confusion matrix with accuracy for each tag
    print (("{: >2}{: >9}{: >9}%s{: >9}" % ("{: >9}" * num_classes)).format(
        "ID", "NE", "Total",
        *([id_to_tag[i] for i in range(num_classes)] + ["Percent"]))
    )
    for i in range(num_classes):
        print (("{: >2}{: >9}{: >9}%s{: >9}" % ("{: >9}" * num_classes)).format(
            str(i), id_to_tag[i], str(count[i].sum()),
            *([count[i][j] for j in range(num_classes)] +
              ["%.3f" % (count[i][i] * 100. / max(1, count[i].sum()))])
        ))

    # Global accuracy
    accuracy = 100. * count.trace() / max(1, count.sum())
    print ("Stroke accuracy: %i/%i (%.5f%%)" % (
        count.trace(), count.sum(), accuracy)
    )
    
def evaluate(model, loader, num_classes):
    model.eval()
    count = np.zeros((num_classes, num_classes))
    for it, (fg, lg) in enumerate(loader):
        logits = model(fg)
        _, predictions = th.max(logits, dim=1)
        labels = lg.ndata['y']
        predictions = predictions.cpu().numpy()
        labels = labels.cpu().numpy()
        count += confusion_matrix(labels, predictions)
    model.train()
    
    print_result(count)
    
    accuracy = 100. * count.trace() / max(1, count.sum())
    return accuracy, count

In [11]:
import torch.optim as optim

import torch.nn as nn
from gcn import GCNNet

in_feats = 23
hidden_feats = 30
num_classes = 5
num_epoches = 100

model = GCNNet(in_feats, hidden_feats, num_classes).to(device)
loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-2)

epoch_losses = []
best_valid_acc = 0
best_test_acc = 0
best_round = 0

for epoch in range(num_epoches):
    epoch_loss = 0
    for it, (fg, lg) in enumerate(train_loader):
        logits = model(fg)
        labels = lg.ndata['y']
        loss = loss_func(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.detach().item()
        
    epoch_loss /= (it + 1)
    print('Epoch {:3d}, loss {:4f}'.format(epoch, epoch_loss))
    epoch_losses.append(epoch_loss)
    
    valid_acc, _ = evaluate(model, valid_loader, num_classes)
    if valid_acc > best_valid_acc:
        test_acc, test_conf_mat = evaluate(model, test_loader, num_classes)
        if test_acc > best_test_acc:
            best_test_acc = test_acc
            best_conf_mat = test_conf_mat
            best_round = epoch

print("Best round: %d" % best_round)
print_result(best_conf_mat)

Epoch   0, loss 1.167410
ID       NE    Total    Graph     Text    Table     List     Math  Percent
 0    Graph  15481.0  11843.0   3638.0      0.0      0.0      0.0   76.500
 1     Text  39796.0   7945.0  31851.0      0.0      0.0      0.0   80.036
 2    Table   6562.0   2225.0   4337.0      0.0      0.0      0.0    0.000
 3     List   3474.0    781.0   2693.0      0.0      0.0      0.0    0.000
 4     Math   3412.0   2134.0   1278.0      0.0      0.0      0.0    0.000
Stroke accuracy: 43694/68725 (63.57803%)
ID       NE    Total    Graph     Text    Table     List     Math  Percent
 0    Graph  17488.0  14626.0   2862.0      0.0      0.0      0.0   83.634
 1     Text  40469.0   7636.0  32833.0      0.0      0.0      0.0   81.131
 2    Table   6883.0   2006.0   4877.0      0.0      0.0      0.0    0.000
 3     List   3115.0    820.0   2295.0      0.0      0.0      0.0    0.000
 4     Math   2972.0   1895.0   1077.0      0.0      0.0      0.0    0.000
Stroke accuracy: 47459/70927 (66.9

ID       NE    Total    Graph     Text    Table     List     Math  Percent
 0    Graph  17488.0  13503.0   3430.0     35.0      0.0    520.0   77.213
 1     Text  40469.0   1621.0  38353.0      7.0      0.0    488.0   94.771
 2    Table   6883.0    740.0   5872.0    141.0      0.0    130.0    2.049
 3     List   3115.0    222.0   2846.0      0.0      0.0     47.0    0.000
 4     Math   2972.0    511.0   1463.0      3.0      0.0    995.0   33.479
Stroke accuracy: 52992/70927 (74.71344%)
Epoch   9, loss 0.761719
ID       NE    Total    Graph     Text    Table     List     Math  Percent
 0    Graph  15481.0  11769.0   3314.0    128.0      0.0    270.0   76.022
 1     Text  39796.0   2894.0  36464.0     11.0      0.0    427.0   91.627
 2    Table   6562.0   1103.0   5175.0    172.0      0.0    112.0    2.621
 3     List   3474.0    284.0   3151.0      9.0      0.0     30.0    0.000
 4     Math   3412.0    879.0   1511.0     11.0      0.0   1011.0   29.631
Stroke accuracy: 49416/68725 (71.9

Epoch  17, loss 0.737564
ID       NE    Total    Graph     Text    Table     List     Math  Percent
 0    Graph  15481.0  11564.0   3425.0     54.0      3.0    435.0   74.698
 1     Text  39796.0   1857.0  37100.0     39.0      0.0    800.0   93.225
 2    Table   6562.0   1016.0   5174.0    182.0      0.0    190.0    2.774
 3     List   3474.0    237.0   3161.0     10.0      2.0     64.0    0.058
 4     Math   3412.0    636.0   1380.0     12.0      0.0   1384.0   40.563
Stroke accuracy: 50232/68725 (73.09131%)
ID       NE    Total    Graph     Text    Table     List     Math  Percent
 0    Graph  17488.0  13870.0   3075.0     74.0      0.0    469.0   79.312
 1     Text  40469.0   1529.0  38451.0     26.0      0.0    463.0   95.013
 2    Table   6883.0    831.0   5736.0    209.0      0.0    107.0    3.036
 3     List   3115.0    246.0   2810.0      1.0      2.0     56.0    0.064
 4     Math   2972.0    540.0   1375.0      8.0      0.0   1049.0   35.296
Stroke accuracy: 53581/70927 (75.5

Epoch  26, loss 0.715156
ID       NE    Total    Graph     Text    Table     List     Math  Percent
 0    Graph  15481.0  11604.0   3425.0    182.0      6.0    264.0   74.956
 1     Text  39796.0   1770.0  37272.0    272.0      1.0    481.0   93.658
 2    Table   6562.0   1014.0   4901.0    541.0      5.0    101.0    8.244
 3     List   3474.0    228.0   3072.0    110.0     36.0     28.0    1.036
 4     Math   3412.0    664.0   1633.0     36.0      0.0   1079.0   31.624
Stroke accuracy: 50532/68725 (73.52783%)
ID       NE    Total    Graph     Text    Table     List     Math  Percent
 0    Graph  17488.0  13879.0   3114.0    173.0      7.0    315.0   79.363
 1     Text  40469.0   1332.0  38675.0    215.0      2.0    245.0   95.567
 2    Table   6883.0    833.0   5518.0    475.0      5.0     52.0    6.901
 3     List   3115.0    247.0   2716.0     98.0     35.0     19.0    1.124
 4     Math   2972.0    601.0   1536.0     32.0      0.0    803.0   27.019
Stroke accuracy: 53867/70927 (75.9

ID       NE    Total    Graph     Text    Table     List     Math  Percent
 0    Graph  17488.0  13771.0   2898.0    172.0     21.0    626.0   78.745
 1     Text  40469.0   1186.0  38314.0    188.0      9.0    772.0   94.675
 2    Table   6883.0    813.0   5330.0    567.0     17.0    156.0    8.238
 3     List   3115.0    241.0   2567.0    146.0     86.0     75.0    2.761
 4     Math   2972.0    494.0   1146.0     33.0      0.0   1299.0   43.708
Stroke accuracy: 54037/70927 (76.18678%)
Epoch  35, loss 0.701515
ID       NE    Total    Graph     Text    Table     List     Math  Percent
 0    Graph  15481.0  12068.0   2718.0    173.0     37.0    485.0   77.954
 1     Text  39796.0   2295.0  35761.0    335.0     37.0   1368.0   89.861
 2    Table   6562.0   1274.0   4218.0    742.0     63.0    265.0   11.308
 3     List   3474.0    312.0   2707.0    165.0    168.0    122.0    4.836
 4     Math   3412.0    711.0   1016.0     48.0      1.0   1636.0   47.948
Stroke accuracy: 50375/68725 (73.2

Epoch  43, loss 0.709829
ID       NE    Total    Graph     Text    Table     List     Math  Percent
 0    Graph  15481.0  11083.0   3640.0    293.0     85.0    380.0   71.591
 1     Text  39796.0   1109.0  37616.0    231.0     67.0    773.0   94.522
 2    Table   6562.0    774.0   4717.0    808.0     92.0    171.0   12.313
 3     List   3474.0    168.0   2820.0    166.0    269.0     51.0    7.743
 4     Math   3412.0    684.0   1376.0     60.0      8.0   1284.0   37.632
Stroke accuracy: 51060/68725 (74.29611%)
ID       NE    Total    Graph     Text    Table     List     Math  Percent
 0    Graph  17488.0  13382.0   3330.0    295.0     85.0    396.0   76.521
 1     Text  40469.0   1045.0  38710.0    217.0     94.0    403.0   95.653
 2    Table   6883.0    583.0   5376.0    748.0     68.0    108.0   10.867
 3     List   3115.0    184.0   2503.0    152.0    238.0     38.0    7.640
 4     Math   2972.0    566.0   1346.0     56.0      9.0    995.0   33.479
Stroke accuracy: 54073/70927 (76.2

ID       NE    Total    Graph     Text    Table     List     Math  Percent
 0    Graph  17488.0  14539.0   2335.0    225.0    101.0    288.0   83.137
 1     Text  40469.0   2325.0  37168.0    362.0    171.0    443.0   91.843
 2    Table   6883.0   1039.0   4768.0    847.0    130.0     99.0   12.306
 3     List   3115.0    344.0   2222.0    162.0    350.0     37.0   11.236
 4     Math   2972.0    858.0   1132.0     62.0     16.0    904.0   30.417
Stroke accuracy: 53808/70927 (75.86392%)
Epoch  52, loss 0.686844
ID       NE    Total    Graph     Text    Table     List     Math  Percent
 0    Graph  15481.0  11929.0   2801.0    200.0    108.0    443.0   77.056
 1     Text  39796.0   2158.0  36047.0    322.0    332.0    937.0   90.579
 2    Table   6562.0   1160.0   4182.0    842.0    173.0    205.0   12.831
 3     List   3474.0    302.0   2500.0    153.0    451.0     68.0   12.982
 4     Math   3412.0    625.0   1216.0     58.0     13.0   1500.0   43.962
Stroke accuracy: 50769/68725 (73.8

Epoch  60, loss 0.681203
ID       NE    Total    Graph     Text    Table     List     Math  Percent
 0    Graph  15481.0  11439.0   3245.0    245.0    140.0    412.0   73.891
 1     Text  39796.0   1516.0  37128.0    170.0    249.0    733.0   93.296
 2    Table   6562.0    889.0   4597.0    704.0    204.0    168.0   10.728
 3     List   3474.0    214.0   2559.0    136.0    511.0     54.0   14.709
 4     Math   3412.0    588.0   1385.0     43.0     15.0   1381.0   40.475
Stroke accuracy: 51163/68725 (74.44598%)
ID       NE    Total    Graph     Text    Table     List     Math  Percent
 0    Graph  17488.0  13742.0   2921.0    250.0    134.0    441.0   78.580
 1     Text  40469.0   1153.0  38447.0    204.0    263.0    402.0   95.004
 2    Table   6883.0    697.0   5279.0    657.0    160.0     90.0    9.545
 3     List   3115.0    229.0   2257.0    130.0    467.0     32.0   14.992
 4     Math   2972.0    501.0   1342.0     41.0     31.0   1057.0   35.565
Stroke accuracy: 54370/70927 (76.6

ID       NE    Total    Graph     Text    Table     List     Math  Percent
 0    Graph  17488.0  13958.0   2723.0    363.0     60.0    384.0   79.815
 1     Text  40469.0   1359.0  38279.0    328.0    108.0    395.0   94.588
 2    Table   6883.0    769.0   4957.0    989.0     79.0     89.0   14.369
 3     List   3115.0    256.0   2345.0    171.0    305.0     38.0    9.791
 4     Math   2972.0    598.0   1262.0     73.0     13.0   1026.0   34.522
Stroke accuracy: 54557/70927 (76.91993%)
Epoch  69, loss 0.669890
ID       NE    Total    Graph     Text    Table     List     Math  Percent
 0    Graph  15481.0  11481.0   3410.0    271.0     53.0    266.0   74.162
 1     Text  39796.0   1308.0  37552.0    369.0     81.0    486.0   94.361
 2    Table   6562.0    894.0   4441.0   1049.0     91.0     87.0   15.986
 3     List   3474.0    192.0   2790.0    198.0    269.0     25.0    7.743
 4     Math   3412.0    712.0   1529.0     86.0      7.0   1078.0   31.594
Stroke accuracy: 51429/68725 (74.8

Epoch  77, loss 0.669851
ID       NE    Total    Graph     Text    Table     List     Math  Percent
 0    Graph  15481.0  11157.0   3281.0    270.0     63.0    710.0   72.069
 1     Text  39796.0   1221.0  37077.0    215.0     77.0   1206.0   93.168
 2    Table   6562.0    857.0   4452.0    809.0    144.0    300.0   12.329
 3     List   3474.0    186.0   2706.0    113.0    364.0    105.0   10.478
 4     Math   3412.0    442.0   1179.0     44.0      6.0   1741.0   51.026
Stroke accuracy: 51148/68725 (74.42415%)
ID       NE    Total    Graph     Text    Table     List     Math  Percent
 0    Graph  17488.0  13475.0   2988.0    231.0     69.0    725.0   77.053
 1     Text  40469.0    833.0  38465.0    259.0     86.0    826.0   95.048
 2    Table   6883.0    656.0   5155.0    793.0     76.0    203.0   11.521
 3     List   3115.0    207.0   2402.0    114.0    320.0     72.0   10.273
 4     Math   2972.0    394.0   1134.0     38.0     13.0   1393.0   46.871
Stroke accuracy: 54446/70927 (76.7

ID       NE    Total    Graph     Text    Table     List     Math  Percent
 0    Graph  17488.0  14129.0   2633.0    328.0     73.0    325.0   80.793
 1     Text  40469.0   1559.0  38075.0    345.0    155.0    335.0   94.084
 2    Table   6883.0    841.0   4780.0   1062.0    127.0     73.0   15.429
 3     List   3115.0    286.0   2263.0    137.0    399.0     30.0   12.809
 4     Math   2972.0    620.0   1306.0     61.0     19.0    966.0   32.503
Stroke accuracy: 54631/70927 (77.02426%)
Epoch  86, loss 0.662333
ID       NE    Total    Graph     Text    Table     List     Math  Percent
 0    Graph  15481.0  11611.0   2833.0    534.0     63.0    440.0   75.002
 1     Text  39796.0   1817.0  36219.0    786.0     94.0    880.0   91.012
 2    Table   6562.0    956.0   3914.0   1412.0     97.0    183.0   21.518
 3     List   3474.0    251.0   2518.0    301.0    335.0     69.0    9.643
 4     Math   3412.0    626.0   1191.0    115.0      5.0   1475.0   43.230
Stroke accuracy: 51052/68725 (74.2

Epoch  94, loss 0.675527
ID       NE    Total    Graph     Text    Table     List     Math  Percent
 0    Graph  15481.0  11935.0   2507.0    452.0     88.0    499.0   77.095
 1     Text  39796.0   2590.0  35570.0    523.0    180.0    933.0   89.381
 2    Table   6562.0   1175.0   3802.0   1236.0    144.0    205.0   18.836
 3     List   3474.0    359.0   2358.0    211.0    473.0     73.0   13.615
 4     Math   3412.0    640.0   1100.0    102.0     10.0   1560.0   45.721
Stroke accuracy: 50774/68725 (73.87996%)
ID       NE    Total    Graph     Text    Table     List     Math  Percent
 0    Graph  17488.0  14245.0   2239.0    401.0     80.0    523.0   81.456
 1     Text  40469.0   2000.0  37060.0    567.0    198.0    644.0   91.576
 2    Table   6883.0    952.0   4406.0   1266.0    136.0    123.0   18.393
 3     List   3115.0    323.0   2100.0    213.0    426.0     53.0   13.676
 4     Math   2972.0    572.0   1062.0     84.0     26.0   1228.0   41.319
Stroke accuracy: 54225/70927 (76.4