In [None]:
# Test file for toolbox

import argparse
from OpenAttMultiGL.model.mGCN.mGCN_node import*
import torch.nn as nn
import torch.optim as optim
import torch
from OpenAttMultiGL.utils.dataset import dataset
from OpenAttMultiGL.utils.process import split_node_data
from sklearn.metrics import roc_auc_score
import numpy as np
import random
import copy
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, roc_auc_score, recall_score, f1_score
from sklearn.metrics import normalized_mutual_info_score, pairwise, f1_score
from sklearn.cluster import KMeans

parser = argparse.ArgumentParser(description='GCN')
parser.add_argument('--dataset', type=str, default='imdb')
parser.add_argument('--fast_split', action='store_true',
                    help="for large custom datasets (not OGB), do a fast data split")

parser.add_argument('--runs', type=int, default=1)
parser.add_argument('--hidden', type=int, default=128,
                    help='Number of hidden units.')
parser.add_argument('--epochs', type=int, default=500,
                    help='Number of training epochs.')
parser.add_argument('--alpha', type=float, default=0.6,
                    help='Hyperparameter')
parser.add_argument('--dropout', type=float, default=0.5,
                    help='Dropout')
parser.add_argument('--training_ratio', type=float, default=0.3,
                    help='Training Ratio')
parser.add_argument('--lr', type=float, default=0.001,
                    help='Learning Rate')
parser.add_argument('--weight_decay', type=float, default=1e-2,
                    help='Weight_decay')
parser.add_argument('--test_view', type=int, default=1,
                    help='Number of training epochs.')
parser.add_argument("-f", "--fff", help="a dummy argument to fool ipython", default="1")

args = parser.parse_args()

def evaluate_metrics(true, pred):
    preds = pred.max(1)[1].type_as(labels)
    correct = preds.eq(true).double()
    correct = correct.sum()
    return correct /len(true)

def evaluate_model(ind):
    model_mGCN.eval()
    
    logits = model_mGCN(sample_data.dataset)
    nb_classes = sample_data.num_classes
    
    pred = logits.max(1).indices
    macro_f1 = f1_score(sample_data.labels[ind].cpu().numpy(), logits.max(1)[1][ind].detach().cpu().numpy(),
                       average="macro")
    micro_f1 = f1_score(sample_data.labels[ind].cpu().numpy(), logits.max(1)[1][ind].detach().cpu().numpy(),
                       average="micro")
    #t = torch.LongTensor(sample_data.gcn_labels[sample_data.test_id]) 
    #test_lbls = torch.argmax(t, dim=1)
    nmi = run_kmeans(sample_data.labels,pred, nb_classes)
    
    return macro_f1,micro_f1,nmi
    # return evaluate_metrics(labels[ind], logits[ind]).item()
soft = nn.Softmax(dim=1)

def run_kmeans(y,y_pred, k):
    estimator = KMeans(n_clusters=k,n_init=10)#, n_jobs=16)

    NMI_list = []
    for i in range(5):
        #estimator.fit(x)
        #y_pred = estimator.predict(x)
        s = normalized_mutual_info_score(y, y_pred, average_method='arithmetic')
        e = float("{:.4f}".format(s))
        NMI_list.append(e)

    mean = np.mean(NMI_list)
    std = np.std(NMI_list)
    print('\t[Clustering] NMI: {:.4f} | {:.4f}'.format(mean, std))
    return NMI_list


def write_results(ind):
    model_mGCN.eval()
    logits = soft(model_mGCN(sample_data.dataset))
    f= open("./results/" + args.dataset + "_combined", "w")
    for temp in logits[ind].detach().cpu().numpy():
        f.write(" ".join(np.array([str(i) for i in temp])) +"\n")
    f.close()
    
def run_similarity_search(true_label,pred_label):

    c = 0
        
    for i in range(len(true_label)):
        if pred_label[i] == true_label[i]:
            c += 1
        
    sim = c/len(true_label)
    return sim

best_val = 0
best_test = 0

def train_model(epochs):
    global best_val
    global best_test
    global best_hits
    best_val = 0
    best_test = 0
    # print(split_edges['train']['edge'].shape[0])
    # training_negative = split_edges['train']['edge_neg'][range(0, split_edges['train']['edge'].shape[0])]
    training_negative = []
    training_positive = []

    for epoch in range(0, epochs):
        model_mGCN.train()
        optimizer.zero_grad()
        # print(training_negative.shape)
        logits = model_mGCN(sample_data.dataset)
        # labels = split_edges[i]['train']['label']
        loss = criterion(logits[sample_data.train_id], sample_data.labels[sample_data.train_id])
        loss.backward()
        optimizer.step()

        #t1,t2 = evaluate_model(sample_data.valid_id)
        # if epoch == epochs-1:
        #     temp = evaluate_model(split_edges['test']['edge'], split_edges['test']['edge_neg'], None,
        #                    split_edges['test']['label'], test=True)
        #     print("AUC last:", temp)
        #t = sample_data.labels
        #s = sample_data.gcn_labels
        #print(t)
        #print(t.shape)
        #print(s.shape)
        #print(logits[-1].items())
        pred = logits.max(1).indices
        #print(pred)
        #print(pred.shape)
        #if t1 > best_val:
        #best_val = t1
        macro,micro,nmi = evaluate_model(sample_data.test_id)
        sim = run_similarity_search(sample_data.labels,pred)
        # write_results(test_id)
        print('Epoch:', epoch)
        #print("Best Validation:", t1)
        print("Macro_F1:", macro)
        print("Micro_F1:", micro)
        print("NMI: ", nmi)
        print("SIM: ", sim)
        #print("Best Test:", best_test)
        #print("Best Test2:", micro)


results = []
# results_hits = {}
# for K in [20, 50, 100]:
#     results_hits[f'Hits@{K}'] = []
for run in range(0, args.runs):
    np.random.seed(run)
    torch.manual_seed(run)
    torch.cuda.manual_seed(run)
    random.seed(run)


sample_data = dataset(args.dataset)

taskname  = 'node'

#training_id, valid_id, test_id = split_node_data(len(sample_data.labels),train_percent=args.training_ratio,valid_percent = 0.1)
#data, num_views, training_id, valid_id, test_id, num_classes, labels, adj_list, edge_list = load_data(args.dataset, training_percent=args.training_ratio)
# data, split_edges = split_data(data_ori, args.test_view, multi=True)
print("Finish loading data")
num_feat = sample_data.dataset.x.shape[1]
model_mGCN = mGCN(num_feat, args.hidden, None, sample_data.num_dims, args.alpha, sample_data.num_classes, dropout=args.dropout)
criterion = nn.NLLLoss()
optimizer = optim.Adam(model_mGCN.parameters(), lr=args.lr)
train_model(args.epochs)
evaluate_model(sample_data.valid_id)  

#for epoch in range(args.epochs):
    #train()
    
print("Model training is complete")

#test()

Finish loading data
	[Clustering] NMI: 0.0000 | 0.0000
Epoch: 0
Macro_F1: 0.17277749874434956
Micro_F1: 0.34983050847457625
NMI:  [0.0, 0.0, 0.0, 0.0, 0.0]
SIM:  0.2757746478873239
	[Clustering] NMI: 0.0000 | 0.0000
Epoch: 1
Macro_F1: 0.17277749874434956
Micro_F1: 0.34983050847457625
NMI:  [0.0, 0.0, 0.0, 0.0, 0.0]
SIM:  0.3512676056338028
	[Clustering] NMI: 0.0027 | 0.0000
Epoch: 2
Macro_F1: 0.18652681357699719
Micro_F1: 0.38644067796610165
NMI:  [0.0027, 0.0027, 0.0027, 0.0027, 0.0027]
SIM:  0.3487323943661972
	[Clustering] NMI: 0.0000 | 0.0000
Epoch: 3
Macro_F1: 0.13926879413605078
Micro_F1: 0.2640677966101695
NMI:  [0.0, 0.0, 0.0, 0.0, 0.0]
SIM:  0.37295774647887325
	[Clustering] NMI: 0.0000 | 0.0000
Epoch: 4
Macro_F1: 0.13926879413605078
Micro_F1: 0.2640677966101695
NMI:  [0.0, 0.0, 0.0, 0.0, 0.0]
SIM:  0.29746478873239435
	[Clustering] NMI: 0.0000 | 0.0000
Epoch: 5
Macro_F1: 0.13926879413605078
Micro_F1: 0.2640677966101695
NMI:  [0.0, 0.0, 0.0, 0.0, 0.0]
SIM:  0.27605633802816903

	[Clustering] NMI: 0.0107 | 0.0000
Epoch: 48
Macro_F1: 0.3077826278861689
Micro_F1: 0.40711864406779663
NMI:  [0.0107, 0.0107, 0.0107, 0.0107, 0.0107]
SIM:  0.3754929577464789
	[Clustering] NMI: 0.0106 | 0.0000
Epoch: 49
Macro_F1: 0.30972331807820225
Micro_F1: 0.40508474576271186
NMI:  [0.0106, 0.0106, 0.0106, 0.0106, 0.0106]
SIM:  0.3963380281690141
	[Clustering] NMI: 0.0138 | 0.0000
Epoch: 50
Macro_F1: 0.3165100211422979
Micro_F1: 0.4061016949152542
NMI:  [0.0138, 0.0138, 0.0138, 0.0138, 0.0138]
SIM:  0.38253521126760565
	[Clustering] NMI: 0.0057 | 0.0000
Epoch: 51
Macro_F1: 0.26555353933891856
Micro_F1: 0.38610169491525426
NMI:  [0.0057, 0.0057, 0.0057, 0.0057, 0.0057]
SIM:  0.3805633802816901
	[Clustering] NMI: 0.0215 | 0.0000
Epoch: 52
Macro_F1: 0.35602194663922626
Micro_F1: 0.4101694915254237
NMI:  [0.0215, 0.0215, 0.0215, 0.0215, 0.0215]
SIM:  0.39492957746478874
	[Clustering] NMI: 0.0187 | 0.0000
Epoch: 53
Macro_F1: 0.36759068333145867
Micro_F1: 0.4111864406779661
NMI:  [0.0187

	[Clustering] NMI: 0.1997 | 0.0000
Epoch: 95
Macro_F1: 0.5840981329622325
Micro_F1: 0.6047457627118644
NMI:  [0.1997, 0.1997, 0.1997, 0.1997, 0.1997]
SIM:  0.6338028169014085
	[Clustering] NMI: 0.2028 | 0.0000
Epoch: 96
Macro_F1: 0.5986684985451731
Micro_F1: 0.6098305084745763
NMI:  [0.2028, 0.2028, 0.2028, 0.2028, 0.2028]
SIM:  0.6340845070422535
	[Clustering] NMI: 0.2068 | 0.0000
Epoch: 97
Macro_F1: 0.6044996147469056
Micro_F1: 0.6159322033898305
NMI:  [0.2068, 0.2068, 0.2068, 0.2068, 0.2068]
SIM:  0.6445070422535212
	[Clustering] NMI: 0.2021 | 0.0000
Epoch: 98
Macro_F1: 0.5920713457059467
Micro_F1: 0.6081355932203389
NMI:  [0.2021, 0.2021, 0.2021, 0.2021, 0.2021]
SIM:  0.647887323943662
	[Clustering] NMI: 0.2013 | 0.0000
Epoch: 99
Macro_F1: 0.5877594636959912
Micro_F1: 0.6071186440677966
NMI:  [0.2013, 0.2013, 0.2013, 0.2013, 0.2013]
SIM:  0.6380281690140845
	[Clustering] NMI: 0.2162 | 0.0000
Epoch: 100
Macro_F1: 0.6194798567345218
Micro_F1: 0.6261016949152542
NMI:  [0.2162, 0.2162,

	[Clustering] NMI: 0.2147 | 0.0000
Epoch: 142
Macro_F1: 0.6135347144495503
Micro_F1: 0.6196610169491525
NMI:  [0.2147, 0.2147, 0.2147, 0.2147, 0.2147]
SIM:  0.6622535211267606
	[Clustering] NMI: 0.2183 | 0.0000
Epoch: 143
Macro_F1: 0.6116900308372893
Micro_F1: 0.6237288135593221
NMI:  [0.2183, 0.2183, 0.2183, 0.2183, 0.2183]
SIM:  0.6583098591549296
	[Clustering] NMI: 0.2110 | 0.0000
Epoch: 144
Macro_F1: 0.60530750575116
Micro_F1: 0.6179661016949153
NMI:  [0.211, 0.211, 0.211, 0.211, 0.211]
SIM:  0.6585915492957747
	[Clustering] NMI: 0.2101 | 0.0000
Epoch: 145
Macro_F1: 0.6079308812340677
Micro_F1: 0.6189830508474576
NMI:  [0.2101, 0.2101, 0.2101, 0.2101, 0.2101]
SIM:  0.6577464788732394
	[Clustering] NMI: 0.2156 | 0.0000
Epoch: 146
Macro_F1: 0.6147616581384069
Micro_F1: 0.6216949152542373
NMI:  [0.2156, 0.2156, 0.2156, 0.2156, 0.2156]
SIM:  0.6563380281690141
	[Clustering] NMI: 0.2119 | 0.0000
Epoch: 147
Macro_F1: 0.610274107697017
Micro_F1: 0.616271186440678
NMI:  [0.2119, 0.2119, 0.

	[Clustering] NMI: 0.2122 | 0.0000
Epoch: 189
Macro_F1: 0.6110675531873132
Micro_F1: 0.6203389830508474
NMI:  [0.2122, 0.2122, 0.2122, 0.2122, 0.2122]
SIM:  0.6583098591549296
	[Clustering] NMI: 0.2124 | 0.0000
Epoch: 190
Macro_F1: 0.6119979634038506
Micro_F1: 0.6206779661016949
NMI:  [0.2124, 0.2124, 0.2124, 0.2124, 0.2124]
SIM:  0.6597183098591549
	[Clustering] NMI: 0.2105 | 0.0000
Epoch: 191
Macro_F1: 0.6083031685626198
Micro_F1: 0.6172881355932204
NMI:  [0.2105, 0.2105, 0.2105, 0.2105, 0.2105]
SIM:  0.6591549295774648
	[Clustering] NMI: 0.2131 | 0.0000
Epoch: 192
Macro_F1: 0.6112276017776641
Micro_F1: 0.6196610169491525
NMI:  [0.2131, 0.2131, 0.2131, 0.2131, 0.2131]
SIM:  0.6566197183098591
	[Clustering] NMI: 0.2141 | 0.0000
Epoch: 193
Macro_F1: 0.6138151501281036
Micro_F1: 0.6203389830508474
NMI:  [0.2141, 0.2141, 0.2141, 0.2141, 0.2141]
SIM:  0.6619718309859155
	[Clustering] NMI: 0.2154 | 0.0000
Epoch: 194
Macro_F1: 0.6131406007897978
Micro_F1: 0.6196610169491525
NMI:  [0.2154, 0

	[Clustering] NMI: 0.2185 | 0.0000
Epoch: 236
Macro_F1: 0.6183699746568
Micro_F1: 0.6250847457627119
NMI:  [0.2185, 0.2185, 0.2185, 0.2185, 0.2185]
SIM:  0.66
	[Clustering] NMI: 0.2165 | 0.0000
Epoch: 237
Macro_F1: 0.6129991718001149
Micro_F1: 0.6210169491525424
NMI:  [0.2165, 0.2165, 0.2165, 0.2165, 0.2165]
SIM:  0.6625352112676056
	[Clustering] NMI: 0.2137 | 0.0000
Epoch: 238
Macro_F1: 0.6090413670755656
Micro_F1: 0.6183050847457627
NMI:  [0.2137, 0.2137, 0.2137, 0.2137, 0.2137]
SIM:  0.6580281690140845
	[Clustering] NMI: 0.2107 | 0.0000
Epoch: 239
Macro_F1: 0.6059899415609461
Micro_F1: 0.616271186440678
NMI:  [0.2107, 0.2107, 0.2107, 0.2107, 0.2107]
SIM:  0.656056338028169
	[Clustering] NMI: 0.2094 | 0.0000
Epoch: 240
Macro_F1: 0.6075223471621949
Micro_F1: 0.616271186440678
NMI:  [0.2094, 0.2094, 0.2094, 0.2094, 0.2094]
SIM:  0.6538028169014084
	[Clustering] NMI: 0.2101 | 0.0000
Epoch: 241
Macro_F1: 0.609901259310382
Micro_F1: 0.6183050847457627
NMI:  [0.2101, 0.2101, 0.2101, 0.2101

	[Clustering] NMI: 0.2141 | 0.0000
Epoch: 283
Macro_F1: 0.6109531507966718
Micro_F1: 0.6196610169491525
NMI:  [0.2141, 0.2141, 0.2141, 0.2141, 0.2141]
SIM:  0.6583098591549296
	[Clustering] NMI: 0.2113 | 0.0000
Epoch: 284
Macro_F1: 0.6099021937184875
Micro_F1: 0.6183050847457627
NMI:  [0.2113, 0.2113, 0.2113, 0.2113, 0.2113]
SIM:  0.6543661971830986
	[Clustering] NMI: 0.2100 | 0.0000
Epoch: 285
Macro_F1: 0.6098691078797808
Micro_F1: 0.6179661016949153
NMI:  [0.21, 0.21, 0.21, 0.21, 0.21]
SIM:  0.6569014084507042
	[Clustering] NMI: 0.2116 | 0.0000
Epoch: 286
Macro_F1: 0.6122733980858693
Micro_F1: 0.6196610169491525
NMI:  [0.2116, 0.2116, 0.2116, 0.2116, 0.2116]
SIM:  0.6563380281690141
	[Clustering] NMI: 0.2128 | 0.0000
Epoch: 287
Macro_F1: 0.6145066494902601
Micro_F1: 0.6213559322033898
NMI:  [0.2128, 0.2128, 0.2128, 0.2128, 0.2128]
SIM:  0.6552112676056338
	[Clustering] NMI: 0.2127 | 0.0000
Epoch: 288
Macro_F1: 0.6133664466723944
Micro_F1: 0.6206779661016949
NMI:  [0.2127, 0.2127, 0.2