In [1]:
import re
import string

import jieba
import difflib
import os
from typing import List
from collections import Counter

def normalize_answer(s):
    """Lower text and remove punctuation, articles and extra whitespace."""

    def remove_articles(text):
        return re.sub(r"\b(a|an|the)\b", " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def normalize_zh_answer(s):
    """Lower text and remove punctuation, extra whitespace."""

    def white_space_fix(text):
        return "".join(text.split())

    def remove_punc(text):
        cn_punctuation = "！？｡。＂＃＄％＆＇（）＊＋，－／：；＜＝＞＠［＼］＾＿｀｛｜｝～｟｠｢｣､、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏."
        all_punctuation = set(string.punctuation + cn_punctuation)
        return "".join(ch for ch in text if ch not in all_punctuation)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_punc(lower(s)))

def count_score(prediction, ground_truth, **kwargs):
    numbers = re.findall(r"\d+", prediction)
    right_num = 0
    for number in numbers:
        if str(number) == str(ground_truth):
            right_num += 1
    final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
    return float(final_score)

def retrieval_score(prediction, ground_truth, **kwargs):
    pattern = r'Paragraph (\d+)'
    matches = re.findall(pattern, ground_truth)
    ground_truth_id = matches[0]
    numbers = re.findall(r"\d+", prediction)
    right_num = 0
    for number in numbers:
        if str(number) == str(ground_truth_id):
            right_num += 1
    final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
    return float(final_score)

def retrieval_zh_score(prediction, ground_truth, **kwargs):
    pattern = r'段落(\d+)'
    matches = re.findall(pattern, ground_truth)
    ground_truth_id = matches[0]
    numbers = re.findall(r"\d+", prediction)
    right_num = 0
    for number in numbers:
        if str(number) == str(ground_truth_id):
            right_num += 1
    final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
    return float(final_score)
def f1_score(prediction, ground_truth, **kwargs):
    common = Counter(prediction) & Counter(ground_truth)
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(prediction)
    recall = 1.0 * num_same / len(ground_truth)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1

def qa_f1_score(prediction, ground_truth, **kwargs):
    normalized_prediction = normalize_answer(prediction)
    normalized_ground_truth = normalize_answer(ground_truth)

    prediction_tokens = normalized_prediction.split()
    ground_truth_tokens = normalized_ground_truth.split()
    return f1_score(prediction_tokens, ground_truth_tokens)


def qa_f1_zh_score(prediction, ground_truth, **kwargs):
    prediction_tokens = list(jieba.cut(prediction, cut_all=False))
    ground_truth_tokens = list(jieba.cut(ground_truth, cut_all=False))
    prediction_tokens = [normalize_zh_answer(token) for token in prediction_tokens]
    ground_truth_tokens = [normalize_zh_answer(token) for token in ground_truth_tokens]
    prediction_tokens = [token for token in prediction_tokens if len(token) > 0]
    ground_truth_tokens = [token for token in ground_truth_tokens if len(token) > 0]
    return f1_score(prediction_tokens, ground_truth_tokens)
def scorer(predictions, answers):
    total_score = 0.
    for (prediction, ground_truths) in zip(predictions, answers):
        score = 0.

        for ground_truth in ground_truths:
            score = max(score, qa_f1_score(prediction, ground_truth,))
        total_score += score

    return round(100 * total_score / len(predictions), 2)

In [20]:
import json
predictions, answers,inputs = [], [],[]
dataset='multifieldqa_zh'
# if os.path.exists(filename):
folder='data/multifieldqa_zh'
answer_folder='data/zh_answer'
num=len(os.listdir(folder))
for i in range(num):
    file=f'data/multifieldqa_zh/output_{i}.json'
    answer_file=f'data/zh_answer/aoutput_{i}.json'
    if os.path.exists(file) and os.path.exists(answer_file):
        with open(file, 'r', encoding='utf-8') as f:
            file_data = json.load(f)
            with open(answer_file, 'r', encoding='utf-8') as answerf:
                answerf_data = json.load(answerf)
                input=file_data['input']
                answer=file_data['answers']
                prediction=answerf_data['answer']
                predictions.append(prediction)
                answers.append(answer)
print(predictions)
print(answers)
score = scorer(predictions, answers)
print('accuracy: ', score)

['厦门大学。', '地方财政过去五年用于社会保障支出的金额是6.5亿元。', '二审法院确定了赔偿金额为57081.86元。', '外戚庾亮的建议。', '奇力锅炉公司支付了10万元的预付款。', '10mm', '邓某因为在狱中接触到毒品，并在网络赌博中输掉近二十万元，导致一发不可收拾，最终学会了制造冰毒的技术，成为毒品贩子。', '文章主要讲了织田信长和丰臣秀吉这两位历史人物。', '在2012年，铜仁市的财政总收入增长了25.89%。', '禹在治理九州的过程中遇到了水灾的困难，经过九年的努力治水，但效果不佳，最后舜殛鯀於羽山以解决问题。', '美女姐姐、女总裁、交际花、女警花。', '许茂忠是在2014年1月23日将此案诉至法院的。', 'Brother Industries, Ltd。', '被告人王德山在2013年11月1日9时许被公安机关抓获。', '根据提供的文本内容，2018年韶关市的重点工作包括：\n1. 经济平稳增长，固定资产投资增长5%。\n2. 优化政务服务环境，推行“一门式一网式”政务服务模式，企业开办时间缩短至3.5个工作日。\n3. 改善群众民生福祉，民生投入增长5.4%，居民人均可支配收入增长8%。\n4. 实现预脱贫，6471户1.97万人脱贫。\n5. 推进教育现代化，新增学位1.29万个。\n6. 公立医院综合改革效果排名全省第三。\n\n这些是2018年韶关市的重点工作内容。', '40754.95元', '湖南省湘潭市中级人民法院', '凤凰是通过钢铁侠制造的机器将凤凰之力打散，然后附体到镭射眼等X战警身上的。', '万小霞提供的第二组证据证明了同一事实，能够互相印证。', '规档云的联系电话是021-50710282。', '黃智海居士的解釋方法有兩種：一種是解釋，另一種是釋解。', '一审判决认为郭海燕、范舒炜之间不存在真实的股权转让关系，而是股权赠与关系。', '25%', '导师培训考核采用积分制，每三学年为一轮次。每位导师每一轮次至少获得3个培训积分，同种培训方式最多只计1次。导师培训情况将与导师资格聘任及招生资格审核挂钩，并纳入导师聘期考核指标体系。', '从提供的内容中无法确定西北农林科技大学图书馆何时开始使用计算机流通管理系统。', '万小霞提供的第二组证据证明了她向相关职能部门提交了信访材料并得到了反馈

In [16]:
print('accuracy: ', score)

accuracy:  10.24


In [4]:
import json
# predictions, answers,inputs = [], [],[]
def count_answer_f1(folder,answer_folder):
    predictions, answers = [], []
    num=len(os.listdir(folder))
    for i in range(num):
        file=f'{folder}/output_{i}.json'
        answer_file=f'{answer_folder}/aoutput_{i}.json'
        if os.path.exists(file) and os.path.exists(answer_file):
            with open(file, 'r', encoding='utf-8') as f:
                file_data = json.load(f)
                with open(answer_file, 'r', encoding='utf-8') as answerf:
                    answerf_data = json.load(answerf)
                    input=file_data['input']
                    answer=file_data['answers']
                    prediction=answerf_data['answer']
                    predictions.append(prediction)
                    answers.append(answer)
    score = scorer(predictions, answers)
    return score
folder='data/multifieldqa_zh'
answer_folder='data/zh_answer'
answer_dataset=answer_folder.split('/')[-1]
print(f'{answer_dataset} accuracy: ', count_answer_f1(folder,answer_folder))

zh_answer accuracy:  7.0


In [6]:
folders=[
    # 'data/multifieldqa_zh',
    #      'data/multifieldqa_zh',
    #      'data/multifieldqa_zh',
         'data/multifieldqa_en',
         'data/multifieldqa_en',
         'data/multifieldqa_en',
         'data/qasper',
         'data/qasper',
         'data/qasper',
         'data/narrativeqa',
         'data/narrativeqa',
         'data/narrativeqa',
         'data/hotpotqa',
         'data/hotpotqa',
         'data/hotpotqa',
         ]
answer_folders=[
    # 'data/zh_answer',
    #             'data/zh_answer_1024',
    #             'data/zh_answer_2048',
                'data/en_answer',
                'data/en_answer_1024',
                'data/en_answer_2048',
                'data/qasper_answer',
                'data/qasper_answer_1024',
                'data/qasper_answer_2048',
                'data/narrativeqa_answer',
                'data/narrativeqa_answer_1024',
                'data/narrativeqa_answer_2048',
                'data/hotpotqa_answer',
                'data/hotpotqa_answer_1024',
                'data/hotpotqa_answer_2048']
for folder,answer_folder in zip(folders,answer_folders):
    answer_dataset=answer_folder.split('/')[-1]
    print(f'{answer_dataset} accuracy: ', count_answer_f1(folder,answer_folder))
    print('-------------------')


en_answer accuracy:  46.91
-------------------
en_answer_1024 accuracy:  47.57
-------------------
en_answer_2048 accuracy:  46.1
-------------------
qasper_answer accuracy:  25.91
-------------------
qasper_answer_1024 accuracy:  27.89
-------------------
qasper_answer_2048 accuracy:  26.86
-------------------
narrativeqa_answer accuracy:  14.32
-------------------
narrativeqa_answer_1024 accuracy:  15.22
-------------------
narrativeqa_answer_2048 accuracy:  12.53
-------------------
hotpotqa_answer accuracy:  27.93
-------------------
hotpotqa_answer_1024 accuracy:  24.99
-------------------
hotpotqa_answer_2048 accuracy:  28.12
-------------------
