In [1]:
import warnings
warnings.filterwarnings("ignore")

In [20]:
import os
import pickle
import numpy as np
import pandas as pd
import scipy.io as sio
import matplotlib.pyplot as plt

In [3]:
import sklearn.metrics as metrics
from sklearn.metrics import roc_auc_score
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import auc

### Ablation Study

In [4]:
def precision_recall_k(y_true, y_score, k=50):

    desc_sort_order = np.argsort(y_score)[::-1]
    y_true_sorted = y_true[desc_sort_order]

    true_positives = y_true_sorted[:k].sum()
    pk = true_positives / k
    rk = true_positives / np.sum(y_true)
    return pk, rk

In [32]:
import pickle
for dataset in ["enron"]:
    for flag in ["", "_none", "_base"]:
        if os.path.exists(f"../outputs/{dataset}{flag}.mat"):
            data = sio.loadmat(f"../outputs/{dataset}{flag}.mat")
            with open(f"../outputs/{dataset}{flag}.pkl", 'wb') as fw:
                pickle.dump(data, fw)


In [46]:
dataset = "pubmed"
with open(f"../outputs/{dataset}.pkl", "rb") as fr:
    result = pickle.load(fr)
with open(f"../outputs/{dataset}_none.pkl", "rb") as fr:
    result_v1 = pickle.load(fr)
with open(f"../outputs/{dataset}_base.pkl", "rb") as fr:
    result_v2 = pickle.load(fr)

methods = ["Full", "w/o sampling", "w/o clustering"]
data_list = [result, result_v1, result_v2]

for data, method in zip(data_list, methods):
    print(f"{method}")
    labels, scores = data['labels'].flatten(), data['scores'].flatten()
    auc = roc_auc_score(labels, scores)
    print(f"AUC: {auc:.4f}")

    k_list = [10, 50, 100, 200, 300]
    for k in k_list:
        pk, rk = precision_recall_k(labels, scores, k)
        print(f"Precision@{k}: {pk:.4f}; Recall@{k}: {rk:.4f};")

Full
AUC: 0.9593
Precision@10: 0.9000; Recall@10: 0.0450;
Precision@50: 0.6200; Recall@50: 0.1550;
Precision@100: 0.5700; Recall@100: 0.2850;
Precision@200: 0.4450; Recall@200: 0.4450;
Precision@300: 0.3767; Recall@300: 0.5650;
w/o sampling
AUC: 0.9577
Precision@10: 0.8000; Recall@10: 0.0400;
Precision@50: 0.6200; Recall@50: 0.1550;
Precision@100: 0.5700; Recall@100: 0.2850;
Precision@200: 0.4500; Recall@200: 0.4500;
Precision@300: 0.3733; Recall@300: 0.5600;
w/o clustering
AUC: 0.9560
Precision@10: 0.7000; Recall@10: 0.0350;
Precision@50: 0.6600; Recall@50: 0.1650;
Precision@100: 0.5700; Recall@100: 0.2850;
Precision@200: 0.4500; Recall@200: 0.4500;
Precision@300: 0.3800; Recall@300: 0.5700;
