# 下游任务Demo：相似度预估

In [None]:
import os
import numpy as np
import pandas as pd
import pickle
import faiss
import time
from tqdm import tqdm
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics.pairwise import paired_cosine_distances

In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

## 1. 获取题目表征

In [None]:
# 读取相似度预估下游任务题目数据，格式：每行一道题目文本
with open({'path/to/your/data/math.tsv'}, 'r') as f:
    lines = f.readlines()
ques = []
for line in lines:
    ques.append(line.strip('\n'))

In [None]:
# 以DisenQNet为例

from EduNLP.Pretrain import DisenQTokenizer
from EduNLP.Vector import DisenQModel

path = "/path/to/disenqnet/checkpoint"
tokenizer = DisenQTokenizer.from_pretrained(path)
t2v = DisenQModel(path, device="cuda")

In [None]:
ques_emb = []
with np.errstate(all='raise'):
    for i, text in enumerate(tqdm(ques)):
        encodes = tokenizer([text], key=lambda x: x)
        emb = t2v.infer_vector(encodes, key=lambda x: x["stem"], vector_type="k").detach().cpu().reshape(-1).numpy()
        ques_emb.append(emb)
ques_emb = np.array(ques_emb)
ques_emb.shape

In [None]:
with open('./cache/disenq_300_embs.pkl', 'wb') as f:
    pickle.dump(ques_emb, f)

## 2. 相似度预估 Ranking

In [None]:
# 读取数据
sim = pd.read_csv('/path/to/your/data/similarity.csv')
test_id1 = []
test_id2 = []
labels = []
for i, line in sim.iterrows():
    id1, id2, _, _, _, sim = line
    try:
        idx1 = id1-1
        idx2 = id2-1
        score = sum([int(x) for x in sim.split('|')]) / 3
        test_id1.append(idx1)
        test_id2.append(idx2)
        labels.append(score)
    except:
        print(id1, id2, score)
np.array(labels)

In [None]:
def compute_ranking_metrics(ques_emb):
    ques_emb1 = ques_emb[test_id1]
    ques_emb2 = ques_emb[test_id2]
    cosine_scores = 1 - (paired_cosine_distances(ques_emb1, ques_emb2))
    pearson_cosine, _ = pearsonr(labels, cosine_scores)
    spearman_cosine, _ = spearmanr(labels, cosine_scores)
    print(f'Pearson: {pearson_cosine:.4f}, Spearman: {spearman_cosine:.4f}')

In [None]:
# 读取Step1中保存的题目表征
with open('./cache/disenq_300_embs.pkl', 'rb') as f:
    embs = pickle.load(f)

In [None]:
compute_ranking_metrics(embs)

## 3. 相似度预估 Recall

In [None]:
# 读取Step1中保存的题目表征
with open('./cache/disenq_300_embs.pkl', 'rb') as f:
    embs = pickle.load(f)

norm_embs = embs / (np.linalg.norm(embs, ord=2, axis=-1, keepdims=True) + 1e-12)
norm_embs = norm_embs.astype('float32')

In [None]:
dim = norm_embs.shape[-1]
param = 'IVF512,PQ15'
measure = faiss.METRIC_L2
index = faiss.index_factory(dim, param, measure)
index.train(norm_embs)
index.add(norm_embs)
faiss.write_index(index, './index/disenq.index')

In [None]:
# 读取数据并按照recall任务进行处理
sim = pd.read_csv('/path/to/your/data/similarity.csv')
query = {}
for i, line in sim.iterrows():
    id1, id2, _, _, _, sim = line
    id1 = int(id1)
    id2 = int(id2)
    score = sum([int(x) for x in sim.split('|')]) / 3
    if score >= 5:
        if id1 in query:
            query[id1].append((id2, score))
        else:
            query[id1] = [(id2, score)]
        if id2 in query:
            query[id2].append((id1, score))
        else:
            query[id2] = [(id1, score)]
for k in query:
    query[k].sort(key=lambda x: x[1], reverse=True)

In [None]:
def compute_recall_metrics(query, result, p=100):
    total_hr, total_ndcg = 0, 0
    for k, v in query.items():
        res = result[k][:p]
        hit, dcg, idcg = 0, 0, 0
        for i, (label, score) in enumerate(v):
            idcg += (2 ** score - 1) / np.log2(i + 2)
            if label in res:
                hit += 1
                dcg += (2 ** score - 1) / np.log2(res.index(label) + 2)
        total_hr += (hit / len(v))
        total_ndcg += (dcg / idcg)
    print(f'HR@{p}: {total_hr / len(query):.4f}, NDCG@{p}: {total_ndcg / len(query):.4f}')

In [None]:
avg_time = 0
for _ in range(5):
    result = {}
    total_time = 0
    for k in tqdm(query):
        idx = k-1
        start = time.time()
        _, idxs = index.search(norm_embs[idx].reshape(1, -1), 101)
        end = time.time()
        total_time += (end - start) * 1000
        res_ids = idxs.tolist()[0]
        if idx in res_ids:
            res_ids.remove(idx)
        result[k] = []
        for i in res_ids[:100]:
            try:
                result[k].append(i+1)
            except:
                pass
    print('Average time: ', total_time / len(query))
    avg_time += total_time / len(query)
    compute_recall_metrics(query, result, 10)
    compute_recall_metrics(query, result, 20)
    compute_recall_metrics(query, result, 30)
    compute_recall_metrics(query, result, 50)
    compute_recall_metrics(query, result, 100)
print(avg_time / 5)