In [6]:
from datasets import load_dataset
import json
import requests
from collections import defaultdict

from langchain.llms import OpenAI
from langchain import PromptTemplate, FewShotPromptTemplate
from langchain.prompts.example_selector import LengthBasedExampleSelector
from inference_api import YiyanInferenceApi, HfInferenceApi
import regex

dataset_mapping_url = f"https://raw.githubusercontent.com/SJTU-LIT/ceval/main/subject_mapping.json"

# llm = YiyanInferenceApi("yiyan-007", debug=True)
llm = OpenAI(model_name="text-davinci-003")

def get_data_mapping(mapping_url):
    headers = {
            'Content-Type': 'application/json',
            'Accept': 'application/json'
        }
    response = requests.request("GET", mapping_url, headers=headers)
    return json.loads(response.content.decode('utf-8'))

def gen_by_fewshot_prompt(task_name, fewshot_set, test_set) :

    prompt_template = """问题：{question} \n选项\nA: {A}\nB: {B}\nC: {C}\nD: {D}\n思考：{explanation}\n答案：{answer}\n\n"""

    example_prompt = PromptTemplate(
        input_variables = ['question', 'A', 'B', 'C', 'D', 'explanation', 'answer'],
        template = prompt_template
    )
    
    fewshot_examples = []
    for id in range(len(fewshot_set)) :
        fewshot_examples.append({'question': fewshot_set[id]['question'], 
                                 'A': fewshot_set[id]['A'],
                                 'B': fewshot_set[id]['B'],
                                 'C': fewshot_set[id]['C'],
                                 'D': fewshot_set[id]['D'],
                                 'explanation': fewshot_set[id]['explanation'], 
                                 'answer': fewshot_set[id]['answer']})

    example_selector = LengthBasedExampleSelector(
        examples = fewshot_examples,
        example_prompt = example_prompt, 
        max_length = 1024,
        get_text_length = lambda x: 2*len(x) # 2 tokens per Chinese character
    )

    few_shot_prompt = FewShotPromptTemplate(
        example_selector = example_selector,
        example_prompt = example_prompt,
        prefix = f"请作为一个{task_name}科目的考生，参考示例回答单项选择题，示例如下：\n", 
        suffix = "下面是要回答的问题，可以先写出思考过程，然后给出ABCD的一个选项作为答案，确保用{{答案：}}作为答案的前缀\n问题：{question}\n选项\nA: {A}\nB: {B}\nC: {C}\nD: {D}\n思考：",
        input_variables = ['question', 'A', 'B', 'C', 'D'], 
        example_separator = "\n\n"
    )

    for id in range(len(test_set)) :
        grounded_prompt = few_shot_prompt.format(question=test_set[id]['question'],
                                                        A = test_set[id]['A'],
                                                        B = test_set[id]['B'],
                                                        C = test_set[id]['C'],
                                                        D = test_set[id]['D'])
        completion = llm(grounded_prompt)
        print(f"question: {id}, llm_answer={extract_answer(completion)}, answer={test_set[id]['answer']}")

def extract_answer(completion) :
    
    answer_prefixes = ['答案：', '答案为', '答案是']
    for answer_prefix in answer_prefixes :
        if answer_prefix in completion :
            return completion[completion.find(answer_prefix) + len(answer_prefix)]
    
    # regx match the first letter in ()
    matched = regex.match(r'\((.*)\)', completion)
    if matched :
        return matched.group(1)
    return completion

data_mapping = get_data_mapping(dataset_mapping_url)
for name in data_mapping.keys() :
    dataset=load_dataset(r"ceval/ceval-exam",name=name)
    print(dataset['dev'][0])
    gen_by_fewshot_prompt(data_mapping[name][1], dataset['dev'], dataset['val'])
    break


Found cached dataset ceval-exam (/Users/wujianmin/.cache/huggingface/datasets/ceval___ceval-exam/computer_network/1.0.0/955d76b4f8ea8c9b6a14519905c8c10f781d3feb698126e274d2eb16cefc671f)


  0%|          | 0/3 [00:00<?, ?it/s]

{'id': 0, 'question': '下列设备属于资源子网的是____。', 'A': '计算机软件', 'B': '网桥', 'C': '交换机', 'D': '路由器', 'answer': 'A', 'explanation': '1. 首先，资源子网是指提供共享资源的网络，如打印机、文件服务器等。\r\n2. 其次，我们需要了解选项中设备的功能。网桥、交换机和路由器的主要功能是实现不同网络之间的通信和数据传输，是通信子网设备。而计算机软件可以提供共享资源的功能。'}


Retrying langchain.llms.openai.completion_with_retry.<locals>._completion_with_retry in 4.0 seconds as it raised RateLimitError: The server had an error while processing your request. Sorry about that!.


question: 0, llm_answer=A, answer=C
question: 1, llm_answer=C, answer=C
question: 2, llm_answer=A, answer=C
question: 3, llm_answer=C, answer=C
question: 4, llm_answer=B, answer=D
question: 5, llm_answer=D, answer=D
question: 6, llm_answer=D, answer=C
question: 7, llm_answer=D, answer=D
question: 8, llm_answer=D, answer=C
question: 9, llm_answer=C, answer=D
question: 10, llm_answer=B, answer=B


Retrying langchain.llms.openai.completion_with_retry.<locals>._completion_with_retry in 4.0 seconds as it raised RateLimitError: The server had an error while processing your request. Sorry about that!.
Retrying langchain.llms.openai.completion_with_retry.<locals>._completion_with_retry in 4.0 seconds as it raised RateLimitError: The server had an error while processing your request. Sorry about that!.


question: 11, llm_answer=C, answer=C
question: 12, llm_answer=A, answer=B


Retrying langchain.llms.openai.completion_with_retry.<locals>._completion_with_retry in 4.0 seconds as it raised RateLimitError: The server had an error while processing your request. Sorry about that!.


question: 13, llm_answer=A, answer=A
question: 14, llm_answer=B, answer=D
question: 15, llm_answer=C, answer=C
question: 16, llm_answer=A, answer=A
question: 17, llm_answer=D, answer=B
question: 18, llm_answer=C, answer=C
