In [1]:
import json
import argparse
from dataclasses import dataclass
import pandas as pd
import sys
sys.path.append('../')
from utils import load_simple_json,load_json

In [2]:
def process_prediction(pred,thresh=0.5,topk=None):
    '''
    input: raw prediction
    return: dictionary with id and page value
    '''
    re_dic={}
    for qid in pred:
        buf_lst=[]
        scores = pred[qid]["score"]
        if topk:
            sorted_lst = sorted(scores,reverse=True)[:topk]
            [buf_lst.append(pred[qid]["page_ids"][scores.index(i)]) for i in sorted_lst]
        else:
            [buf_lst.append(pred[qid]["page_ids"][idx]) for idx,value in 
                    enumerate(scores) if value >= thresh]
        re_dic[qid]=buf_lst
    return re_dic
def join_data_with_preds(ori_data,preds):
    '''
    input: original data and processed predcition
    return: joint data with prediciton
    '''
    re_lst=[]
    for i in ori_data:
        id_buf = str(i["id"])
        if id_buf in preds.keys():
            i["predicted_pages_1"]=preds[id_buf]
            re_lst.append(i)
    return re_lst

In [3]:
def calculate_precision(data,predictions: pd.Series) -> None:
    precision = 0
    count = 0

    for i, d in enumerate(data):
        if d["label"] == "NOT ENOUGH INFO":
            continue

        # Extract all ground truth of titles of the wikipedia pages
        # evidence[2] refers to the title of the wikipedia page
        gt_pages = set([
            evidence[2]
            for evidence_set in d["evidence"]
            for evidence in evidence_set
        ])

        predicted_pages = predictions.iloc[i]
        hits = predicted_pages.intersection(gt_pages)
        if len(predicted_pages) != 0:
            precision += len(hits) / len(predicted_pages)

        count += 1

    # Macro precision
    print(f"Precision: {precision / count}")
def at_least_get_one(data,predictions: pd.Series) -> None:
    precision = 0
    count = 0
    hit_counts=0
    for i, d in enumerate(data):
        if d["label"] == "NOT ENOUGH INFO":
            continue
        predicted_pages = predictions.iloc[i]
        # Extract all ground truth of titles of the wikipedia pages
        # evidence[2] refers to the title of the wikipedia page
        at_least_get_one = False
        for evidence_set in d["evidence"]:
            evid_buf=[]
            for evidence in evidence_set:
                evid_buf.append(evidence[2])
            hit_count = 0
            for evid in evid_buf:
                if evid in predicted_pages:
                    hit_count+=1
            if hit_count == len(evid_buf):
                at_least_get_one=True
        hit_counts+=int(at_least_get_one)
        count+=1
    # Macro precision
    print(f"at least get one rate : {hit_counts / count}")


def calculate_recall(data,predictions: pd.Series) -> None:
    recall = 0
    count = 0

    for i, d in enumerate(data):
        if d["label"] == "NOT ENOUGH INFO":
            continue

        gt_pages = set([
            evidence[2]
            for evidence_set in d["evidence"]
            for evidence in evidence_set
        ])
        predicted_pages = predictions.iloc[i]
        hits = predicted_pages.intersection(gt_pages)
        recall += len(hits) / len(gt_pages)
        count += 1

    print(f"Recall: {recall / count}")

In [11]:
def join_kfold_preds(pred_lst):
    buf_dic = pred_lst[0]
    for i in range(len(pred_lst)):
        if i ==0:
            continue
        for key in pred_lst[i]:
            if key in buf_dic.keys():
                buf_dic[key]+=pred_lst[i][key]
            else:
                buf_dic[key]=pred_lst[i][key]
    return buf_dic

In [4]:
ori_data = load_json("../preprocess/pre_train_doc100.jsonl")
len(ori_data)

3839

In [159]:
processed_preds=[]
preds = ["../pert_large/page/doc100_cluster_folds_recall/val_f0.json","../pert_large/page/doc100_cluster_folds_recall/val_f1.json"]
for pred_path in preds:
    pred = load_simple_json(pred_path)
    processed_preds.append(process_prediction(pred,0.0015,None))
paired_data = join_data_with_preds(ori_data,join_kfold_preds(processed_preds))

In [160]:
predictions = pd.Series([set(elem["predicted_pages_1"]) for elem in paired_data])

In [161]:
calculate_precision(paired_data,predictions)
calculate_recall(paired_data,predictions)
at_least_get_one(paired_data,predictions)

Precision: 0.24516348320572232
Recall: 0.8420076377523185
at least get one rate : 0.7937806873977087
