In [1]:
import json
import argparse
from dataclasses import dataclass
import pandas as pd
import sys
sys.path.append('../')
from utils import load_simple_json,load_json

In [2]:
def process_prediction(pred,thresh=0.5,topk=None):
    '''
    input: raw prediction
    return: dictionary with id and page value
    '''
    re_dic={}
    for qid in pred:
        buf_lst=[]
        scores = pred[qid]["score"]
        if topk:
            sorted_lst = sorted(scores,reverse=True)[:topk]
            [buf_lst.append(pred[qid]["page_ids"][scores.index(i)]) for i in sorted_lst]
        else:
            [buf_lst.append(pred[qid]["page_ids"][idx]) for idx,value in 
                    enumerate(scores) if value >= thresh]
        re_dic[qid]=buf_lst
    return re_dic
def join_data_with_preds(ori_data,preds,key = "predicted_pages_1"):
    '''
    input: original data and processed predcition
    return: joint data with prediciton
    '''
    re_lst=[]
    for i in ori_data:
        id_buf = str(i["id"])
        if id_buf in preds.keys():
            i[key]=preds[id_buf]
            re_lst.append(i)
    return re_lst

In [3]:
def calculate_precision(data,predictions: pd.Series) -> None:
    precision = 0
    count = 0

    for i, d in enumerate(data):
        if d["label"] == "NOT ENOUGH INFO":
            continue

        # Extract all ground truth of titles of the wikipedia pages
        # evidence[2] refers to the title of the wikipedia page
        gt_pages = set([
            evidence[2]
            for evidence_set in d["evidence"]
            for evidence in evidence_set
        ])

        predicted_pages = predictions.iloc[i]
        hits = predicted_pages.intersection(gt_pages)
        if len(predicted_pages) != 0:
            precision += len(hits) / len(predicted_pages)

        count += 1

    # Macro precision
    print(f"Precision: {precision / count}")
def at_least_get_one(data,predictions: pd.Series) -> None:
    precision = 0
    count = 0
    hit_counts=0
    for i, d in enumerate(data):
        if d["label"] == "NOT ENOUGH INFO":
            continue
        predicted_pages = predictions.iloc[i]
        # Extract all ground truth of titles of the wikipedia pages
        # evidence[2] refers to the title of the wikipedia page
        at_least_get_one = False
        for evidence_set in d["evidence"]:
            evid_buf=[]
            for evidence in evidence_set:
                evid_buf.append(evidence[2])
            hit_count = 0
            for evid in evid_buf:
                if evid in predicted_pages:
                    hit_count+=1
            if hit_count == len(evid_buf):
                at_least_get_one=True
        hit_counts+=int(at_least_get_one)
        count+=1
    # Macro precision
    print(f"at least get one rate : {hit_counts / count}")


def calculate_recall(data,predictions: pd.Series) -> None:
    recall = 0
    count = 0

    for i, d in enumerate(data):
        if d["label"] == "NOT ENOUGH INFO":
            continue

        gt_pages = set([
            evidence[2]
            for evidence_set in d["evidence"]
            for evidence in evidence_set
        ])
        predicted_pages = predictions.iloc[i]
        hits = predicted_pages.intersection(gt_pages)
        recall += len(hits) / len(gt_pages)
        count += 1

    print(f"Recall: {recall / count}")

In [235]:
def unique(input_lst):
    elem_set = []
    for elem in input_lst:
        if elem not in elem_set:
            elem_set.append(elem)
    return elem_set
def join_kfold_preds(pred_lst):
    buf_dic = pred_lst[0]
    for i in range(len(pred_lst)):
        if i ==0:
            continue
        for key in pred_lst[i]:
            if key in buf_dic.keys():
                buf_dic[key]+=pred_lst[i][key]
            else:
                buf_dic[key]=pred_lst[i][key]
    for key in buf_dic:#post process clean duplicate
        buf_dic[key]=unique(buf_dic[key])
    return buf_dic

In [5]:
ori_data = load_json("../preprocess/pre_train_0522.jsonl")
len(ori_data)

11349

In [28]:
processed_preds=[]
preds = ["../pert_large/page/0522_base_cluster/val_f0.json"]
thr_param =[]
for pred_path in preds:
    pred = load_simple_json(pred_path)
    processed_preds.append(process_prediction(pred,0.005,None))
paired_data = join_data_with_preds(ori_data,join_kfold_preds(processed_preds))

In [29]:
#paired_data = [i for i in paired_data]
predictions = pd.Series([set(elem["predicted_pages_1"]) for elem in paired_data])

In [30]:
calculate_precision(paired_data,predictions)
calculate_recall(paired_data,predictions)
at_least_get_one(paired_data,predictions)
#0.7855973813420621 Precision: 0.5449945444626291 f0 recall
#0.7823240589198036 Precision: 0.5635077097843053 f0 f1
#0.812111801242236 Precision: 0.5406277728482696 f1 recall
#0.7839607201309329 Precision: 0.6245382277297165 f0 recall b14


Precision: 0.8921061318808289
Recall: 0.9125024070864624
at least get one rate : 0.8700173310225303


preprocess page data

In [415]:
for i in paired_data:
    i["predicted_pages_1"]=list(set(i["predicted_pages_1"]))

In [417]:
with open("../pert_large/page/doc100_cluster_folds_recall/all.jsonl","w",encoding="utf8",) as f:
                for i, d in enumerate(paired_data):
                    f.write(json.dumps(d, ensure_ascii=False) + "\n")

evaluate test data

In [340]:
ori_data = load_json("../preprocess/pre_all_0522.jsonl")
len(ori_data)

11349

In [341]:
processed_preds=[]
#preds = ["no_info_f0","no_info_f1","no_info_f2","no_info_f3","no_info_f4"]
preds = ["all_test_f0","all_test_f1","all_test_f2","all_test_f3","all_test_f4"]
#thr_param =[0.001, 0.001,0.004,0.004,]
for idx,pred_path in enumerate(preds):
    pred = load_simple_json("../pert_large/page/0522_base_clu_folds/no_info/"+pred_path+".json")
    processed_preds.append(process_prediction(pred,0.5,None))
paired_data = join_data_with_preds(ori_data,join_kfold_preds(processed_preds))

In [342]:
len(paired_data)

3268

In [343]:
paired_data[:1]

[{'id': 8075,
  'label': 'NOT ENOUGH INFO',
  'claim': 'F.I.R.的團員有主唱Faye飛（詹雯婷）、吉他手Real阿沁（黃漢青）、鍵盤手Ian（陳建寧），是亞洲樂壇不常見的一女二男三人組合樂團。',
  'evidence': [[7208, None, None, None]],
  'predicted_pages_1': ['F.I.R.飛兒樂團']}]

In [326]:
#evaluate
predictions = pd.Series([set(elem["predicted_pages_1"]) for elem in paired_data])
calculate_precision(paired_data,predictions)
calculate_recall(paired_data,predictions)
at_least_get_one(paired_data,predictions)

Precision: 0.8458178527821373
Recall: 0.9234787087912093
at least get one rate : 0.8945054945054945


In [344]:
with open("../pert_large/page/0522_base_clu_folds/no_info_ens5.jsonl","w",encoding="utf8",) as f:
    for i, d in enumerate(paired_data):
        f.write(json.dumps(d, ensure_ascii=False) + "\n")

#new doc retriecal

In [15]:
wiki_emb = load_json("../preprocess/pre_train_wikisearch_base.jsonl")

In [16]:
wiki_preds = {str(i["id"]):i["result"] for i in wiki_emb}

In [17]:
paired_data = join_data_with_preds(ori_data,join_kfold_preds([wiki_preds]),key="predicted_pages")

In [18]:
paired_data[0]

{'id': 2663,
 'label': 'refutes',
 'claim': '天衛三軌道在天王星內部的磁層，以《仲夏夜之夢》作者緹坦妮雅命名。',
 'evidence': [[[4209, 4331, '天衛三', 2]]],
 'predicted_pages': ['天衛三',
  '天衛四',
  '天王星',
  '天王星的衛星',
  '仲夏夜之淫夢',
  '天衛二',
  '天衛一',
  '天衛十',
  '仲夏夜之夢_(消歧義)',
  '緹坦妮雅']}

In [19]:
predictions = pd.Series([set(elem["predicted_pages"]) for elem in paired_data])

In [20]:
predictions[0]

{'仲夏夜之夢_(消歧義)',
 '仲夏夜之淫夢',
 '天王星',
 '天王星的衛星',
 '天衛一',
 '天衛三',
 '天衛二',
 '天衛十',
 '天衛四',
 '緹坦妮雅'}

In [21]:
calculate_precision(paired_data,predictions)
calculate_recall(paired_data,predictions)
at_least_get_one(paired_data,predictions)

Precision: 0.10348846581817957
Recall: 0.926340723309388
at least get one rate : 0.9001733960862026
