In [23]:
import torch
from l2r import Rank
from sklearn import preprocessing
import numpy as np
import ujson
from tqdm.notebook import tqdm
np.set_printoptions(threshold=np.inf, suppress=True)

In [21]:
def process_StandardScal(x):
    standard_scaler = preprocessing.StandardScaler().fit(x)
    x_train_mean =standard_scaler.mean_
    x_train_std = standard_scaler.scale_
    x_train_standard = standard_scaler.transform(x)
    return x_train_standard, x_train_mean, x_train_std

def process_StandardScal_test(x,mean,std):
    standard_scaler = preprocessing.StandardScaler().fit(x)
    standard_scaler.mean_ = mean
    standard_scaler.std_ = std
    x_test = standard_scaler.transform(x)
    return x_test

def evaluate(pred_result,true_result):
    pred_data = {}
    for val in pred_result.values():
        for item in val:
            pred_data[item[0]] = item[1]
    true_data = {}
    for res in true_result:
        true_data[res[0]] = res[1]
    return (factor(pred_data, true_data))

def factor(pred_data,true_data):
    tp = 0
    fp = 0
    fn = 0
    tn = 0
    for key in pred_data.keys():
        if true_data[key]==1:
            tp += 1
        else:
            fp += 1
    for key in true_data.keys():
        if key not in pred_data.keys():
            if true_data[key]==1:
                fn += 1
            else:
                tn += 1
    precision = tp / (tp + fp)
    recall = tp / (tp + fn)
    f1_score = 2 * precision * recall / (precision + recall)
    accuracy = (tp+tn)/(tp+tn+fp+fn)
    print("factors in predict")
    print("precision: {},\nrecall: {},\nf1_score: {},\naccuracy: {}".format(precision, recall, f1_score, accuracy))
    return precision, recall, f1_score, accuracy

def plot(ndcg_record):
    x = np.array(list(ndcg_record.keys()))
    y = np.array(list(ndcg_record.values()))
    fig = plt.figure()
    axes = fig.add_axes([0.1, 0.1, 0.9, 0.9])
    axes.plot(x, y, 'r')
    plt.show()

In [2]:
model = "Lambdarank"
train = np.load('./train.npy')
test = np.load('./test.npy')
benchmark = np.load('./benchmark.npy')
alldata = np.load('./alldata.npy')
n_feature = 6
h1_units = 512
h2_units = 256
epoch = 10
lr = 0.0001
nt = 20
k = 1

In [3]:
alldata.size

113052942

In [27]:
# 获取训练数据集的大小
alldata_len = alldata.shape[0]
# 训练集和测试集拼接
data = np.concatenate((train,test),axis=0)
# 获取基于总量标准化的训练数据、以及均值和标准差
x_data, x_train_mean, x_train_std = process_StandardScal(data[:,3:])
# 对benchmark进行标准化，并拼接
x_alldata = process_StandardScal_test(alldata[:,3:],x_train_mean, x_train_std)
alldata = np.hstack((alldata[:,:3],x_alldata))

In [28]:
rank = Rank(rank_model=model, training_data=train, n_feature=n_feature, h1_units=h1_units, h2_units=h2_units, epoch=epoch, lr=lr, number_of_trees=nt)
rank.handler.model.load_state_dict(torch.load('model.pth'))
predict_result_score, predict_result_pair = rank.predict(alldata, 1)

100%|██████████| 1144441/1144441 [07:31<00:00, 2536.02it/s]


In [29]:
f = open('res.txt', 'w')
for k in tqdm(predict_result_pair.keys()):
    print(k, predict_result_pair[k], file=f)
f.close()

HBox(children=(FloatProgress(value=0.0, max=12561438.0), HTML(value='')))




In [9]:
!head res.txt

12522182 -2.83149790763855
12522183 -4.406253337860107
12522184 -2.715312957763672
12522185 -1.2879914045333862
12522186 -2.398498773574829
12522187 -2.063746690750122
12522188 -1.4096819162368774
12522189 -1.107888102531433
12522190 -2.152543783187866
12522191 -2.223851203918457


In [30]:
tmin = 0
for k in tqdm(predict_result_score):
    if k < tmin:
        tmin = k

HBox(children=(FloatProgress(value=0.0, max=12561438.0), HTML(value='')))




In [31]:
tmin

-3.852494239807129

In [32]:
threshold = [-2,-1.5,-1,-0.5,0,2,1.5,1,0.5]
for th in threshold:
    th_sum = 0
    for k in tqdm(predict_result_pair.keys()):
        if predict_result_pair[k] > th:
            th_sum+=1
    print(th, th_sum, float(th_sum/432140))

HBox(children=(FloatProgress(value=0.0, max=12561438.0), HTML(value='')))


-2 10879375 25.17557967325404


HBox(children=(FloatProgress(value=0.0, max=12561438.0), HTML(value='')))


-1.5 9111574 21.084773453047625


HBox(children=(FloatProgress(value=0.0, max=12561438.0), HTML(value='')))


-1 7062370 16.342782431619383


HBox(children=(FloatProgress(value=0.0, max=12561438.0), HTML(value='')))


-0.5 4819718 11.153140186050816


HBox(children=(FloatProgress(value=0.0, max=12561438.0), HTML(value='')))


0 3526412 8.160346184107002


HBox(children=(FloatProgress(value=0.0, max=12561438.0), HTML(value='')))


2 772794 1.788295459804693


HBox(children=(FloatProgress(value=0.0, max=12561438.0), HTML(value='')))


1.5 1193638 2.762155782848151


HBox(children=(FloatProgress(value=0.0, max=12561438.0), HTML(value='')))


1 1815042 4.2001249595038646


HBox(children=(FloatProgress(value=0.0, max=12561438.0), HTML(value='')))


0.5 2605968 6.030379043828389


In [33]:
threshold = [1.92, 2.5, 3]
for th in threshold:
    th_sum = 0
    for k in tqdm(predict_result_pair.keys()):
        if predict_result_pair[k] > th:
            th_sum+=1
    print(th, th_sum, float(th_sum/432140))

HBox(children=(FloatProgress(value=0.0, max=12561438.0), HTML(value='')))


1.92 828824 1.9179525153885315


HBox(children=(FloatProgress(value=0.0, max=12561438.0), HTML(value='')))


2.5 462295 1.0697806266487713


HBox(children=(FloatProgress(value=0.0, max=12561438.0), HTML(value='')))


3 236470 0.5472069236821401


In [41]:
aeid2pqid = {}
for item in tqdm(alldata):
    aeid2pqid[int(item[2])] = int(item[1])

HBox(children=(FloatProgress(value=0.0, max=12561438.0), HTML(value='')))




In [42]:
pqids = []
for k in tqdm(predict_result_pair.keys()):
    if predict_result_pair[k] > 3:
        pqids.append(aeid2pqid[int(k)])

HBox(children=(FloatProgress(value=0.0, max=12561438.0), HTML(value='')))




In [43]:
len(pqids)

236470

In [44]:
pqids[:10]

[1, 3, 3, 5, 5, 5, 5, 9, 14, 22]

In [46]:
relation_map = ujson.load(open('rel.json', 'r'))

In [47]:
all_pqid = open('data/allpq.txt', 'r')
allqid = {}
allidq = {}
flag = 0
for line in tqdm(all_pqid):
    c = line[:-1].replace("\"","").split('\t')
    if str(c[0]) + "@" + relation_map[str(c[1])] in allqid.keys():
        continue
    allqid[str(c[0]) + "@" + relation_map[str(c[1])]] = flag
    allidq[flag] = str(c[0]) + "@" + relation_map[str(c[1])]
    flag += 1

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




In [49]:
!head data/allpq.txt

"198244830"	"aim"
"198244830"	"aim"
"198244830"	"aim"
"198244830"	"aim"
"198244830"	"aim"
"198244830"	"aim"
"198244830"	"aim"
"198244830"	"method"
"198244830"	"method"
"198244830"	"method"


In [50]:
ana = {} 
for one_pqid in tqdm(pqids):
    paper_id = allidq[one_pqid].split('@')[0]
    question = allidq[one_pqid].split('@')[1]
    if paper_id not in ana.keys():
        ana[paper_id] = 1
    else:
        ana[paper_id] += 1

HBox(children=(FloatProgress(value=0.0, max=236470.0), HTML(value='')))




In [51]:
tmax = 0
for k in ana.keys():
    if ana[k] > tmax:
        tmax = ana[k]

In [52]:
tmax

24

In [54]:
len(ana.keys())

119151

In [55]:
for k in ana.keys():
    if ana[k] == 24:
        print(k)

39202247
