## Code for running the graphical neural network

gnn_models and supporting files can be found at https://github.com/SJTUBME-QianLab/AutoMetricGNN


In [1]:
import gnn_models as models
import numpy as np
import argparse
import torch

import torch.optim as optim
import torch.utils.data as data
import torch.nn.functional as F
import torch.cuda as cuda
from torch.autograd import Variable
from torch.utils.tensorboard import SummaryWriter
from statistics import mean
import os
import datetime
import pickle
import random
import textwrap
from pathlib import Path
from utils import io_utils
from sklearn import metrics as metrics

import visualize
import gnn_models as models
from predict import Predict
from data import DataGenerator

In [2]:
#convenient variables
RESULT_PATH = Path("results/august2022/altered_300")
DATA_PATH = Path("data_prep/data")

# tensorboard summary writer
WRITER_PATH = Path('runs/august2022/altered_300')
tb = SummaryWriter(WRITER_PATH)

#set number of iterations
n_iter=300
#set filenames for input
trainstring='relabel3103_train.npy'
teststring='relabel3103_test.npy'
#trainstring='orig_train.npy'
#teststring='orig_test.npy'
#set stopping criterion patience
patience=10
#determine if stopping
stopbool=False
#length of experiment
maxloops=30

In [3]:
teststring

'relabel3103_test.npy'

In [4]:

class Args(object):
    pass

args=Args()
args.metric_network="gnn"
args.dataset="AD"
args.train_N_way=3
args.test_N_way=3
args.test_N_shots =1
args.train_N_shots=1
args.lr=0.001
args.feature_num=221
args.clinical_feature_num= 8
args.w_feature_num= 213
args.w_feature_list=8
args.iterations=n_iter
args.dec_lr=10000
args.log_interval=1
args.batch_size=64
args.batch_size_train=64
args.batch_size_test=64
args.test_interval=200
#args.random_seed=2021
#for original 10
args.random_seed=2023
#for second 20
args.cuda = True
args.w_feature_list

random_seed = args.random_seed


In [5]:

def setup_seed(seed=random_seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


def adjust_learning_rate(optimizers, lr, iter, writer=None):
    new_lr = lr * (0.5 ** (int(iter / args.dec_lr)))

    for optimizer in optimizers:
        for param_group in optimizer.param_groups:
            param_group['lr'] = new_lr

    if writer:
        writer.add_scalar("Learning rate", new_lr, iter)


def train_batch(model, data):
    """Train a model on selected data sample"""
    [amgnn, softmax_module] = model
    [batch_x, label_x, batches_xi, labels_yi, oracles_yi] = data

    # slice the first five features which are our risk factors
    z_clinical = batch_x[:, 0, 0, 0:args.clinical_feature_num]
    zi_s_clinical = [batch_xi[:,0,0,0:args.clinical_feature_num] for batch_xi in batches_xi]

    # slice the remaining features after our clinical / risk factors
    z_mri_feature = batch_x[:, :, :, args.clinical_feature_num:]
    zi_s_mri_feature = [batch_xi[:, :, :, args.clinical_feature_num:] for batch_xi in batches_xi]
    adj = amgnn.compute_adj(z_clinical, zi_s_clinical)

    inputs = [z_clinical, z_mri_feature, zi_s_clinical, zi_s_mri_feature, labels_yi, oracles_yi, adj]
    _, out_logits = amgnn(*inputs)
    logsoft_prob = softmax_module.forward(out_logits)

    # Loss
    label_x_numpy = label_x.cpu().data.numpy()
    formatted_label_x = np.argmax(label_x_numpy, axis=1)
    formatted_label_x = Variable(torch.LongTensor(formatted_label_x))
    if args.cuda:
        formatted_label_x = formatted_label_x.cuda()
    loss = F.nll_loss(logsoft_prob, formatted_label_x)
    loss.backward()

    return loss


def test_one_shot(args, fold, test_root, model, test_samples=50, partition='test', io_path='results/run.log', write_model_graph=False):
    io = io_utils.IOStream(io_path)

    io.cprint('\n**** TESTING BEGIN ***' )
    root = test_root
    data_loader = DataGenerator(root, keys=['CN','MCI','AD'])
    [amgnn, *rest] = model
    amgnn.eval()
    correct = 0
    total = 0
    iterations = 4
    false0=0
    false1=0
    false2=0
    correct0=0
    correct1=0
    correct2=0
    incorrect0=0
    incorrect1=0
    incorrect2=0

    for i in range(iterations):
        data = data_loader.get_task_batch(
                batch_size=args.batch_size_test,
                n_way=args.test_N_way,
                num_shots=args.test_N_shots,
                cuda=args.cuda
        )

        Y, y_pred, labels_x_cpu = predict.predict_nodes_using_one_shot(data)
       
        balancedaccuracy= metrics.balanced_accuracy_score(labels_x_cpu,y_pred)
        multi_auc=0
        auc = [0 for i in range(3)] 
    
        for i in range(3):    
            temptrue=labels_x_cpu==i
            temppred=y_pred==i
            auc[i] = metrics.roc_auc_score(temptrue,temppred)
        
        multi_auc=mean(auc)

        
        for row_i in range(y_pred.shape[0]):
            if y_pred[row_i] == labels_x_cpu[row_i]:
                correct += 1
            total += 1

    labels_x_cpu = Variable(torch.cuda.LongTensor(labels_x_cpu))
    loss_test = F.nll_loss(Y, labels_x_cpu)
    loss_test_f = float(loss_test)
    del loss_test

    message = textwrap.dedent("""
    ***ITERATION FINISHED***
    Loss: {}
    Correct: {}
    Total: {}
    Accuracy: {:.3f}%
    Balanced Accuracy: {:.3f}%
    AUC: {:.3f}
    """.format(loss_test_f, correct, total, (100.0*correct/total), balancedaccuracy,multi_auc))
    io.cprint(message)

    amgnn.train()
    accuracy = 100 * (correct / total)
    
    return correct, accuracy, loss_test_f,multi_auc,balancedaccuracy




In [6]:
%%capture
for i in range(maxloops):
    now = datetime.datetime.now()
    now_format = now.strftime('%Y-%m-%d-%H-%M')
    now_format=now_format + "_" +str(i)
    save_path = RESULT_PATH / now_format
    if not save_path.exists():
        os.makedirs(save_path)

    io_path = save_path / 'run.log'
    io = io_utils.IOStream(io_path)
   
    tb_path = WRITER_PATH / now_format
    if not tb_path.exists():
        os.makedirs(tb_path)

    tb = SummaryWriter(tb_path)
    print(now_format)
    print('The result will be saved in :', save_path)
    setup_seed(args.random_seed+i)

    amgnn = models.create_models(args, cnn_dim1=2)

    # initialise softmax and prediction modules
    softmax_module = models.SoftmaxModule()
    predict = Predict(amgnn, softmax_module, args, io_path)


    # NOTE: CNN dimension where one CNN is used for learning the edge weight from the
    # absolute difference between each feature of the feature nodes - see notes 1b.
    io.cprint(str(amgnn))

    if args.cuda:
        amgnn.cuda()

    weight_decay = 0

    opt_amgnn = optim.Adam(amgnn.parameters(), lr=args.lr, weight_decay=weight_decay)
    amgnn.train()
    counter = 0
    total_loss = 0
    val_acc, val_acc_aux = 0, 0
    test_acc = 0
    

    root = DATA_PATH / trainstring
    data_loader = DataGenerator(root, keys=['CN', 'MCI','AD'])
    testacclist=np.zeros(args.iterations)
    losslist=np.zeros(args.iterations)
    for batch_idx in range(args.iterations):
        
        data = data_loader.get_task_batch(
                batch_size=args.batch_size_train,
                n_way=args.train_N_way,
                num_shots=args.train_N_shots,
                cuda=args.cuda
        )
        [batch_x, label_x, _, _, batches_xi, labels_yi, oracles_yi] = data

        opt_amgnn.zero_grad()

        # train model
        loss_d_metric = train_batch(model=[amgnn, softmax_module],
                                    data=[batch_x, label_x, batches_xi, labels_yi, oracles_yi])
        opt_amgnn.step()

        adjust_learning_rate(optimizers=[opt_amgnn], lr=args.lr, iter=batch_idx, writer=tb)

        # test result output
        counter += 1
        total_loss += loss_d_metric.item()
        if batch_idx % args.log_interval == 0:
            display_str = 'Train Iter: {}'.format(batch_idx)
            display_str += '\tLoss_d_metric: {:.6f}'.format(total_loss / counter)
            io.cprint(display_str)
            counter = 0
            total_loss = 0


        # test trained model performance
        if (batch_idx + 1) % args.log_interval == 0:

            test_samples = 112
            test_root = DATA_PATH / teststring
            test_correct, test_acc_aux, test_loss_,test_auc,test_balacc = test_one_shot(
                            args, 0, test_root, model=[amgnn, softmax_module],
                            test_samples=test_samples, partition='test',
                            io_path=io_path
                        )

            # record testing metrics for batch
            tb.add_scalar("Loss", test_loss_, batch_idx)
            tb.add_scalar("Correct", test_correct, batch_idx)
            tb.add_scalar("Accuracy", test_acc_aux, batch_idx)
            tb.add_scalar("Balanced Accuracy", test_balacc, batch_idx)
            tb.add_scalar("AUC", test_auc, batch_idx)



            tb = visualize.record_amgnn_bias_metrics(amgnn, tb)

            amgnn.train()
            
            if test_acc_aux is not None and test_acc_aux >= test_acc:
                test_acc = test_acc_aux
                # val_acc = val_acc_aux
                torch.save(amgnn, save_path / 'amgnn_best_model.pkl')
            if args.dataset == 'AD':
                io.cprint("Best test accuracy {:.4f} \n".format(test_acc))
            testacclist[batch_idx]=test_acc_aux
            losslist[batch_idx]=loss_d_metric
            if stopbool and test_acc_aux is not None and batch_idx >10 and all(losslist[i] <= losslist[i+1] for i in range(batch_idx-patience,batch_idx)):
                io.cprint("Stopping at {0} iterations \n".format(batch_idx))
                break
    tb.close()
