In [1]:
import sys
sys.path += ['../utils']
import csv
from tqdm import tqdm 
import collections
from collections import defaultdict
import math
import gzip
import pickle
import numpy as np
import faiss
import os
import pytrec_eval
import json
import time
from msmarco_eval import quality_checks_qids, compute_metrics, load_reference


use_gpu_faiss = False

In [2]:
def convert_to_string_id(result_dict):
    string_id_dict = {}

    # format [string, dict[string, val]]
    for k, v in result_dict.items():
        _temp_v = {}
        for inner_k, inner_v in v.items():
            _temp_v[str(inner_k)] = inner_v

        string_id_dict[str(k)] = _temp_v

    return string_id_dict


def dump_query_scores(run_files_dir, fname, results, r_qidmap):
    out_f = os.path.join(run_files_dir, f'query_scores.{fname}.txt')
    os.makedirs(run_files_dir, exist_ok=True)
    with open(out_f, 'w') as fout:
        print('MF_Qid\tPP_Qid\tndcg@10', file=fout)
        for k, v in results.items():
            print(f"{r_qidmap[int(k)]}\t{k}\t{v['ndcg_cut_10']}", file=fout)


def EvalDevQuery(query_embedding2id, passage_embedding2id, dev_query_positive_id, I_nearest_neighbor,
                 r_qidmap, r_pidmap,
                 topN, run_files_dir, shortcut=False, cuts=None):
    prediction = {} #[qid][docid] = docscore, here we use -rank as score, so the higher the rank (1 > 2), the higher the score (-1 > -2)

    total_5 = 0
    labeled_5 = 0
    total_20 = 0
    labeled_20 = 0
    total_100 = 0
    labeled_100 = 0
    Atotal = 0
    Alabeled = 0
    qids_to_ranked_candidate_passages = {}
    for query_idx in range(len(I_nearest_neighbor)): 
        seen_pid = set()
        query_id = query_embedding2id[query_idx]
        prediction[query_id] = {}
        

        top_ann_pid = I_nearest_neighbor[query_idx].copy()
        selected_ann_idx = top_ann_pid[:topN]
        rank = 0
        
        if query_id in qids_to_ranked_candidate_passages:
            pass    
        else:
            # By default, all PIDs in the list of 1000 are 0. Only override those that are given
            tmp = [0] * 1000
            qids_to_ranked_candidate_passages[query_id] = tmp
                
        for idx in selected_ann_idx:
            pred_pid = passage_embedding2id[idx]
            
            if not pred_pid in seen_pid:
                # this check handles multiple vector per document
                qids_to_ranked_candidate_passages[query_id][rank]=pred_pid
                Atotal += 1
                
                if not shortcut:
                    if pred_pid not in dev_query_positive_id[query_id]:
                        Alabeled += 1
                    if rank < 5:
                        total_5 += 1
                        if pred_pid not in dev_query_positive_id[query_id]:
                            labeled_5 += 1
                    if rank < 20:
                        total_20 += 1
                        if pred_pid not in dev_query_positive_id[query_id]:
                            labeled_20 += 1
                    if rank < 100:
                        total_100 += 1
                        if pred_pid not in dev_query_positive_id[query_id]:
                            labeled_100 += 1
                            
                rank += 1
                if not (task=='arguana' and r_qidmap[query_id]==r_pidmap[pred_pid]):
                    prediction[query_id][pred_pid] = -rank
                seen_pid.add(pred_pid)
        
    if shortcut:
        return [prediction] + [None]*15
    # use out of the box evaluation script
    evaluator = pytrec_eval.RelevanceEvaluator(
        convert_to_string_id(dev_query_positive_id), {'map_cut', 'ndcg_cut', 'recip_rank','recall', 'P'})

    eval_query_cnt = 0
    result = evaluator.evaluate(convert_to_string_id(prediction))
    dump_query_scores(run_files_dir, ckpt_choice.replace('/', '__'), result, r_qidmap)
    
    qids_to_relevant_passageids = {}
    for qid in dev_query_positive_id:
        qid = int(qid)
        if qid in qids_to_relevant_passageids:
            pass
        else:
            qids_to_relevant_passageids[qid] = []
            for pid in dev_query_positive_id[qid]:
                if pid>0:
                    qids_to_relevant_passageids[qid].append(pid)
            
    ms_mrr = compute_metrics(qids_to_relevant_passageids, qids_to_ranked_candidate_passages)

    ndcg = defaultdict(int)
    precision = defaultdict(int)
    Map = defaultdict(int)
    mrr = 0
    recall = 0
    recall_100 = 0
    recall_10 = 0

    for k in result.keys():
        eval_query_cnt += 1
        for cut in cuts:
            ndcg[cut] += result[k][f"ndcg_cut_{cut}"]
            precision[cut] += result[k][f"P_{cut}"]
            Map[cut] += result[k][f"map_cut_{cut}"]
        mrr += result[k]["recip_rank"]
        recall += result[k]["recall_"+str(topN)]
        recall_100 += result[k]["recall_100"]
        recall_10 += result[k]["recall_10"]

    final_ndcg, final_precision, final_Map = {}, {}, {}
    for cut in cuts:
        final_ndcg[cut] = ndcg[cut] / eval_query_cnt
        final_precision[cut] = precision[cut] / eval_query_cnt
        final_Map[cut] = Map[cut] / eval_query_cnt
    final_mrr = mrr / eval_query_cnt
    final_recall = recall / eval_query_cnt
    final_recall_100 = recall_100 / eval_query_cnt
    final_recall_10 = recall_10 / eval_query_cnt
    hole_rate_5 = labeled_5 / total_5
    hole_rate_20 = labeled_20 / total_20
    hole_rate_100 = labeled_100 / total_100
    Ahole_rate = Alabeled/Atotal
    
    return (prediction, final_ndcg, final_precision,
            eval_query_cnt, final_Map, final_mrr, final_recall, final_recall_100, final_recall_10,
            hole_rate_5, hole_rate_20, hole_rate_100, ms_mrr, Ahole_rate, result, prediction)


def dump_pred_to_trecrun(prediction, run_files_dir, run_type, r_qidmap, r_pidmap):
    mf_prediction = {  # marco-format prediction
        r_qidmap[qid]: {r_pidmap[k]: v for k,v in values.items()}
        for qid, values in prediction.items()
    }

    # dump it as trec run file
    os.makedirs(run_files_dir, exist_ok=True)
    out_f = os.path.join(run_files_dir, f'runs.{run_type}.txt')
#     print('Dump to', out_f)
    with open(out_f, 'w') as fout:
        for qid, qpreds in mf_prediction.items():
            for pid, score in qpreds.items():
                print(f'{qid} Q0 {pid} {-score} {score} TEAM', file=fout)
    return out_f


def faiss_search_with_gpu_partitioning(passages, queries, topN):
    p_size = int(1000000)
    partition_results = []
    for i_p in range(math.ceil(passages.shape[0] / p_size)):
        cpu_index = faiss.IndexFlatIP(dim)
        gpu_index = faiss.index_cpu_to_all_gpus(cpu_index)
        current_passages = passages[i_p*p_size : (i_p+1)*p_size]
        print(i_p, current_passages.shape)
        start_time = time.time()
        gpu_index.add(current_passages)
        partition_results.append(gpu_index.search(queries, topN))
        print(time.time() - start_time)
        del cpu_index
        del gpu_index
    
    # merge
    collection = []
    for i_p in range(len(partition_results)):
        collection.append(np.dstack([
            partition_results[i_p][0],
            partition_results[i_p][1] + p_size*i_p  # fix the offset
        ]))  # [query, passage, DISTANCE/PID]
    print('end of stacking'); del partition_results
    collection = np.concatenate(collection, axis=1)
    print('end of concat')
    
    results = []
    for i, query_result in enumerate(collection):
        if i%10000 == 0:
            print(i)
        results.append([int(x[1]) for x in sorted(query_result, key=lambda x: x[0], reverse=True)[:topN]])
    
    return None, np.array(results)         
        
    

# end2end eval

In [3]:
def end2end(
    task,
    data_type,
    processed_data_dir,
    marcoformat_data_dir,
    run_files_dir,
    checkpoint_path,
    verbose
):
    # load qrel
    if data_type == 0:
        topN = 100
    else:
        topN = 1000
    dev_query_positive_id = {}

    query_positive_id_path = os.path.join(processed_data_dir, "dev-qrel.tsv")
    if verbose:
        print(query_positive_id_path)

    with open(query_positive_id_path, 'r', encoding='utf8') as f:
        tsvreader = csv.reader(f, delimiter="\t")
        for [topicid, docid, rel] in tsvreader:
            topicid = int(topicid)
            docid = int(docid)
            if topicid not in dev_query_positive_id:
                dev_query_positive_id[topicid] = {}
            rel_score = max(int(rel), 0)
            dev_query_positive_id[topicid][docid] = rel_score
    
    
    # load mapping
    qidmap_path = processed_data_dir+"/qid2offset.pickle"
    pidmap_path = processed_data_dir+"/pid2offset.pickle"

    with open(qidmap_path, 'rb') as handle:
        qidmap = pickle.load(handle)

    with open(pidmap_path, 'rb') as handle:
        pidmap = pickle.load(handle)

    r_qidmap = {v: k for k,v in qidmap.items()}
    r_pidmap = {v: k for k,v in pidmap.items()}
    
    
    # load embeddings
    if task=='fulltreccovid':
        checkpoint_path = checkpoint_path.replace('fulltreccovid-ann', 'beirtreccovid-ann')
    dev_query_embedding = []
    dev_query_embedding2id = []
    passage_embedding = []
    passage_embedding2id = []
    for i in range(8):
        try:
            with open(checkpoint_path + "dev_query_0__emb_p__data_obj_"+str(i)+".pb", 'rb') as handle:
                dev_query_embedding.append(pickle.load(handle))
            with open(checkpoint_path + "dev_query_0__embid_p__data_obj_"+str(i)+".pb", 'rb') as handle:
                dev_query_embedding2id.append(pickle.load(handle))
            with open(checkpoint_path + "passage_0__emb_p__data_obj_"+str(i)+".pb", 'rb') as handle:
                passage_embedding.append(pickle.load(handle))
            with open(checkpoint_path + "passage_0__embid_p__data_obj_"+str(i)+".pb", 'rb') as handle:
                passage_embedding2id.append(pickle.load(handle))
        except FileNotFoundError:
            break
        if verbose:
            print(i)
    if verbose:
        print(checkpoint_path + "dev_query_0__emb_p__data_obj_"+str(i)+".pb")
    if (not dev_query_embedding) or (not dev_query_embedding2id) or (not passage_embedding) or not (passage_embedding2id):
        print("No data found for checkpoint")

    dev_query_embedding = np.concatenate(dev_query_embedding, axis=0)
    dev_query_embedding2id = np.concatenate(dev_query_embedding2id, axis=0)
    passage_embedding = np.concatenate(passage_embedding, axis=0)
    passage_embedding2id = np.concatenate(passage_embedding2id, axis=0)
    
    
    # search
    if verbose:
        print('search')
        print(passage_embedding.shape, dev_query_embedding.shape)
    dim = passage_embedding.shape[1]
    faiss.omp_set_num_threads(64)
    if not use_gpu_faiss:
        cpu_index = faiss.IndexFlatIP(dim)
        start_time = time.time()
        cpu_index.add(passage_embedding)
        index_time = time.time()
        if verbose:
            print(f"Indexing: {index_time-start_time}")
        _, dev_I = cpu_index.search(dev_query_embedding, topN)
        search_time = time.time()
        if verbose:
            print(f"searching: {search_time-start_time}")
    else:
        start_time = time.time()    
        _, dev_I = faiss_search_with_gpu_partitioning(passage_embedding, dev_query_embedding, topN)
        if verbose:
            print(f"GPU index and search: {time.time()-start_time}")
    
    
    # evaluation
    cuts = [5, 10, 20]
    if verbose:
        print('eval', task, ckpt_choice)
    result = EvalDevQuery(dev_query_embedding2id, passage_embedding2id, dev_query_positive_id, dev_I,
                          r_qidmap, r_pidmap,
                          topN, run_files_dir, shortcut=False, cuts=cuts)

    (prediction, final_ndcg, final_precision, eval_query_cnt,
     final_Map, final_mrr, final_recall,
     final_recall_100, final_recall_10,
     hole_rate_5, hole_rate_20, hole_rate_100,
     ms_mrr, Ahole_rate, metrics, prediction) = result
    if verbose:
        for cut in cuts:
            print(f"NDCG@{cut}:" + str(final_ndcg[cut]))
            print(f"P@{cut}:" + str(final_precision[cut]))
            print(f"map@{cut}:" + str(final_Map[cut]))
        print("pytrec_mrr:" + str(final_mrr))
        print("recall@"+str(topN)+":" + str(final_recall))
        print("recall@10"+":" + str(final_recall_10))
        print("recall@100"+":" + str(final_recall_100))
        print("ms_mrr:" + str(ms_mrr))
    trec_run_fname = dump_pred_to_trecrun(
        prediction, run_files_dir, ckpt_choice.replace('/', '__'), r_qidmap, r_pidmap)


    # calculate hole rates
    labels = defaultdict(dict) # [qid][pid] = 0/1
    valid_query = set()
    with open(os.path.join(marcoformat_data_dir, 'qrels.tsv')) as fin:
        for line in fin:
            a = line.strip().split('\t')
            qid, pid, label = a[0], a[2], a[3]
            label = 0 if int(label)<=0 else 1
            labels[qid][pid] = label
            if label==1:
                valid_query.add(qid)

    cuts = [5, 10, 20, 100]
    total, missing = {cut: 0 for cut in cuts}, {cut: 0 for cut in cuts}

    run_fname = trec_run_fname  # from previous dense retrieval results
    # run_fname = f"/home/xinji/pseudo-root/rblob/{task}/runs/runs.bm25.txt"  # bm25
    with open(run_fname) as fin:
        for line in fin:
            a = line.strip().split(' ')
            qid, pid, rank = a[0], a[2], int(a[3])
            if qid not in valid_query:
                continue
            for cut in cuts:
                if rank <= cut:
                    total[cut] += 1
                    if qid not in labels or pid not in labels[qid]:
                        missing[cut] += 1

    if verbose:
        for cut in cuts:
            print(f'hole@{cut}:', missing[cut]/total[cut])


    if verbose:
        print('\n')
        print('ndcg@10', final_ndcg[10])
        print('recall@100', final_recall_100)
        print('hole@10', missing[10]/total[10])
    else:
        print('{:.4f}\t{:.4f}\t{:.4f}'.format(
            final_ndcg[10], final_recall_100, missing[10]/total[10]
        ))

# Define params below

In [4]:
# it's advised to do "restart kernel and run all" for a new dataset or new checkpoint - to avoid namespace confusion

cqa = False
data_type = 1 # 0 for document, 1 for passage


if not cqa:
    task = 'treccovid'
    ckpt_choice = f'modir-{task}-10k'
    
    processed_data_dir = f'../data/{task}/preprocessed_data/'
    marcoformat_data_dir = f'../data/{task}/marco-format/'
    run_files_dir = f'eval_tmp/'
    inference_output_dir = f'../inference_output/'
    
    end2end(
        task,
        data_type,
        processed_data_dir,
        marcoformat_data_dir,
        run_files_dir,
        inference_output_dir,
        verbose=True,
    )
    
else:
    subsets = [
        'android',
        'english',
        'gaming',
        'gis',
        'mathematica',
        'physics',
        'programmers',
        'stats',
        'tex',
        'unix',
        'webmasters',
        'wordpress',
    ]
    
    print(f'ndcg\trec\thole')
    for task in subsets:
        ckpt_choice = f'modir-cqadupstack-10k'
        
        processed_data_dir = f'../data/cqadupstack/{task}/preprocessed_data/'
        marcoformat_data_dir = f'../data/cqadupstack/{task}/marco-format/'
        run_files_dir = f'eval_tmp/'
        inference_output_dir = f'../inference_output/{task}'
        
        print(task)
        end2end(
            task,
            data_type,
            processed_data_dir,
            marcoformat_data_dir,
            run_files_dir,
            inference_output_dir,
            verbose=False,
        )
        print()

../data/treccovid/preprocessed_data/dev-qrel.tsv
0
1
2
3
../inference_output/dev_query_0__emb_p__data_obj_4.pb
search
(171331, 768) (50, 768)
Indexing: 0.12249016761779785
searching: 0.3689405918121338
eval treccovid modir-treccovid-10k
NDCG@5:0.7146378341011254
P@5:0.74
map@5:0.008810734367829533
NDCG@10:0.6761737743183631
P@10:0.698
map@10:0.015274506169486139
NDCG@20:0.6258467684317476
P@20:0.6469999999999999
map@20:0.02570818829818436
pytrec_mrr:0.870468253968254
recall@1000:0.3482583404325919
recall@10:0.017418426100629346
recall@100:0.10514951192393761
ms_mrr:{'MRR @10': 0.9495238095238094, 'QueriesRanked': 50}
hole@5: 0.16
hole@10: 0.192
hole@20: 0.244
hole@100: 0.411


ndcg@10 0.6761737743183631
recall@100 0.10514951192393761
hole@10 0.192
