In [1]:
from src.get_data import load_benchmark, load_synthetic
from src.normalization import get_adj_feats
from src.args import get_args
from src.models import get_model
from src.utils import accuracy, LDA_loss
from src.plots import plot_feature
import torch.optim as optim
import torch
import torch.nn.functional as F
import time
import matplotlib.pyplot as plt
import numpy as np
import pickle as pkl

In [2]:
# load dataset
# all tensor, dense
dataset_name = 'citeseer'
# dataset_name = input('input dataset name: cora/citeseer/pubmed/...')

adj, feats, labels, idx_train, idx_val, idx_test = load_benchmark(dataset_name)
# adj, feats, labels, idx_train, idx_val, idx_test = load_synthetic(dataset_name)

Loading citeseer dataset...


  r_inv = np.power(rowsum, -1).flatten()


finish load data


In [3]:
# get args
# model_name = input('choose model: GCN/SGC/GFNN/GFN/AGNN/GIN/...')
model_name = 'AGNN'
args = get_args(model_opt = model_name, dataset = dataset_name)
weights= []

In [4]:
adj, feats = get_adj_feats(adj = adj, feats = feats, model_opt = model_name, degree = args.degree, weights = weights)

In [5]:
nb_class = (torch.max(labels) + 1).numpy()
Y_onehot =  torch.zeros(labels.shape[0], nb_class).scatter_(1, labels.unsqueeze(-1), 1)

nb_each_class_train = torch.sum(Y_onehot[idx_train], dim = 0)
nb_each_class_inv_train = torch.tensor(np.power(nb_each_class_train.numpy(), -1).flatten())
nb_each_class_inv_mat_train = torch.diag(nb_each_class_inv_train)

nb_each_class_val = torch.sum(Y_onehot[idx_val], dim = 0)
nb_each_class_inv_val = torch.tensor(np.power(nb_each_class_val.numpy(), -1).flatten())
nb_each_class_inv_mat_val = torch.diag(nb_each_class_inv_val)

nb_each_class_test = torch.sum(Y_onehot[idx_test], dim = 0)
nb_each_class_inv_test = torch.tensor(np.power(nb_each_class_test.numpy(), -1).flatten())
nb_each_class_inv_mat_test = torch.diag(nb_each_class_inv_test)

In [6]:
# get model
model = get_model(model_opt = model_name, nfeat = feats.size(1), \
                  nclass = labels.max().item()+1, nhid = args.hidden, \
                  dropout = args.dropout, cuda = args.cuda, \
                  dataset = dataset_name, degree = args.degree)
# optimizer
optimizer = optim.Adam(model.parameters(),
                           lr=args.lr, weight_decay=args.weight_decay)

if args.cuda:
    if model_name!='AGNN' and model_name!='GIN':
        model.cuda()
        feats = feats.cuda()
        adj = adj.cuda()
        labels = labels.cuda()
        idx_train = idx_train.cuda()
        idx_val = idx_val.cuda()
        idx_test = idx_test.cuda()
    
    
# Print model's state_dict    
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor,"\t",model.state_dict()[param_tensor].size()) 
print("optimizer's state_dict:")

# Print optimizer's state_dict
for var_name in optimizer.state_dict():
    print(var_name,"\t",optimizer.state_dict()[var_name])
    
# # Print parameters
# for name,param in model.named_parameters():
#     print(name, param)

Model's state_dict:
mapping 	 torch.Size([6, 6])
gc1.weight 	 torch.Size([3703, 6])
gc1.linear_weight 	 torch.Size([7])
gc1.bias 	 torch.Size([6])
optimizer's state_dict:
state 	 {}
param_groups 	 [{'lr': 0.1, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0.001, 'amsgrad': False, 'params': [4891145704, 4891145920, 4891145848, 4891145776]}]


In [7]:
# train, test


def train(epoch):
    t = time.time()
    model.train()
    optimizer.zero_grad()
    output, fp1, fp2 = model(feats, adj)
    CE_loss_train = F.nll_loss(output[idx_train], labels[idx_train])
    if model_name == 'AGNN':
        LDA_loss_train = LDA_loss(fp1[idx_train], Y_onehot[idx_train], nb_each_class_inv_mat_train, norm_or_not = False)
#         loss_train =  CE_loss_train - LDA_loss_train
        loss_train =  - LDA_loss_train

    else:
        loss_train = CE_loss_train
    
    acc_train = accuracy(output[idx_train], labels[idx_train])
    loss_train.backward()
    optimizer.step()

    model.eval()
    output, fp1, fp2 = model(feats, adj)
    
    CE_loss_val = F.nll_loss(output[idx_val], labels[idx_val])
#     loss_val = CE_loss_val
    LDA_loss_val = LDA_loss(fp1[idx_val], Y_onehot[idx_val], nb_each_class_inv_mat_val, norm_or_not = True)
    loss_val = - LDA_loss_val
    acc_val = accuracy(output[idx_val], labels[idx_val])
    
    CE_loss_test = F.nll_loss(output[idx_test], labels[idx_test])
#     loss_test = CE_loss_test
    LDA_loss_test = LDA_loss(fp1[idx_test], Y_onehot[idx_test], nb_each_class_inv_mat_test, norm_or_not = True)
    loss_test = - LDA_loss_test
    acc_test = accuracy(output[idx_test], labels[idx_test])
    
    
    print('Epoch: {:04d}'.format(epoch+1),
          'loss_train: {:.4f}'.format(loss_train.item()),
          'acc_train: {:.4f}'.format(acc_train.item()),
          'loss_val: {:.4f}'.format(loss_val.item()),
          'acc_val: {:.4f}'.format(acc_val.item()),
#           'loss_test: {:.4f}'.format(loss_test.item()),
#           'acc_test: {:.4f}'.format(acc_test.item()),
          'time: {:.4f}s'.format(time.time() - t))

    return epoch+1, loss_train.item(), acc_train.item(), loss_val.item(), \
            acc_val.item(), loss_test.item(), acc_test.item(), time.time() - t, \
            


# def test():
#     model.eval()
#     output, fp1, fp2 = model(feats, adj)
#     loss_test = F.nll_loss(output[idx_test], labels[idx_test])
#     acc_test = accuracy(output[idx_test], labels[idx_test])
#     print("Test set results:",
#           "loss= {:.4f}".format(loss_test.item()),
#           "accuracy= {:.4f}".format(acc_test.item()))
#     return 




In [None]:
training_log = []

# Train model
t_total = time.time()
temp_val_loss = 999999
temp_test_loss = 0
temp_test_acc = 0
PATH = "save/model_param/{}{}.pt".format(model_name, dataset_name)

for epoch in range(args.epochs):

    epo, trainloss, trainacc, valloss, valacc, testloss, testacc, epotime = train(epoch)
    training_log.append([epo, trainloss, trainacc, valloss, valacc, testloss, testacc, epotime])
    
    if valloss <= temp_val_loss:
        temp_val_loss = valloss
        temp_test_loss = testloss
        temp_test_acc = testacc
        torch.save(model.state_dict(), PATH)


print("Optimization Finished!")
print("Total time elapsed: {:.4f}s".format(time.time() - t_total))
print("Best result:",
          "val_loss=",temp_val_loss,
            "test_loss=",temp_test_loss,
             "test_acc=",temp_test_acc)
bestmodel = torch.load(PATH)
if model_name == 'AGNN':
    print("the weight is: ", torch.softmax(bestmodel['gc1.linear_weight'].data,dim=0))
# output, fp1, fp2 = model(feats, adj)
# test_LDA = LDA_loss(fp1[idx_train], Y_onehot[idx_train], nb_each_class_inv_mat_test, norm_or_not = False)
# print("test_LDA: test_LDA")

# # Testing
# test()

# # save training log
# # expname = input('input experiment name: ')
# expname = dataset_name + '_' + model_name
# log_pk = open('./save/trainlog_'+expname+'.pkl','wb')
# pkl.dump(np.array(training_log),log_pk)
# log_pk.close()
# print("finish save log")





Epoch: 0001 loss_train: -81.8612 acc_train: 0.1667 loss_val: -72.9693 acc_val: 0.1880 time: 2.6709s
Epoch: 0002 loss_train: -82.6054 acc_train: 0.1667 loss_val: -73.8103 acc_val: 0.1880 time: 2.1588s
Epoch: 0003 loss_train: -83.3120 acc_train: 0.1667 loss_val: -74.6017 acc_val: 0.1880 time: 1.7595s
Epoch: 0004 loss_train: -83.9686 acc_train: 0.1667 loss_val: -75.3373 acc_val: 0.1880 time: 2.3347s
Epoch: 0005 loss_train: -84.5673 acc_train: 0.1667 loss_val: -76.0170 acc_val: 0.1880 time: 1.5639s
Epoch: 0006 loss_train: -85.1058 acc_train: 0.1667 loss_val: -76.6476 acc_val: 0.1880 time: 1.5209s
Epoch: 0007 loss_train: -85.5868 acc_train: 0.1667 loss_val: -77.2406 acc_val: 0.1880 time: 1.7539s
Epoch: 0008 loss_train: -86.0176 acc_train: 0.1667 loss_val: -77.8069 acc_val: 0.1880 time: 1.4946s
Epoch: 0009 loss_train: -86.4071 acc_train: 0.1667 loss_val: -78.3524 acc_val: 0.1880 time: 1.6059s
Epoch: 0010 loss_train: -86.7629 acc_train: 0.1667 loss_val: -78.8772 acc_val: 0.1880 time: 1.6768s


Epoch: 0083 loss_train: -91.3780 acc_train: 0.1667 loss_val: -85.6566 acc_val: 0.1880 time: 1.6902s
Epoch: 0084 loss_train: -91.3788 acc_train: 0.1667 loss_val: -85.6626 acc_val: 0.1880 time: 2.0107s
Epoch: 0085 loss_train: -91.3797 acc_train: 0.1667 loss_val: -85.6671 acc_val: 0.1880 time: 1.6891s
Epoch: 0086 loss_train: -91.3805 acc_train: 0.1667 loss_val: -85.6701 acc_val: 0.1880 time: 1.7663s
Epoch: 0087 loss_train: -91.3813 acc_train: 0.1667 loss_val: -85.6719 acc_val: 0.1880 time: 1.6225s
Epoch: 0088 loss_train: -91.3822 acc_train: 0.1667 loss_val: -85.6723 acc_val: 0.1880 time: 1.5805s
Epoch: 0089 loss_train: -91.3831 acc_train: 0.1667 loss_val: -85.6717 acc_val: 0.1880 time: 1.4664s
Epoch: 0090 loss_train: -91.3841 acc_train: 0.1667 loss_val: -85.6700 acc_val: 0.1880 time: 1.4446s
Epoch: 0091 loss_train: -91.3850 acc_train: 0.1667 loss_val: -85.6675 acc_val: 0.1880 time: 1.4596s
Epoch: 0092 loss_train: -91.3860 acc_train: 0.1667 loss_val: -85.6643 acc_val: 0.1880 time: 1.5390s
