In [1]:
import torch
from sklearn.metrics import f1_score
from mGCN_Toolbox.utils.dataset import dataset
from mGCN_Toolbox.utils.process import * 
import datetime
import errno
import os
import pickle
import random
from pprint import pprint
from sklearn.metrics import normalized_mutual_info_score, pairwise, f1_score
from sklearn.cluster import KMeans
from mGCN_Toolbox.model.HAN.embedder_link import evaluate
from torch_geometric.utils import *
import dgl


from dgl.data.utils import _get_dgl_url, download, get_download_dir
import numpy as np

def score(logits, labels,t):
    _, indices = torch.max(logits, dim=1)
    prediction = indices.long().cpu().numpy()
    labels = labels.cpu().numpy()

    sim = (prediction == labels).sum() / len(prediction)
    micro_f1 = f1_score(labels, prediction, average="micro")
    macro_f1 = f1_score(labels, prediction, average="macro")
    nmi = run_kmeans(labels, prediction, t.num_classes)

    return sim, micro_f1, macro_f1,nmi


#def evaluate(model, g, features, labels, mask, loss_func,t):
    #model.eval()
    #with torch.no_grad():
        #logits = model(g, features)
    #print("embeds:",logits.shape)
    #loss = loss_func(logits[mask], labels[mask])
    #sim, micro_f1, macro_f1,nmi = score(logits[mask], labels[mask],t)
    
    

    #return micro_f1, macro_f1,nmi,sim

def run_kmeans(y,y_pred, k):
    estimator = KMeans(n_clusters=k,n_init=10)#, n_jobs=16)

    NMI_list = []
    for i in range(5):
        #estimator.fit(x)
        #y_pred = estimator.predict(x)
        s = normalized_mutual_info_score(y, y_pred, average_method='arithmetic')
        e = float("{:.4f}".format(s))
        NMI_list.append(e)

    mean = np.mean(NMI_list)
    std = np.std(NMI_list)
    print('\t[Clustering] NMI: {:.4f} | {:.4f}'.format(mean, std))
    return mean

def main(args):
    # If args['hetero'] is True, g would be a heterogeneous graph.
    # Otherwise, it will be a list of homogeneous graphs.
    dataname = str.lower(args["dataset"])
    c = dataset(dataname)
    
    
    
    
    if dataname == "amazon":
        data = pkl.load(open('mGCN_Toolbox/data/HAN/AMAZON/amazon.pkl', "rb"))
        data["IVI"] = sp.csr_matrix(data["IVI"])
        data["IBI"] = sp.csr_matrix(data["IBI"])
        data["IOI"] = sp.csr_matrix(data["IOI"])
        author_g = dgl.from_scipy(data["IVI"])
        subject_g = dgl.from_scipy(data["IBI"])
        o_g = dgl.from_scipy(data["IOI"])
        gs = [author_g, subject_g, o_g]
    elif dataname == "acm":
        data = sio.loadmat('mGCN_Toolbox/data/HAN/ACM/acm.mat')
        data["PAP"] = sp.csr_matrix(data["PAP"])
        data["PLP"] = sp.csr_matrix(data["PLP"])
        author_g = dgl.from_scipy(data["PAP"])
        subject_g = dgl.from_scipy(data["PLP"])
        gs = [author_g, subject_g]
    elif dataname == "dblp":
        data = pkl.load(open('mGCN_Toolbox/data/HAN/DBLP/dblp.pkl', "rb"))
        data["PAP"] = sp.csr_matrix(data["PAP"])
        data["PPrefP"] = sp.csr_matrix(data["PPrefP"])
        data["PATAP"] = sp.csr_matrix(data["PATAP"])
        author_g = dgl.from_scipy(data["PAP"])
        subject_g = dgl.from_scipy(data["PPrefP"])
        o_g = dgl.from_scipy(data["PATAP"])
        gs = [author_g, subject_g, o_g]
    elif dataname == "imdb":
        data = pkl.load(open('mGCN_Toolbox/data/HAN/IMDB/imdb.pkl', "rb"))
        data["MDM"] = sp.csr_matrix(data["MDM"])
        data["MAM"] = sp.csr_matrix(data["MAM"])
        author_g = dgl.from_scipy(data["MDM"])
        subject_g = dgl.from_scipy(data["MAM"])
        gs = [author_g, subject_g]
    
    num_classes = c.gcn_labels.shape[1]
    c.gcn_labels = torch.from_numpy(data["label"]).long()
    c.gcn_labels = c.gcn_labels.nonzero()[:, 1]
    c.features = c.features.toarray()
    c.features = torch.from_numpy(data["feature"]).float()
    num_nodes = author_g.num_nodes()
    train_mask = get_binary_mask(num_nodes, c.train_id)
    val_mask = get_binary_mask(num_nodes, c.valid_id)
    test_mask = get_binary_mask(num_nodes, c.test_id)
    
    #t = dataset(args["dataset"])
    #print(type(t.edge_index))
    if hasattr(torch, "BoolTensor"):
        train_mask = train_mask.bool()
        val_mask = val_mask.bool()
        test_mask = test_mask.bool()

    c.features = c.features.to(args["device"])
    c.gcn_labels = c.gcn_labels.to(args["device"])
    
    
    train_mask = train_mask.to(args["device"])
    val_mask = val_mask.to(args["device"])
    test_mask = test_mask.to(args["device"])
    
    
    #print(args["hetero"])
    if args["hetero"]:
        from GCN_Toolbox.model.HAN.model_hetero import HAN
        
        model = HAN(
            meta_paths=[["pa", "ap"], ["pf", "fp"]],
            in_size=c.features.shape[1],
            hidden_size=args["hidden_units"],
            out_size=t.HAN_num_classes,
            num_heads=args["num_heads"],
            dropout=args["dropout"],
        ).to(args["device"])
        gs = gs.to(args["device"])
        
    else:
        from mGCN_Toolbox.model.HAN.model import HAN

        model = HAN(
            num_meta_paths=len(gs),
            in_size=c.features.shape[1],
            hidden_size=args["hidden_units"],
            out_size=num_classes,
            num_heads=args["num_heads"],
            dropout=args["dropout"],
        ).to(args["device"])
        gs = [graph.to(args["device"]) for graph in gs]

    stopper = EarlyStopping(patience=args["patience"])
    loss_fcn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(
        model.parameters(), lr=args["lr"], weight_decay=args["weight_decay"]
    )
    test_view = 0
    logits = model(gs, c.features)
    
    #print(type(c.edge_list))
    
    #c.edge_list = np.array(c.edge_list)
    
    #print(type(c.edge_list))
    
    #print(c.edge_list[0].shape)
    #print("gs",t.gs)
    #for i in t.edge_index:
    #i = i.transpose(1, 0)
    #t.edge_index = [torch.LongTensor(temp).transpose(0, 1) for temp in t.edge_index]
    #print(type(t.edge_index[test_view]))
    #for i in c.edge_list:
            #i = i.transpose(1, 0)
    split_edges = mask_test_edges(c.features, c.edge_list[0],neg_num=1)
    
        
    split_edges['train']['label'] = torch.cat(
        (split_edges['train']['label_pos'], split_edges['train']['label_neg']))#.to(args.device)
    split_edges['valid']['label'] = torch.cat(
        (split_edges['valid']['label_pos'], split_edges['valid']['label_neg']))#.to(args.device)
    split_edges['test']['label'] = torch.cat(
        (split_edges['test']['label_pos'], split_edges['test']['label_neg']))#.to(args.device)
    s_edge = split_edges
    #split_edge = mask_test_edges(t.HAN_features, split_edges, 1, 0.1, 0.5)
    print("shape",type(logits))
    AUC, ap, hits = evaluate(logits,s_edge)
    print("Average-percision:", np.mean(ap), np.std(ap))
    print("Average-AUC:", np.mean(AUC), np.std(AUC))

    


if __name__ == "__main__":
    import argparse

    #from utils import setup

    parser = argparse.ArgumentParser("HAN")
    parser.add_argument("-s", "--seed", type=int, default=1, help="Random seed")
    parser.add_argument(
        "-ld",
        "--log-dir",
        type=str,
        default="results",
        help="Dir for saving training results",
    )
    parser.add_argument(
        "--hetero",
        action="store_true",
        help="Use metapath coalescing with DGL's own dataset",
    )
    parser.add_argument("-f", "--fff", help="a dummy argument to fool ipython", default="1")
    args = parser.parse_args().__dict__

    args = setup(args)
    # setup function is in mGCN_Toolbox.utils.process file
    
    print(args["dataset"])
    main(args)

Created directory results/DBLP_2023-07-13_13-25-43
DBLP
shape <class 'torch.Tensor'>
retshape torch.Size([7907, 2])
retshape torch.Size([7907, 2])
retshape torch.Size([7907, 2])
<class 'torch.Tensor'>
Epoch: 0
Best Validation: 0.5716016618702275
Best Test: 0.5667267838890772
retshape torch.Size([7907, 2])
retshape torch.Size([7907, 2])
retshape torch.Size([7907, 2])
retshape torch.Size([7907, 2])
retshape torch.Size([7907, 2])
<class 'torch.Tensor'>
Epoch: 0
Best Validation: 0.5719361965955018
Best Test: 0.5736838901137314
retshape torch.Size([7907, 2])


KeyboardInterrupt: 