In [1]:
import dgl
import torch
import torch.nn.functional as F
# 其中包括激活函数, 损失函数, 池化函数 ,通过 F.xxx() 的形式，可以方便地调用 torch.nn.functional 模块中的各种函数
import numpy
import argparse
import time
from dataset_process.dataset import Dataset
from sklearn.metrics import f1_score, accuracy_score, recall_score, roc_auc_score, precision_score, confusion_matrix
from model.GPRGNN_anomaly import *
from sklearn.model_selection import train_test_split
import scipy.sparse as sp

In [2]:
def train(model, g, edge_index, args):
    features = g.ndata['feature']
    labels = g.ndata['label']
    index = list(range(len(labels)))
    if dataset_name == 'amazon':
        index = list(range(3305, len(labels)))

    idx_train, idx_rest, y_train, y_rest = train_test_split(index, labels[index], stratify=labels[index],
                                                            train_size=args.train_ratio,
                                                            random_state=2, shuffle=True)
    idx_valid, idx_test, y_valid, y_test = train_test_split(idx_rest, y_rest, stratify=y_rest,
                                                            test_size=0.67,
                                                            random_state=2, shuffle=True)
    train_mask = torch.zeros([len(labels)]).bool()
    val_mask = torch.zeros([len(labels)]).bool()
    test_mask = torch.zeros([len(labels)]).bool()

    train_mask[idx_train] = 1
    val_mask[idx_valid] = 1
    test_mask[idx_test] = 1
    print('train/dev/test samples: ', train_mask.sum().item(), val_mask.sum().item(), test_mask.sum().item())
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    best_f1, final_tf1, final_trec, final_tpre, final_tmf1, final_tauc = 0., 0., 0., 0., 0., 0.

    weight = (1-labels[train_mask]).sum().item() / labels[train_mask].sum().item()
    print('cross entropy weight: ', weight)
    time_start = time.time()
    for e in range(args.epoch):
        # 训练
        model.train()
        # 调用模型中的forward函数
        logits = model(features,edge_index)
        loss = F.cross_entropy(logits[train_mask], labels[train_mask], weight=torch.tensor([1., weight]))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        #验证
        model.eval()
        probs = logits.softmax(1)
        f1, thres = get_best_f1(labels[val_mask], probs[val_mask])
        preds = numpy.zeros_like(labels)
        preds[probs[:, 1] > thres] = 1
        trec = recall_score(labels[test_mask], preds[test_mask])
        tpre = precision_score(labels[test_mask], preds[test_mask])
        tmf1 = f1_score(labels[test_mask], preds[test_mask], average='macro')
        tauc = roc_auc_score(labels[test_mask], probs[test_mask][:, 1].detach().numpy())

        if best_f1 < f1:
            best_f1 = f1
            final_trec = trec
            final_tpre = tpre
            final_tmf1 = tmf1
            final_tauc = tauc
        print('Epoch {}, loss: {:.4f}, val mf1: {:.4f}, (best {:.4f})'.format(e, loss, f1, best_f1))

    time_end = time.time()
    print('time cost: ', time_end - time_start, 's')
    print('Test: REC {:.2f} PRE {:.2f} MF1 {:.2f} AUC {:.2f}'.format(final_trec*100,
                                                                     final_tpre*100, final_tmf1*100, final_tauc*100))
    return final_tmf1, final_tauc


# threshold adjusting for best macro f1
def get_best_f1(labels, probs):
    best_f1, best_thre = 0, 0
    for thres in np.linspace(0.05, 0.95, 19):
        #构建一个与labels同维度的数组,并初始化所有变量为零
        preds = np.zeros_like(labels)
        preds[probs[:,1] > thres] = 1
        #average='binary'：计算二分类问题中的 F1 分数（默认值）。
        #average='micro'：对所有类别的真实和预测样本进行汇总，然后计算 F1 分数。
        #average='macro'：计算每个类别的 F1 分数，然后取平均值。
        #average=None：返回每个类别的 F1 分数。
        # F1_score 详细原理间“备份”
        mf1 = f1_score(labels, preds, average='macro')
        if mf1 > best_f1:
            best_f1 = mf1
            best_thre = thres
    return best_f1, best_thre


In [3]:
parser = argparse.ArgumentParser(description='GPRGNN_anomaly')
parser.add_argument("--dataset", type=str, default="yelp",
                        help="Dataset for this model (yelp/amazon/tfinance/tsocial)")
parser.add_argument("--train_ratio", type=float, default=0.01, help="Training ratio")
parser.add_argument("--hid_dim", type=int, default=64, help="Hidden layer dimension")
parser.add_argument("--homo", type=int, default=1, help="1 for GCN_GAD(Homo) and 0 for GCN_GAD(Hetero)")
parser.add_argument("--epoch", type=int, default=200, help="The max number of epochs")
parser.add_argument("--run", type=int, default=1, help="Running times")
parser.add_argument('--net', type=str, choices=['JKNet', 'GPRGNN'],default='GPRGNN')
parser.add_argument('--K', type=int, default=10)
parser.add_argument('--Gamma', default=None)
parser.add_argument('--dprate', type=float, default=0.5)
parser.add_argument('--Init', type=str,
                        choices=['SGC', 'PPR', 'NPPR', 'Random', 'WS', 'Null'],
                        default='PPR')
parser.add_argument('--dropout', type=float, default=0.5)
parser.add_argument('--ppnp', default='GPR_prop',
                        choices=['PPNP', 'GPR_prop'])
parser.add_argument('--alpha', type=float, default=0.1)


args = parser.parse_args(args = [])
print(args)
dataset_name = args.dataset
homo = args.homo
graph = Dataset(dataset_name, homo).graph
#edge_index = Dataset(dataset_name, homo).edge_index

Namespace(dataset='tfinance', train_ratio=0.4, hid_dim=64, homo=1, epoch=200, run=1, net='GPRGNN', K=10, Gamma=None, dprate=0.5, Init='PPR', dropout=0.5, ppnp='GPR_prop', alpha=0.1)
Graph(num_nodes=39357, num_edges=42445086,
      ndata_schemes={'feature': Scheme(shape=(10,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={})


In [4]:
in_feats = graph.ndata['feature'].shape[1]
num_classes = 2

    
if homo:
    edges = graph.edges()
    edge_index =edges = torch.stack((edges[0], edges[1]))
else:#((dataset_name ==  "yelp") |(dataset_name ==  "amazon")) and hetero
    #三类边
    edges_upu = graph[graph.canonical_etypes[0]].edges()
    edge_index_upu = torch.stack(edges_upu)
    edges_usu = graph[graph.canonical_etypes[1]].edges()
    edge_index_usu = torch.stack(edges_usu)
    edges_uvu = graph[graph.canonical_etypes[2]].edges()
    edge_index_uvu = torch.stack(edges_uvu)
    
    # 合并连个Tensor，dim=1 按列合并
    combined_tensor = torch.cat((edge_index_upu, edge_index_usu), dim=1)
    edge_index = torch.cat((combined_tensor, edge_index_uvu), dim=1)


gnn_name = args.net
if gnn_name == 'JKNet':
    Net = GCN_JKNet
elif gnn_name == 'GPRGNN':
    Net = GPRGNN

if args.run == 0:
    if homo:
        print("hello")
        model = Net(in_feats, num_classes, graph,args)
    else:
        model = Net_Hetero(in_feats, num_classes,args)
        train(model, graph,edge_index, args)

else:
    final_mf1s, final_aucs = [], []
    for tt in range(args.run):
        if homo:
            #in_feats 特征点维度；h_feats：隐层维度；num_classes：节点分类数（nomal，anomaly）
            model = Net(in_feats, num_classes, graph,args)
        else:
            model = Net_Hetero(in_feats, num_classes, graph,args)
        mf1, auc = train(model, graph, edge_index, args)
        final_mf1s.append(mf1)
        final_aucs.append(auc)
    final_mf1s = np.array(final_mf1s)
    final_aucs = np.array(final_aucs)
    # np.std :计算全局标准差
    print('MF1-mean: {:.2f}, MF1-std: {:.2f}, AUC-mean: {:.2f}, AUC-std: {:.2f}'.format(100 * np.mean(final_mf1s),
                                                                                            100 * np.std(final_mf1s),
                                                               100 * np.mean(final_aucs), 100 * np.std(final_aucs)))

train/dev/test samples:  15742 7792 15823
cross entropy weight:  20.83356449375867


  _warn_prf(average, modifier, msg_start, len(result))


Epoch 0, loss: 499.3223, val mf1: 0.4883, (best 0.4883)


  _warn_prf(average, modifier, msg_start, len(result))


Epoch 1, loss: 410.5814, val mf1: 0.4883, (best 0.4883)


  _warn_prf(average, modifier, msg_start, len(result))


Epoch 2, loss: 330.6917, val mf1: 0.4883, (best 0.4883)


  _warn_prf(average, modifier, msg_start, len(result))


Epoch 3, loss: 259.7615, val mf1: 0.4883, (best 0.4883)


  _warn_prf(average, modifier, msg_start, len(result))


Epoch 4, loss: 197.8220, val mf1: 0.4883, (best 0.4883)


  _warn_prf(average, modifier, msg_start, len(result))


Epoch 5, loss: 144.8074, val mf1: 0.4883, (best 0.4883)


  _warn_prf(average, modifier, msg_start, len(result))


Epoch 6, loss: 100.5319, val mf1: 0.4883, (best 0.4883)
Epoch 7, loss: 64.6655, val mf1: 0.4882, (best 0.4883)
Epoch 8, loss: 36.7399, val mf1: 0.4876, (best 0.4883)
Epoch 9, loss: 16.3656, val mf1: 0.4809, (best 0.4883)
Epoch 10, loss: 18.4094, val mf1: 0.1449, (best 0.4883)
Epoch 11, loss: 28.4285, val mf1: 0.0516, (best 0.4883)
Epoch 12, loss: 33.1033, val mf1: 0.0567, (best 0.4883)
Epoch 13, loss: 33.3327, val mf1: 0.0713, (best 0.4883)
Epoch 14, loss: 31.4686, val mf1: 0.1152, (best 0.4883)
Epoch 15, loss: 29.4447, val mf1: 0.2035, (best 0.4883)
Epoch 16, loss: 28.4043, val mf1: 0.2732, (best 0.4883)
Epoch 17, loss: 28.3324, val mf1: 0.3250, (best 0.4883)
Epoch 18, loss: 28.8811, val mf1: 0.3708, (best 0.4883)
Epoch 19, loss: 29.8405, val mf1: 0.4140, (best 0.4883)
Epoch 20, loss: 30.7874, val mf1: 0.4331, (best 0.4883)
Epoch 21, loss: 31.2196, val mf1: 0.4432, (best 0.4883)
Epoch 22, loss: 31.0671, val mf1: 0.4588, (best 0.4883)
Epoch 23, loss: 30.3567, val mf1: 0.4736, (best 0.4

Epoch 154, loss: 0.4802, val mf1: 0.7623, (best 0.8044)
Epoch 155, loss: 0.5081, val mf1: 0.7312, (best 0.8044)
Epoch 156, loss: 0.4801, val mf1: 0.7677, (best 0.8044)
Epoch 157, loss: 0.5257, val mf1: 0.7271, (best 0.8044)
Epoch 158, loss: 0.5035, val mf1: 0.7403, (best 0.8044)
Epoch 159, loss: 0.5092, val mf1: 0.7280, (best 0.8044)
Epoch 160, loss: 0.5057, val mf1: 0.7326, (best 0.8044)
Epoch 161, loss: 0.4927, val mf1: 0.7463, (best 0.8044)
Epoch 162, loss: 0.4918, val mf1: 0.7463, (best 0.8044)
Epoch 163, loss: 0.4901, val mf1: 0.7475, (best 0.8044)
Epoch 164, loss: 0.4770, val mf1: 0.7668, (best 0.8044)
Epoch 165, loss: 0.5046, val mf1: 0.7348, (best 0.8044)
Epoch 166, loss: 0.4831, val mf1: 0.7513, (best 0.8044)
Epoch 167, loss: 0.5078, val mf1: 0.7273, (best 0.8044)
Epoch 168, loss: 0.4936, val mf1: 0.7407, (best 0.8044)
Epoch 169, loss: 0.4964, val mf1: 0.7399, (best 0.8044)
Epoch 170, loss: 0.4927, val mf1: 0.7421, (best 0.8044)
Epoch 171, loss: 0.4840, val mf1: 0.7503, (best 