In [64]:
import time
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim

In [65]:
%run models.ipynb

In [66]:
%run utils.ipynb

Loading all the outputs of load data from load_data function of utils notebook.

In [67]:
adj, features, labels, idx_train, idx_val, idx_test = load_data()

Loading cora dataset...


Initializing model by passing values in GCN class in models notebook.

In [68]:
model = GCN(nfeat=features.shape[1], nhidd=20, nclasses=labels.max().item()+1, dropout=0)

We will be using Adam Optimizer with learning rate of 0.01.

In [69]:
optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-4)

Training on train dataset using training index

In [74]:
def train(epochs):
    t = time.time()
    
    model.train()
    optimizer.zero_grad()
    output = model(features, adj)
    loss_train = F.nll_loss(output[idx_train], labels[idx_train])
    acc_train = accuracy(output[idx_train], labels[idx_train])
    loss_train.backward()
    optimizer.step()
    
    model.eval()
    output = model(features, adj)
    loss_val = F.nll_loss(output[idx_val], labels[idx_val])
    acc_val = accuracy(output[idx_val], labels[idx_val])
    
    print('Epoch: {:04d}'.format(epoch+1),
          'loss_train: {:.4f}'.format(loss_train.item()),
          'acc_train: {:.4f}'.format(acc_train.item()),
          'loss_val: {:.4f}'.format(loss_val.item()),
          'acc_val: {:.4f}'.format(acc_val.item()),
          'time: {:.4f}s'.format(time.time() - t))

Testing on test dataset using test index.

In [75]:
def test():
    
    model.eval()
    output = model(features, adj)
    loss_test = F.nll_loss(output[idx_test], labels[idx_test])
    acc_test = accuracy(output[idx_test], labels[idx_test])
    
    print("Test set results:",
          "loss= {:.4f}".format(loss_test.item()),
          "accuracy= {:.4f}".format(acc_test.item()))

In [76]:
t_total = time.time()

In [79]:
for epoch in range(400):
    train(epoch)
print("Optimization Finished!")
print("Total time elapsed: {:.4f}s".format(time.time() - t_total))

Epoch: 0001 loss_train: 0.0657 acc_train: 1.0000 loss_val: 0.6257 acc_val: 0.8033 time: 0.0320s
Epoch: 0002 loss_train: 0.0653 acc_train: 1.0000 loss_val: 0.6255 acc_val: 0.8033 time: 0.0360s
Epoch: 0003 loss_train: 0.0649 acc_train: 1.0000 loss_val: 0.6256 acc_val: 0.8033 time: 0.0240s
Epoch: 0004 loss_train: 0.0645 acc_train: 1.0000 loss_val: 0.6262 acc_val: 0.8033 time: 0.0280s
Epoch: 0005 loss_train: 0.0642 acc_train: 1.0000 loss_val: 0.6259 acc_val: 0.8033 time: 0.0280s
Epoch: 0006 loss_train: 0.0638 acc_train: 1.0000 loss_val: 0.6266 acc_val: 0.8033 time: 0.0280s
Epoch: 0007 loss_train: 0.0635 acc_train: 1.0000 loss_val: 0.6268 acc_val: 0.8033 time: 0.0280s
Epoch: 0008 loss_train: 0.0632 acc_train: 1.0000 loss_val: 0.6269 acc_val: 0.8033 time: 0.0320s
Epoch: 0009 loss_train: 0.0628 acc_train: 1.0000 loss_val: 0.6272 acc_val: 0.8033 time: 0.0280s
Epoch: 0010 loss_train: 0.0625 acc_train: 1.0000 loss_val: 0.6280 acc_val: 0.8033 time: 0.0280s
Epoch: 0011 loss_train: 0.0622 acc_train

Epoch: 0090 loss_train: 0.0471 acc_train: 1.0000 loss_val: 0.6496 acc_val: 0.7933 time: 0.0312s
Epoch: 0091 loss_train: 0.0469 acc_train: 1.0000 loss_val: 0.6511 acc_val: 0.7900 time: 0.0318s
Epoch: 0092 loss_train: 0.0468 acc_train: 1.0000 loss_val: 0.6499 acc_val: 0.7933 time: 0.0280s
Epoch: 0093 loss_train: 0.0467 acc_train: 1.0000 loss_val: 0.6515 acc_val: 0.7900 time: 0.0265s
Epoch: 0094 loss_train: 0.0466 acc_train: 1.0000 loss_val: 0.6508 acc_val: 0.7933 time: 0.0320s
Epoch: 0095 loss_train: 0.0465 acc_train: 1.0000 loss_val: 0.6510 acc_val: 0.7900 time: 0.0360s
Epoch: 0096 loss_train: 0.0464 acc_train: 1.0000 loss_val: 0.6524 acc_val: 0.7900 time: 0.0360s
Epoch: 0097 loss_train: 0.0463 acc_train: 1.0000 loss_val: 0.6509 acc_val: 0.7900 time: 0.0320s
Epoch: 0098 loss_train: 0.0462 acc_train: 1.0000 loss_val: 0.6528 acc_val: 0.7900 time: 0.0400s
Epoch: 0099 loss_train: 0.0461 acc_train: 1.0000 loss_val: 0.6523 acc_val: 0.7900 time: 0.0320s
Epoch: 0100 loss_train: 0.0460 acc_train

Epoch: 0182 loss_train: 0.0396 acc_train: 1.0000 loss_val: 0.6682 acc_val: 0.7900 time: 0.0240s
Epoch: 0183 loss_train: 0.0395 acc_train: 1.0000 loss_val: 0.6696 acc_val: 0.7867 time: 0.0360s
Epoch: 0184 loss_train: 0.0395 acc_train: 1.0000 loss_val: 0.6680 acc_val: 0.7900 time: 0.0320s
Epoch: 0185 loss_train: 0.0394 acc_train: 1.0000 loss_val: 0.6701 acc_val: 0.7867 time: 0.0320s
Epoch: 0186 loss_train: 0.0394 acc_train: 1.0000 loss_val: 0.6683 acc_val: 0.7900 time: 0.0320s
Epoch: 0187 loss_train: 0.0393 acc_train: 1.0000 loss_val: 0.6701 acc_val: 0.7867 time: 0.0320s
Epoch: 0188 loss_train: 0.0392 acc_train: 1.0000 loss_val: 0.6690 acc_val: 0.7900 time: 0.0360s
Epoch: 0189 loss_train: 0.0392 acc_train: 1.0000 loss_val: 0.6699 acc_val: 0.7900 time: 0.0280s
Epoch: 0190 loss_train: 0.0391 acc_train: 1.0000 loss_val: 0.6700 acc_val: 0.7900 time: 0.0280s
Epoch: 0191 loss_train: 0.0391 acc_train: 1.0000 loss_val: 0.6695 acc_val: 0.7900 time: 0.0280s
Epoch: 0192 loss_train: 0.0390 acc_train

Epoch: 0273 loss_train: 0.0354 acc_train: 1.0000 loss_val: 0.6808 acc_val: 0.7833 time: 0.0320s
Epoch: 0274 loss_train: 0.0354 acc_train: 1.0000 loss_val: 0.6811 acc_val: 0.7833 time: 0.0360s
Epoch: 0275 loss_train: 0.0354 acc_train: 1.0000 loss_val: 0.6805 acc_val: 0.7833 time: 0.0280s
Epoch: 0276 loss_train: 0.0353 acc_train: 1.0000 loss_val: 0.6818 acc_val: 0.7833 time: 0.0280s
Epoch: 0277 loss_train: 0.0353 acc_train: 1.0000 loss_val: 0.6803 acc_val: 0.7833 time: 0.0320s
Epoch: 0278 loss_train: 0.0353 acc_train: 1.0000 loss_val: 0.6823 acc_val: 0.7833 time: 0.0320s
Epoch: 0279 loss_train: 0.0352 acc_train: 1.0000 loss_val: 0.6804 acc_val: 0.7833 time: 0.0280s
Epoch: 0280 loss_train: 0.0352 acc_train: 1.0000 loss_val: 0.6822 acc_val: 0.7833 time: 0.0360s
Epoch: 0281 loss_train: 0.0351 acc_train: 1.0000 loss_val: 0.6811 acc_val: 0.7833 time: 0.0240s
Epoch: 0282 loss_train: 0.0351 acc_train: 1.0000 loss_val: 0.6818 acc_val: 0.7833 time: 0.0280s
Epoch: 0283 loss_train: 0.0351 acc_train

Epoch: 0363 loss_train: 0.0328 acc_train: 1.0000 loss_val: 0.6886 acc_val: 0.7800 time: 0.0320s
Epoch: 0364 loss_train: 0.0328 acc_train: 1.0000 loss_val: 0.6877 acc_val: 0.7800 time: 0.0320s
Epoch: 0365 loss_train: 0.0328 acc_train: 1.0000 loss_val: 0.6886 acc_val: 0.7800 time: 0.0320s
Epoch: 0366 loss_train: 0.0327 acc_train: 1.0000 loss_val: 0.6883 acc_val: 0.7800 time: 0.0320s
Epoch: 0367 loss_train: 0.0327 acc_train: 1.0000 loss_val: 0.6882 acc_val: 0.7800 time: 0.0280s
Epoch: 0368 loss_train: 0.0327 acc_train: 1.0000 loss_val: 0.6889 acc_val: 0.7800 time: 0.0320s
Epoch: 0369 loss_train: 0.0327 acc_train: 1.0000 loss_val: 0.6880 acc_val: 0.7800 time: 0.0320s
Epoch: 0370 loss_train: 0.0326 acc_train: 1.0000 loss_val: 0.6891 acc_val: 0.7800 time: 0.0360s
Epoch: 0371 loss_train: 0.0326 acc_train: 1.0000 loss_val: 0.6881 acc_val: 0.7800 time: 0.0320s
Epoch: 0372 loss_train: 0.0326 acc_train: 1.0000 loss_val: 0.6892 acc_val: 0.7800 time: 0.0280s
Epoch: 0373 loss_train: 0.0326 acc_train

Accuracy comes out to be 81.8% on test dataset.

In [80]:
test()

Test set results: loss= 0.5781 accuracy= 0.8180
