In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch as th
import os
from dataset import StrokeDataset, collate
from torch.utils.data import DataLoader

num_classes = 2

home = os.environ['HOME']
data_dir = "./data"
trainset = StrokeDataset(os.path.join(data_dir, "train"), num_classes)
validset = StrokeDataset(os.path.join(data_dir, "valid"), num_classes)
testset = StrokeDataset(os.path.join(data_dir, "test"), num_classes)

In [3]:
device = th.device('cuda')

train_loader = DataLoader(trainset,
                          batch_size=32,
                          shuffle=True,
                          collate_fn=collate(device))

valid_loader = DataLoader(validset,
                          batch_size=32,
                          shuffle=False,
                          collate_fn=collate(device))

test_loader = DataLoader(testset,
                         batch_size=32,
                         shuffle=False,
                         collate_fn=collate(device))

In [4]:
from sklearn.metrics import confusion_matrix
import numpy as np

def print_result(count):
    num_classes = count.shape[0]
    if num_classes == 5:
        id_to_tag = ['Graph', 'Text', 'Table', 'List', 'Math']
    elif num_classes == 2:
        id_to_tag = ['Non-text', 'Text']
    
    # Confusion matrix with accuracy for each tag
    print (("{: >2}{: >9}{: >9}%s{: >9}" % ("{: >9}" * num_classes)).format(
        "ID", "NE", "Total",
        *([id_to_tag[i] for i in range(num_classes)] + ["Percent"]))
    )
    for i in range(num_classes):
        print (("{: >2}{: >9}{: >9}%s{: >9}" % ("{: >9}" * num_classes)).format(
            str(i), id_to_tag[i], str(count[i].sum()),
            *([count[i][j] for j in range(num_classes)] +
              ["%.3f" % (count[i][i] * 100. / max(1, count[i].sum()))])
        ))

    # Global accuracy
    accuracy = 100. * count.trace() / max(1, count.sum())
    print ("Stroke accuracy: %i/%i (%.5f%%)" % (
        count.trace(), count.sum(), accuracy)
    )
    
def evaluate(model, loader, num_classes, name):
    model.eval()
    print(name + ":")
    count = np.zeros((num_classes, num_classes), dtype=np.int32)
    for it, (fg, lg) in enumerate(loader):
        logits = model(fg)
        _, predictions = th.max(logits, dim=1)
        labels = lg.ndata['y']
        predictions = predictions.cpu().numpy()
        labels = labels.cpu().numpy()
        count += confusion_matrix(labels, predictions)
    model.train()
    
    print_result(count)
    
    accuracy = 100. * count.trace() / max(1, count.sum())
    return accuracy, count

In [5]:
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

from gcn import GCNNet
from gat import GAT

in_feats = 23
edge_f_dim = 19
hidden_feats = 8
num_heads = 8
num_out_heads = 8
num_layers = 3
residual = True
in_drop = 0
attn_drop = 0.2
lr = 5e-3
weight_decay = 5e-4
alpha = 0.2

num_epoches = 10
heads = ([num_heads] * num_layers) + [num_out_heads]

# model = GCNNet(in_feats, hidden_feats, num_classes).to(device)
model = GAT(num_layers, in_feats, edge_f_dim, hidden_feats, num_classes,
            heads, F.elu, in_drop, attn_drop, alpha, residual).to(device)
loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

epoch_losses = []
best_valid_acc = 0
best_test_acc = 0
best_round = 0

for epoch in range(num_epoches):
    epoch_loss = 0
    for it, (fg, lg) in enumerate(train_loader):
        logits = model(fg)
        labels = lg.ndata['y']
        loss = loss_func(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.detach().item()
        
    epoch_loss /= (it + 1)
    print('Epoch {:3d}, loss {:4f}'.format(epoch, epoch_loss))
    epoch_losses.append(epoch_loss)
    
    train_acc, _= evaluate(model, train_loader, num_classes, "train")
    valid_acc, _ = evaluate(model, valid_loader, num_classes, "valid")
    if valid_acc > best_valid_acc:
        test_acc, test_conf_mat = evaluate(model, test_loader, num_classes, "test")
        if test_acc > best_test_acc:
            best_test_acc = test_acc
            best_conf_mat = test_conf_mat
            best_round = epoch

print("Best round: %d" % best_round)
print_result(best_conf_mat)



Epoch   0, loss 0.485049
train:
ID       NE    Total Non-text     Text  Percent
 0 Non-text    26547    20772     5775   78.246
 1     Text   116801     4181   112620   96.420
Stroke accuracy: 133392/143348 (93.05466%)
valid:
ID       NE    Total Non-text     Text  Percent
 0 Non-text    10904     8024     2880   73.588
 1     Text    57821     2137    55684   96.304
Stroke accuracy: 63708/68725 (92.69989%)
test:
ID       NE    Total Non-text     Text  Percent
 0 Non-text    12968    10366     2602   79.935
 1     Text    57959     2177    55782   96.244
Stroke accuracy: 66148/70927 (93.26209%)
Epoch   1, loss 0.204056
train:
ID       NE    Total Non-text     Text  Percent
 0 Non-text    26547    21393     5154   80.585
 1     Text   116801     2826   113975   97.581
Stroke accuracy: 135368/143348 (94.43313%)
valid:
ID       NE    Total Non-text     Text  Percent
 0 Non-text    10904     8332     2572   76.412
 1     Text    57821     1690    56131   97.077
Stroke accuracy: 64463/68725