In [6]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
import torch as th
import os
from dataset import StrokeDataset, collate
from torch.utils.data import DataLoader

home = os.environ['HOME']
data_dir = os.path.join(home, "data/IAMonDo/cls_5class")
trainset = StrokeDataset(os.path.join(data_dir, "train"))
validset = StrokeDataset(os.path.join(data_dir, "valid"))
testset = StrokeDataset(os.path.join(data_dir, "test"))

In [8]:
device = th.device('cpu')

train_loader = DataLoader(trainset,
                          batch_size=32,
                          shuffle=True,
                          collate_fn=collate(device))

valid_loader = DataLoader(validset,
                          batch_size=32,
                          shuffle=False,
                          collate_fn=collate(device))

test_loader = DataLoader(testset,
                         batch_size=32,
                         shuffle=False,
                         collate_fn=collate(device))

In [9]:
from sklearn.metrics import confusion_matrix
import numpy as np

def print_result(count):
    num_classes = count.shape[0]
    if num_classes == 5:
        id_to_tag = ['Graph', 'Text', 'Table', 'List', 'Math']
    elif num_classes == 2:
        id_to_tag = ['Non-text', 'Text']
    
    # Confusion matrix with accuracy for each tag
    print (("{: >2}{: >9}{: >9}%s{: >9}" % ("{: >9}" * num_classes)).format(
        "ID", "NE", "Total",
        *([id_to_tag[i] for i in range(num_classes)] + ["Percent"]))
    )
    for i in range(num_classes):
        print (("{: >2}{: >9}{: >9}%s{: >9}" % ("{: >9}" * num_classes)).format(
            str(i), id_to_tag[i], str(count[i].sum()),
            *([count[i][j] for j in range(num_classes)] +
              ["%.3f" % (count[i][i] * 100. / max(1, count[i].sum()))])
        ))

    # Global accuracy
    accuracy = 100. * count.trace() / max(1, count.sum())
    print ("Stroke accuracy: %i/%i (%.5f%%)" % (
        count.trace(), count.sum(), accuracy)
    )
    
def evaluate(model, loader, num_classes):
    model.eval()
    count = np.zeros((num_classes, num_classes))
    for it, (fg, lg) in enumerate(loader):
        logits = model(fg)
        _, predictions = th.max(logits, dim=1)
        labels = lg.ndata['y']
        predictions = predictions.cpu().numpy()
        labels = labels.cpu().numpy()
        count += confusion_matrix(labels, predictions)
    model.train()
    
    print_result(count)
    
    accuracy = 100. * count.trace() / max(1, count.sum())
    return accuracy, count

In [19]:
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

from gcn import GCNNet
from gat import GAT

in_feats = 23
hidden_feats = 30
num_heads = 8
num_out_heads = 1
num_layers = 1
residual = True
in_drop = 0.6
attn_drop = 0.6
lr = 1e-2
weight_decay = 5e-4
alpha = 0.2

num_classes = 5
num_epoches = 200
heads = ([num_heads] * num_layers) + [num_out_heads]

# model = GCNNet(in_feats, hidden_feats, num_classes).to(device)
model = GAT(num_layers, in_feats, hidden_feats, num_classes,
            heads, F.elu, in_drop, attn_drop, alpha, residual)
loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

epoch_losses = []
best_valid_acc = 0
best_test_acc = 0
best_round = 0

for epoch in range(num_epoches):
    epoch_loss = 0
    for it, (fg, lg) in enumerate(train_loader):
        logits = model(fg)
        labels = lg.ndata['y']
        loss = loss_func(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.detach().item()
        
    epoch_loss /= (it + 1)
    print('Epoch {:3d}, loss {:4f}'.format(epoch, epoch_loss))
    epoch_losses.append(epoch_loss)
    
    valid_acc, _ = evaluate(model, valid_loader, num_classes)
    if valid_acc > best_valid_acc:
        test_acc, test_conf_mat = evaluate(model, test_loader, num_classes)
        if test_acc > best_test_acc:
            best_test_acc = test_acc
            best_conf_mat = test_conf_mat
            best_round = epoch

print("Best round: %d" % best_round)
print_result(best_conf_mat)

Epoch   0, loss 2.942719
ID       NE    Total    Graph     Text    Table     List     Math  Percent
 0    Graph  15481.0  10279.0   3558.0    159.0     77.0   1408.0   66.398
 1     Text  39796.0   5329.0  29825.0    595.0    410.0   3637.0   74.945
 2    Table   6562.0   1518.0   3808.0    293.0    137.0    806.0    4.465
 3     List   3474.0    496.0   2375.0     95.0    105.0    403.0    3.022
 4     Math   3412.0    833.0    781.0     84.0     74.0   1640.0   48.066
Stroke accuracy: 42142/68725 (61.31975%)
ID       NE    Total    Graph     Text    Table     List     Math  Percent
 0    Graph  17488.0  12884.0   2707.0    177.0    105.0   1615.0   73.673
 1     Text  40469.0   4647.0  30382.0    574.0    577.0   4289.0   75.075
 2    Table   6883.0   1280.0   4331.0    285.0    161.0    826.0    4.141
 3     List   3115.0    512.0   1992.0     68.0     91.0    452.0    2.921
 4     Math   2972.0    802.0    654.0     72.0     53.0   1391.0   46.803
Stroke accuracy: 45033/70927 (63.4

KeyboardInterrupt: 