In [57]:
from datasets import load_dataset
import json
import requests
from collections import defaultdict

from langchain.llms import OpenAI
from langchain import PromptTemplate, FewShotPromptTemplate
from langchain.prompts.example_selector import LengthBasedExampleSelector
from inference_api import YiyanInferenceApi, HfInferenceApi
import regex
from re import match

dataset_mapping_url = f"https://raw.githubusercontent.com/SJTU-LIT/ceval/main/subject_mapping.json"

llm = YiyanInferenceApi("yiyan-007", debug=True)
# llm = OpenAI(model_name="text-davinci-003")

def get_data_mapping(mapping_url):
    headers = {
            'Content-Type': 'application/json',
            'Accept': 'application/json'
        }
    response = requests.request("GET", mapping_url, headers=headers)
    return json.loads(response.content.decode('utf-8'))

def escape_format(text : str) -> str :
    return text.replace('{', '{{').replace('}', '}}')

def gen_by_fewshot_prompt(task_name, fewshot_set, test_set) :

    prompt_template = """问题：{question} \n选项\nA: {A}\nB: {B}\nC: {C}\nD: {D}\n思考：{explanation}\n答案：{answer}\n\n"""

    example_prompt = PromptTemplate(
        input_variables = ['question', 'A', 'B', 'C', 'D', 'explanation', 'answer'],
        template = prompt_template
    )
    
    fewshot_examples = []
    for id in range(len(fewshot_set)) :
        fewshot_examples.append({'question': escape_format(fewshot_set[id]['question']), 
                                 'A': escape_format(fewshot_set[id]['A']),
                                 'B': escape_format(fewshot_set[id]['B']),
                                 'C': escape_format(fewshot_set[id]['C']),
                                 'D': escape_format(fewshot_set[id]['D']),
                                 'explanation': escape_format(fewshot_set[id]['explanation']), 
                                 'answer': fewshot_set[id]['answer']})

    example_selector = LengthBasedExampleSelector(
        examples = fewshot_examples,
        example_prompt = example_prompt, 
        max_length = 1024,
        get_text_length = lambda x: len(x) # 2 tokens per Chinese character
    )

    few_shot_prompt = FewShotPromptTemplate(
        example_selector = example_selector,
        example_prompt = example_prompt,
        prefix = f"请作为一个{task_name}科目的考生，参考示例回答单项选择题，示例如下：\n", 
        suffix = """下面是要回答的问题，可以先写出思考过程，然后给出ABCD的一个选项作为答案，确保用<答案：>作为答案的前缀\n
                    问题：{question}\n选项\nA: {A}\nB: {B}\nC: {C}\nD: {D}\n思考：""",

        input_variables = ['question', 'A', 'B', 'C', 'D'], 
        example_separator = "\n\n"
    )

    correct, total = 0, 0
    for id in range(len(test_set)) :
        try :
            print(f"question: {test_set[id]['question']}")
            grounded_prompt = few_shot_prompt.format(question= escape_format(test_set[id]['question']),
                                                            A = escape_format(test_set[id]['A']),
                                                            B = escape_format(test_set[id]['B']),
                                                            C = escape_format(test_set[id]['C']),
                                                            D = escape_format(test_set[id]['D']))
            completion = llm(grounded_prompt)
            llm_answer = extract_answer(completion)
            answer = test_set[id]['answer']
            print(f"question: {id}, llm_answer={llm_answer}, answer={answer}")
            total += 1
            if llm_answer == answer :
                correct += 1
        except Exception as e :
            print(f"Exception in answer generation for question <{id}> in task <{task_name}>")

    return (correct, total)

def extract_answer(completion) :
    
    # string find part
    answer_prefixes = ['答案：', '答案:', '答案为', '答案是', '答案选',
                       '故选', '选项中最接近的是', '选项为', '答案应该是',
                       '因此选择', '因此答案为：', '即选项', '故选择',
                       '正确的选项是', '答案选', '因此答案为 ', '故选 ']
    for answer_prefix in answer_prefixes :
        if answer_prefix in completion :
            pos = completion.rfind(answer_prefix) + len(answer_prefix)
            if pos < len(completion) :
                answer = completion[pos]
                if answer == '{' and pos + 1 < len(completion) :
                    answer = completion[pos + 1]
                if answer in ['A', 'B', 'C', 'D'] :
                    return answer
    # regex part
    answer_patterns = [r'选项\((.*?)\)为正确', r'选项(.*?)为正确',
                       r'选项\((.*?)\)正确', r'选项(.*?)正确', 
                       r'选项\((.*?)\)是正确', r'选项(.*?)是正确',
                       r'选择\((.*?)\)为正确', r'选择(.*?)为正确', 
                       r'因此\((.*?)\)选项正确', r'因此(.*?)选项正确',
                       r'答案：<\((.*?)\)>', r'答案：<(.*?)>']
    
    for pattern in answer_patterns :
        match = regex.search(pattern, completion)
        if match :
            answer = match.group(1)
            if answer in ['A', 'B', 'C', 'D'] : 
                return answer
    
    return completion

Yiyan api key: 6Rs1EueWFULgAnPUS2HQDx2H, Yiyan sec key: Fl3R1XOeBdQfIcHng4Ah25RHzuSQxChT


In [None]:
# gen and compare
data_mapping = get_data_mapping(dataset_mapping_url)
g_correct, g_total = 0.0, 0.0

tested_subject = 0
topK = len(data_mapping.keys())

for name in data_mapping.keys() :
    dataset=load_dataset(r"ceval/ceval-exam",name=name)
    (correct, total) = gen_by_fewshot_prompt(data_mapping[name][1], dataset['dev'], dataset['val'])
    g_correct += correct
    g_total += total
    print(f"Task <{name}>: correct={correct}, total={total}")
    print(f"Global accuracy so far: {g_correct/g_total:.4f}(g_correct={g_correct}, g_total={g_total})")

    tested_subject += 1
    if tested_subject >= topK :
        break

In [58]:
# fix extract answer

correct, total = 0, 0
for i in range(4) :
    file_ptn = f'yiyan-ceval-result-rank{i}.log'
    with open(file_ptn, 'r') as f :
        lines = f.readlines()
        for line in lines :
            answer, completion = line.split(', ')[-2], line.split(', ')[-1]
            answer = answer.split('=')[-1]
            completion = completion.split('=')[-1]
            llm_answer = extract_answer(completion)
            print(f"answer:{answer} llm_answer:{llm_answer}")
            correct += (answer == llm_answer)
            total += 1

print(f"correct:{correct} total:{total} accuracy:{correct/total}")

answer:C llm_answer:'<答案：3>。根据位填充方法，需要将数据补足到32位的倍数，因此需要添加3个0。'

answer:C llm_answer:B
answer:C llm_answer:A
answer:C llm_answer:D
answer:D llm_answer:D
answer:D llm_answer:C
answer:C llm_answer:D
answer:D llm_answer:A
answer:C llm_answer:C
answer:D llm_answer:D
answer:B llm_answer:B
answer:C llm_answer:B
answer:B llm_answer:C
answer:A llm_answer:A
answer:D llm_answer:D
answer:C llm_answer:C
answer:A llm_answer:D
answer:B llm_answer:C
answer:C llm_answer:C
answer:7 llm_answer:19

answer:7.0 llm_answer:19.0)

answer:A llm_answer:D
answer:B llm_answer:D
answer:C llm_answer:A
answer:C llm_answer:2ν_1-ν_0'

answer:C llm_answer:A
answer:D llm_answer:A
answer:A llm_answer:B
answer:A llm_answer:A
answer:B llm_answer:B
answer:A llm_answer: E2，U1 > U2'

answer:C llm_answer:D
answer:C llm_answer:'在波长为$\\lambda$的驻波中，两个相邻波节之间的距离为C选项，即$\\frac{\\lambda}{2}$。'

answer:\\lambda'$。\n\n又因为锂的原子序数$Z$比铁小，所以锂的原子半径$r_Z$也比铁小，根据散射截面公式$A_Z\\propto Z^2$，锂的散射截面$A_Z$也比铁小；此外，由于原子核对电子的吸引力随原子半径的减小而增大，所以锂的电离能$W_Z$也比铁大，

In [54]:
import regex

completion = '44%。\n4. 因此，该股票的预期报酬率为144%，选项B为正确答案。'
# match = regex.match(r'选项\((.*?)\)为正确|选项\((.*?)\)正确|选项(.*?)正确|选项(.*?)为正确|选项(.*?)是正确', completion)
match = regex.search(r'(选项(.*?)为正确|选项(.*?)正确)', completion)
if match:
    print(match.group(1))
    # D

completion = '1. 新航路开辟后，欧洲大西洋沿岸的工商业经济繁荣起来，促进了资本主义的发展，打破了世界各地区的封闭和孤立状态，推动了海外扩张和世界市场的初步形成。\n2. 新航路开辟后，欧洲殖民者开始了殖民活动，加速了欧洲的资本原始积累。\n3. 第一次鸦片战争的发生是由英国向满清走私鸦片从而引发的一场战争，其原因不能归结于新航路的开辟。\n\n因此，选择D为正确答案'
completion = '问题中给出的无差异曲线相切于E1、E2、E3，这表示消费者1和消费者2在E1、E2、E3这三个点上达到帕累托最优状态，因此D选项正确。'

pattern = r'因此(.*?)选项正确'
match = regex.search(pattern, completion)
if match:
    print(match.group(1))

print(extract_answer(completion))


选项B为正确
D
问题中给出的无差异曲线相切于E1、E2、E3，这表示消费者1和消费者2在E1、E2、E3这三个点上达到帕累托最优状态，因此D选项正确。
