In [1]:
import argparse

from OpenAttMultiGL.model.X_GOAL.xgoal import XGOAL

from sklearn.metrics import normalized_mutual_info_score, pairwise, f1_score
import numpy as np
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', nargs='?', default='amazon')
    parser.add_argument('--model', type=str, default='xgoal')

    parser.add_argument('--hid_units', type=int, default=128, help='hidden dimension')
    parser.add_argument('--nb_epochs', type=int, default=20000, help='the maximum number of epochs')
    parser.add_argument('--lr', type=float, default=1e-3, help='learning rate')
    parser.add_argument('--patience', type=int, default=100, help='patience for early stopping')
    parser.add_argument('--gpu_num', type=int, default=0, help='the id of gpu to use')

    # path
    parser.add_argument('--save_root', type=str, default="./saved_models", help='root for saving the model')
    parser.add_argument('--pretrained_model_path', type=str, default="./example_ckpts/warmup_amazon_xgoal.pkl",
                        help='path to the pretrained model')

    # hyper-parameters for info-nce
    parser.add_argument('--p_drop', type=float, default=0.5, help='dropout rate for attributes')

    # hyper-parameters for clusters
    parser.add_argument('--k', type=list, default=[4, 4, 4], help='the numbers of clusters')
    parser.add_argument('--tau', type=list, default=[1, 1, 1], help='the temperature of clusters')
    parser.add_argument('--w_cluster', type=list, default=1e-2, help='weight for cluster loss')
    parser.add_argument('--cluster_step', type=int, default=5, help='every n steps to perform clustering')

    # hyper-parameters for alignment
    parser.add_argument('--w_reg_n', type=float, default=1e-3, help='weight for node level alignment regularization')
    parser.add_argument('--w_reg_c', type=float, default=1e-2, help='weight for cluster level alignment regularization')

    # hyper-parameters for differnet layers
    parser.add_argument('--w_list', type=list, default=[1, 1, 1], help="weights for different layers")

    # warm-up
    parser.add_argument('--is_warmup', type=bool, default=False, help='whether to warm up or not')
    parser.add_argument('--warmup_lr', type=float, default=5e-3, help='learning rate')
    parser.add_argument('--warmup_w_reg_n', type=float, default=1e-1, help='weight for node level alignment regularization')

    return parser.parse_known_args()


def printConfig(args):
    arg2value = {}
    for arg in vars(args):
        arg2value[arg] = getattr(args, arg)
    print(arg2value)


def main():
    args, unknown = parse_args()
    printConfig(args)
    
    model = XGOAL(args)
    
    model.train()
    model.evaluate()
     #train_model(d, feature_dic,t)
    
    #print("Final score: \n")
    #print('Micro: {:.4f} ({:.4f})'.format(np.mean(micro),np.std(micro)))
    #print('Macro: {:.4f} ({:.4f})'.format(np.mean(macro),np.std(macro)))
    #print('Sim: {:.4f} ({:.4f})'.format(np.mean(sim),np.std(sim)))
    #print('NMI: {:.4f} ({:.4f})'.format(np.mean(nmi),np.std(nmi)))
#
if __name__ == '__main__':
    main()



{'dataset': 'amazon', 'model': 'xgoal', 'hid_units': 128, 'nb_epochs': 20, 'lr': 0.001, 'patience': 100, 'gpu_num': 0, 'save_root': './saved_models', 'pretrained_model_path': './example_ckpts/warmup_amazon_xgoal.pkl', 'p_drop': 0.5, 'k': [4, 4, 4], 'tau': [1, 1, 1], 'w_cluster': 0.01, 'cluster_step': 5, 'w_reg_n': 0.001, 'w_reg_c': 0.01, 'w_list': [1, 1, 1], 'is_warmup': False, 'warmup_lr': 0.005, 'warmup_w_reg_n': 0.1}
Started training on amazon with xgoal...
Full loss training...


  0%|                                                    | 0/20 [00:00<?, ?it/s]

loss_full: 1.683045, L_n: 1.61401, L_c: 0.0612328, R_n: 0.000822781, R_c: 0.698295


 50%|█████████████████████▌                     | 10/20 [01:21<01:18,  7.86s/it]

loss_full: 1.449641, L_n: 1.40575, L_c: 0.0393882, R_n: 0.000728563, R_c: 0.377025


100%|███████████████████████████████████████████| 20/20 [02:38<00:00,  7.94s/it]


Evaluating...
	[Classification] Macro-F1: 0.3763 (0.0022) | Micro-F1: 0.3835 (0.0026)
	[Clustering] NMI: 0.0004 | 0.0000
	[Similarity] [5,10,20,50,100] : [0.3887,0.3878,0.3839,0.3696,0.3558]
