https://github.com/chengjun/Research/blob/master/GAT.ipynb

In [1]:

import os
import glob
import time
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

In [2]:
from utils import load_data, accuracy
from models import GAT, SpGAT

In [14]:
class Args:
    def __init__(self):
        self.no_cuda=False
        self.cuda=True
        self.fastmode=False
        self.sparse=False
        self.seed=72
        self.epochs=800
        self.lr=5e-3
        self.weight_decay=5e-4
        self.hidden=8
        self.nb_heads=8
        self.dropout=0.6
        self.alpha=0.2
        self.patience=100
        
args = Args()

In [4]:

random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)

# Load data
adj, features, labels, idx_train, idx_val, idx_test = load_data()

Loading cora dataset...


In [15]:
# Model and optimizer
if args.sparse:
    model = SpGAT(nfeat=features.shape[1], 
                nhid=args.hidden, 
                nclass=int(labels.max()) + 1, 
                dropout=args.dropout, 
                nheads=args.nb_heads, 
                alpha=args.alpha)
else:
    model = GAT(nfeat=features.shape[1], 
                nhid=args.hidden, 
                nclass=int(labels.max()) + 1, 
                dropout=args.dropout, 
                nheads=args.nb_heads, 
                alpha=args.alpha)
    
optimizer = optim.Adam(model.parameters(), 
                       lr=args.lr, 
                       weight_decay=args.weight_decay)

In [6]:
int(labels.max()) + 1


7

In [16]:
if args.cuda:
    model.cuda()
    features = features.cuda()
    adj = adj.cuda()
    labels = labels.cuda()
    idx_train = idx_train.cuda()
    idx_val = idx_val.cuda()
    idx_test = idx_test.cuda()

features, adj, labels = Variable(features), Variable(adj), Variable(labels)

In [17]:
def train(epoch):
    t = time.time()
    model.train()
    optimizer.zero_grad()
    output = model(features, adj)
    loss_train = F.nll_loss(output[idx_train], labels[idx_train])
    acc_train = accuracy(output[idx_train], labels[idx_train])
    loss_train.backward()
    optimizer.step()

    if not args.fastmode:
        # Evaluate validation set performance separately,
        # deactivates dropout during validation run.
        model.eval()
        output = model(features, adj)

    loss_val = F.nll_loss(output[idx_val], labels[idx_val])
    acc_val = accuracy(output[idx_val], labels[idx_val])
    print('Epoch: {:04d}'.format(epoch+1),
          'loss_train: {:.4f}'.format(loss_train.item()),
          'acc_train: {:.4f}'.format(acc_train.item()),
          'loss_val: {:.4f}'.format(loss_val.item()),
          'acc_val: {:.4f}'.format(acc_val.item()),
          'time: {:.4f}s'.format(time.time() - t))

    return loss_val.item()


def compute_test():
    model.eval()
    output = model(features, adj)
    loss_test = F.nll_loss(output[idx_test], labels[idx_test])
    acc_test = accuracy(output[idx_test], labels[idx_test])
    print("Test set results:",
          "loss= {:.4f}".format(loss_test.item()),
          "accuracy= {:.4f}".format(acc_test.item()))

# Train model
t_total = time.time()
loss_values = []
bad_counter = 0
best = args.epochs + 1
best_epoch = 0
for epoch in range(args.epochs):
    loss_values.append(train(epoch))

    torch.save(model.state_dict(), '{}.pkl'.format(epoch))
    if loss_values[-1] < best:
        best = loss_values[-1]
        best_epoch = epoch
        bad_counter = 0
    else:
        bad_counter += 1

    if bad_counter == args.patience:
        break

    files = glob.glob('*.pkl')
    for file in files:
        epoch_nb = int(file.split('.')[0])
        if epoch_nb < best_epoch:
            os.remove(file)

files = glob.glob('*.pkl')
for file in files:
    epoch_nb = int(file.split('.')[0])
    if epoch_nb > best_epoch:
        os.remove(file)

print("Optimization Finished!")
print("Total time elapsed: {:.4f}s".format(time.time() - t_total))

# Restore best model
print('Loading {}th epoch'.format(best_epoch))
model.load_state_dict(torch.load('{}.pkl'.format(best_epoch)))

# Testing
compute_test()


Epoch: 0001 loss_train: 1.9405 acc_train: 0.1786 loss_val: 1.9368 acc_val: 0.3733 time: 1.8959s
Epoch: 0002 loss_train: 1.9409 acc_train: 0.2143 loss_val: 1.9270 acc_val: 0.4767 time: 1.8800s
Epoch: 0003 loss_train: 1.9264 acc_train: 0.3643 loss_val: 1.9171 acc_val: 0.4933 time: 1.9199s
Epoch: 0004 loss_train: 1.8969 acc_train: 0.5000 loss_val: 1.9069 acc_val: 0.5033 time: 1.8750s
Epoch: 0005 loss_train: 1.9021 acc_train: 0.4357 loss_val: 1.8965 acc_val: 0.5133 time: 1.9219s
Epoch: 0006 loss_train: 1.8796 acc_train: 0.4857 loss_val: 1.8859 acc_val: 0.4900 time: 1.8790s
Epoch: 0007 loss_train: 1.8797 acc_train: 0.4714 loss_val: 1.8754 acc_val: 0.5033 time: 1.8969s
Epoch: 0008 loss_train: 1.8608 acc_train: 0.5143 loss_val: 1.8648 acc_val: 0.4967 time: 1.9538s
Epoch: 0009 loss_train: 1.8491 acc_train: 0.5214 loss_val: 1.8540 acc_val: 0.4933 time: 1.9717s
Epoch: 0010 loss_train: 1.8439 acc_train: 0.4786 loss_val: 1.8431 acc_val: 0.4800 time: 1.8949s
Epoch: 0011 loss_train: 1.8182 acc_train

Epoch: 0087 loss_train: 1.0172 acc_train: 0.7857 loss_val: 1.1048 acc_val: 0.8300 time: 1.9917s
Epoch: 0088 loss_train: 1.0394 acc_train: 0.7571 loss_val: 1.0982 acc_val: 0.8300 time: 1.8900s
Epoch: 0089 loss_train: 1.0260 acc_train: 0.7643 loss_val: 1.0918 acc_val: 0.8300 time: 1.9907s
Epoch: 0090 loss_train: 1.0301 acc_train: 0.7786 loss_val: 1.0855 acc_val: 0.8300 time: 1.9129s
Epoch: 0091 loss_train: 1.0136 acc_train: 0.7643 loss_val: 1.0789 acc_val: 0.8300 time: 2.0126s
Epoch: 0092 loss_train: 1.0552 acc_train: 0.7643 loss_val: 1.0726 acc_val: 0.8300 time: 1.8999s
Epoch: 0093 loss_train: 0.9770 acc_train: 0.8071 loss_val: 1.0663 acc_val: 0.8333 time: 1.9647s
Epoch: 0094 loss_train: 0.9803 acc_train: 0.7429 loss_val: 1.0602 acc_val: 0.8333 time: 1.9763s
Epoch: 0095 loss_train: 1.0522 acc_train: 0.7714 loss_val: 1.0541 acc_val: 0.8333 time: 1.9169s
Epoch: 0096 loss_train: 0.9953 acc_train: 0.7571 loss_val: 1.0481 acc_val: 0.8333 time: 1.9348s
Epoch: 0097 loss_train: 1.0536 acc_train

Epoch: 0173 loss_train: 0.6907 acc_train: 0.8286 loss_val: 0.7969 acc_val: 0.8267 time: 1.9757s
Epoch: 0174 loss_train: 0.8033 acc_train: 0.8071 loss_val: 0.7950 acc_val: 0.8233 time: 1.9378s
Epoch: 0175 loss_train: 0.8068 acc_train: 0.7929 loss_val: 0.7932 acc_val: 0.8233 time: 2.0545s
Epoch: 0176 loss_train: 0.7218 acc_train: 0.8429 loss_val: 0.7915 acc_val: 0.8233 time: 2.0292s
Epoch: 0177 loss_train: 0.7551 acc_train: 0.8000 loss_val: 0.7898 acc_val: 0.8267 time: 1.9999s
Epoch: 0178 loss_train: 0.7345 acc_train: 0.8000 loss_val: 0.7883 acc_val: 0.8267 time: 1.9952s
Epoch: 0179 loss_train: 0.6663 acc_train: 0.8714 loss_val: 0.7867 acc_val: 0.8267 time: 1.9592s
Epoch: 0180 loss_train: 0.7548 acc_train: 0.7929 loss_val: 0.7851 acc_val: 0.8233 time: 1.9953s
Epoch: 0181 loss_train: 0.7679 acc_train: 0.7857 loss_val: 0.7838 acc_val: 0.8233 time: 1.9588s
Epoch: 0182 loss_train: 0.7960 acc_train: 0.7857 loss_val: 0.7825 acc_val: 0.8233 time: 1.9221s
Epoch: 0183 loss_train: 0.8565 acc_train

Epoch: 0259 loss_train: 0.6399 acc_train: 0.8143 loss_val: 0.7200 acc_val: 0.8233 time: 1.9777s
Epoch: 0260 loss_train: 0.7204 acc_train: 0.8214 loss_val: 0.7191 acc_val: 0.8233 time: 1.9756s
Epoch: 0261 loss_train: 0.7189 acc_train: 0.8143 loss_val: 0.7183 acc_val: 0.8233 time: 1.9416s
Epoch: 0262 loss_train: 0.7633 acc_train: 0.7643 loss_val: 0.7175 acc_val: 0.8267 time: 1.8882s
Epoch: 0263 loss_train: 0.7179 acc_train: 0.7429 loss_val: 0.7169 acc_val: 0.8267 time: 1.9568s
Epoch: 0264 loss_train: 0.7484 acc_train: 0.7786 loss_val: 0.7167 acc_val: 0.8267 time: 1.8971s
Epoch: 0265 loss_train: 0.7124 acc_train: 0.8143 loss_val: 0.7167 acc_val: 0.8233 time: 1.9578s
Epoch: 0266 loss_train: 0.6900 acc_train: 0.7857 loss_val: 0.7168 acc_val: 0.8233 time: 1.9139s
Epoch: 0267 loss_train: 0.6235 acc_train: 0.8286 loss_val: 0.7167 acc_val: 0.8233 time: 1.8660s
Epoch: 0268 loss_train: 0.7217 acc_train: 0.7786 loss_val: 0.7165 acc_val: 0.8233 time: 1.9189s
Epoch: 0269 loss_train: 0.7240 acc_train

Epoch: 0345 loss_train: 0.7463 acc_train: 0.7286 loss_val: 0.6964 acc_val: 0.8300 time: 1.9149s
Epoch: 0346 loss_train: 0.7097 acc_train: 0.8000 loss_val: 0.6965 acc_val: 0.8300 time: 1.9466s
Epoch: 0347 loss_train: 0.6897 acc_train: 0.7857 loss_val: 0.6966 acc_val: 0.8300 time: 1.9874s
Epoch: 0348 loss_train: 0.7091 acc_train: 0.7643 loss_val: 0.6963 acc_val: 0.8300 time: 1.9708s
Epoch: 0349 loss_train: 0.5971 acc_train: 0.8571 loss_val: 0.6960 acc_val: 0.8300 time: 1.9695s
Epoch: 0350 loss_train: 0.6285 acc_train: 0.8286 loss_val: 0.6958 acc_val: 0.8300 time: 2.0216s
Epoch: 0351 loss_train: 0.7100 acc_train: 0.7929 loss_val: 0.6955 acc_val: 0.8300 time: 1.9249s
Epoch: 0352 loss_train: 0.7112 acc_train: 0.7857 loss_val: 0.6954 acc_val: 0.8300 time: 1.9719s
Epoch: 0353 loss_train: 0.7061 acc_train: 0.8071 loss_val: 0.6951 acc_val: 0.8300 time: 1.9742s
Epoch: 0354 loss_train: 0.6072 acc_train: 0.8214 loss_val: 0.6949 acc_val: 0.8300 time: 1.8939s
Epoch: 0355 loss_train: 0.6743 acc_train

Epoch: 0431 loss_train: 0.6642 acc_train: 0.8071 loss_val: 0.6717 acc_val: 0.8233 time: 1.9628s
Epoch: 0432 loss_train: 0.6207 acc_train: 0.8357 loss_val: 0.6715 acc_val: 0.8233 time: 1.9298s
Epoch: 0433 loss_train: 0.6235 acc_train: 0.8357 loss_val: 0.6712 acc_val: 0.8233 time: 1.8570s
Epoch: 0434 loss_train: 0.6272 acc_train: 0.8357 loss_val: 0.6709 acc_val: 0.8233 time: 1.9378s
Epoch: 0435 loss_train: 0.7320 acc_train: 0.7786 loss_val: 0.6706 acc_val: 0.8233 time: 1.9887s
Epoch: 0436 loss_train: 0.6260 acc_train: 0.8214 loss_val: 0.6705 acc_val: 0.8233 time: 1.9667s
Epoch: 0437 loss_train: 0.5773 acc_train: 0.8214 loss_val: 0.6703 acc_val: 0.8233 time: 1.9378s
Epoch: 0438 loss_train: 0.6275 acc_train: 0.8286 loss_val: 0.6701 acc_val: 0.8200 time: 1.9817s
Epoch: 0439 loss_train: 0.6445 acc_train: 0.8143 loss_val: 0.6700 acc_val: 0.8200 time: 1.9318s
Epoch: 0440 loss_train: 0.6419 acc_train: 0.8071 loss_val: 0.6697 acc_val: 0.8200 time: 1.9268s
Epoch: 0441 loss_train: 0.6745 acc_train

Epoch: 0517 loss_train: 0.6256 acc_train: 0.8214 loss_val: 0.6640 acc_val: 0.8100 time: 1.9047s
Epoch: 0518 loss_train: 0.5524 acc_train: 0.8500 loss_val: 0.6643 acc_val: 0.8133 time: 1.9784s
Epoch: 0519 loss_train: 0.6083 acc_train: 0.8429 loss_val: 0.6649 acc_val: 0.8133 time: 2.0537s
Epoch: 0520 loss_train: 0.6728 acc_train: 0.8000 loss_val: 0.6655 acc_val: 0.8133 time: 1.9664s
Epoch: 0521 loss_train: 0.5656 acc_train: 0.8643 loss_val: 0.6662 acc_val: 0.8133 time: 2.0276s
Epoch: 0522 loss_train: 0.6135 acc_train: 0.8000 loss_val: 0.6669 acc_val: 0.8133 time: 2.0426s
Epoch: 0523 loss_train: 0.6115 acc_train: 0.8500 loss_val: 0.6676 acc_val: 0.8100 time: 2.0510s
Epoch: 0524 loss_train: 0.6125 acc_train: 0.8143 loss_val: 0.6680 acc_val: 0.8100 time: 2.0071s
Epoch: 0525 loss_train: 0.6122 acc_train: 0.8357 loss_val: 0.6685 acc_val: 0.8133 time: 1.9825s
Epoch: 0526 loss_train: 0.5611 acc_train: 0.8500 loss_val: 0.6689 acc_val: 0.8167 time: 2.0987s
Epoch: 0527 loss_train: 0.6145 acc_train

Epoch: 0603 loss_train: 0.6924 acc_train: 0.8286 loss_val: 0.6575 acc_val: 0.8267 time: 1.9249s
Epoch: 0604 loss_train: 0.6165 acc_train: 0.8286 loss_val: 0.6578 acc_val: 0.8267 time: 1.9980s
Epoch: 0605 loss_train: 0.6045 acc_train: 0.8500 loss_val: 0.6579 acc_val: 0.8267 time: 1.9548s
Epoch: 0606 loss_train: 0.5579 acc_train: 0.8571 loss_val: 0.6579 acc_val: 0.8233 time: 2.0017s
Epoch: 0607 loss_train: 0.6240 acc_train: 0.8071 loss_val: 0.6577 acc_val: 0.8233 time: 1.9428s
Epoch: 0608 loss_train: 0.6166 acc_train: 0.8286 loss_val: 0.6576 acc_val: 0.8233 time: 1.9109s
Epoch: 0609 loss_train: 0.6192 acc_train: 0.8143 loss_val: 0.6573 acc_val: 0.8233 time: 1.9972s
Epoch: 0610 loss_train: 0.6795 acc_train: 0.7857 loss_val: 0.6572 acc_val: 0.8233 time: 2.0078s
Epoch: 0611 loss_train: 0.6018 acc_train: 0.8286 loss_val: 0.6570 acc_val: 0.8233 time: 1.9498s
Epoch: 0612 loss_train: 0.5646 acc_train: 0.8429 loss_val: 0.6572 acc_val: 0.8233 time: 1.9754s
Epoch: 0613 loss_train: 0.6323 acc_train

Epoch: 0689 loss_train: 0.6540 acc_train: 0.7786 loss_val: 0.6594 acc_val: 0.8200 time: 1.9318s
Epoch: 0690 loss_train: 0.6000 acc_train: 0.8143 loss_val: 0.6601 acc_val: 0.8200 time: 1.8840s
Epoch: 0691 loss_train: 0.6367 acc_train: 0.8071 loss_val: 0.6607 acc_val: 0.8200 time: 1.9606s
Epoch: 0692 loss_train: 0.5640 acc_train: 0.8500 loss_val: 0.6611 acc_val: 0.8200 time: 1.9495s
Epoch: 0693 loss_train: 0.7173 acc_train: 0.7786 loss_val: 0.6615 acc_val: 0.8200 time: 2.0293s
Epoch: 0694 loss_train: 0.5990 acc_train: 0.8286 loss_val: 0.6616 acc_val: 0.8233 time: 1.8880s
Epoch: 0695 loss_train: 0.6208 acc_train: 0.8286 loss_val: 0.6612 acc_val: 0.8200 time: 1.9294s
Epoch: 0696 loss_train: 0.5860 acc_train: 0.8000 loss_val: 0.6608 acc_val: 0.8200 time: 2.0513s
Epoch: 0697 loss_train: 0.6019 acc_train: 0.8286 loss_val: 0.6601 acc_val: 0.8233 time: 2.0556s
Epoch: 0698 loss_train: 0.5352 acc_train: 0.8571 loss_val: 0.6593 acc_val: 0.8233 time: 1.9689s
Epoch: 0699 loss_train: 0.6437 acc_train

Epoch: 0775 loss_train: 0.4916 acc_train: 0.8571 loss_val: 0.6685 acc_val: 0.8300 time: 1.9917s
Epoch: 0776 loss_train: 0.5853 acc_train: 0.8286 loss_val: 0.6687 acc_val: 0.8300 time: 1.9239s
Epoch: 0777 loss_train: 0.5871 acc_train: 0.8214 loss_val: 0.6686 acc_val: 0.8233 time: 2.0246s
Epoch: 0778 loss_train: 0.5753 acc_train: 0.8571 loss_val: 0.6686 acc_val: 0.8200 time: 1.8780s
Epoch: 0779 loss_train: 0.6818 acc_train: 0.7929 loss_val: 0.6686 acc_val: 0.8200 time: 1.8991s
Epoch: 0780 loss_train: 0.7224 acc_train: 0.8214 loss_val: 0.6685 acc_val: 0.8200 time: 1.9697s
Epoch: 0781 loss_train: 0.5112 acc_train: 0.9000 loss_val: 0.6685 acc_val: 0.8200 time: 2.0018s
Epoch: 0782 loss_train: 0.6965 acc_train: 0.7786 loss_val: 0.6682 acc_val: 0.8167 time: 1.8760s
Epoch: 0783 loss_train: 0.6606 acc_train: 0.8143 loss_val: 0.6678 acc_val: 0.8167 time: 1.8999s
Epoch: 0784 loss_train: 0.6443 acc_train: 0.7786 loss_val: 0.6670 acc_val: 0.8167 time: 1.9268s
Epoch: 0785 loss_train: 0.5469 acc_train

- Epoch: 0001 loss_train: 1.6168 acc_train: 0.5071 loss_val: 1.6278 acc_val: 0.5767 time: 31.6586s
- Epoch: 0002 loss_train: 1.5773 acc_train: 0.5714 loss_val: 1.6154 acc_val: 0.5767 time: 30.0393s

Optimization Finished!
- Total time elapsed: 62.3707s
- Loading 1th epoch
- Test set results: loss= 1.6820 accuracy= 0.4440

In [20]:
1567.3155/800

1.959144375

In [21]:
compute_test()

Test set results: loss= 0.6624 accuracy= 0.8470
