From f6445a45d7b4828d0b367e6022c13afd9f3de2ed Mon Sep 17 00:00:00 2001 From: KevinNuNu <34083603+KevinNuNu@users.noreply.github.com> Date: Fri, 27 Oct 2023 17:32:11 +0800 Subject: [PATCH] [SIG] add LogiQA dataset --- configs/datasets/LogiQA/LogiQA_en_ppl.py | 4 ++ .../datasets/LogiQA/LogiQA_en_ppl_20dfb3.py | 48 +++++++++++++++++++ configs/datasets/LogiQA/LogiQA_zh_ppl.py | 4 ++ .../datasets/LogiQA/LogiQA_zh_ppl_19fd62.py | 47 ++++++++++++++++++ opencompass/datasets/__init__.py | 1 + opencompass/datasets/logiqa.py | 25 ++++++++++ 6 files changed, 129 insertions(+) create mode 100644 configs/datasets/LogiQA/LogiQA_en_ppl.py create mode 100644 configs/datasets/LogiQA/LogiQA_en_ppl_20dfb3.py create mode 100644 configs/datasets/LogiQA/LogiQA_zh_ppl.py create mode 100644 configs/datasets/LogiQA/LogiQA_zh_ppl_19fd62.py create mode 100644 opencompass/datasets/logiqa.py diff --git a/configs/datasets/LogiQA/LogiQA_en_ppl.py b/configs/datasets/LogiQA/LogiQA_en_ppl.py new file mode 100644 index 000000000..56bfc7559 --- /dev/null +++ b/configs/datasets/LogiQA/LogiQA_en_ppl.py @@ -0,0 +1,4 @@ +from mmengine.config import read_base + +with read_base(): + from .LogiQA_en_ppl_20dfb3 import LogiQA_en_datasets # noqa: F401, F403 diff --git a/configs/datasets/LogiQA/LogiQA_en_ppl_20dfb3.py b/configs/datasets/LogiQA/LogiQA_en_ppl_20dfb3.py new file mode 100644 index 000000000..de7e746e4 --- /dev/null +++ b/configs/datasets/LogiQA/LogiQA_en_ppl_20dfb3.py @@ -0,0 +1,48 @@ +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import FixKRetriever +from opencompass.openicl.icl_inferencer import PPLInferencer +from opencompass.openicl.icl_evaluator import AccEvaluator +from opencompass.datasets import LogiQADataset + + +_hint = "The following are logical reasoning questions from the Chinese National " \ + "Civil Service Examination. Please choose the correct answer.\n" +LogiQA_en_infer_cfg = dict( + ice_template=dict( + type=PromptTemplate, + template="Context: {context}\nQuery: {query}\nA. {A}\nB. {B}\nC. {C}\nD. {D}\nAnswer: {correct_option}", + ), + prompt_template=dict( + type=PromptTemplate, + template={ + answer: + f"{_hint}Context: {{context}}\nQuery: {{query}}\nA. {{A}}\nB. {{B}}\nC. {{C}}\nD. {{D}}\nAnswer: {answer}" + for answer in ['A', 'B', 'C', 'D'] + }, + ice_token='', + ), + retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]), + inferencer=dict(type=PPLInferencer)) + +LogiQA_en_eval_cfg = dict(evaluator=dict(type=AccEvaluator), ) + + +LogiQA_en_datasets = [] +for _split in ["validation", "test"]: + + LogiQA_en_reader_cfg = dict( + input_columns=['context', 'query', 'A', 'B', 'C', 'D'], + output_column='correct_option', + test_split=_split + ) + + LogiQA_en_datasets.append( + dict( + abbr=f'LogiQA_en-{_split}', + type=LogiQADataset, + path='lucasmccabe/logiqa', + reader_cfg=LogiQA_en_reader_cfg, + infer_cfg=LogiQA_en_infer_cfg, + eval_cfg=LogiQA_en_eval_cfg + ) + ) diff --git a/configs/datasets/LogiQA/LogiQA_zh_ppl.py b/configs/datasets/LogiQA/LogiQA_zh_ppl.py new file mode 100644 index 000000000..71ccd5e76 --- /dev/null +++ b/configs/datasets/LogiQA/LogiQA_zh_ppl.py @@ -0,0 +1,4 @@ +from mmengine.config import read_base + +with read_base(): + from .LogiQA_zh_ppl_19fd62 import LogiQA_zh_datasets # noqa: F401, F403 diff --git a/configs/datasets/LogiQA/LogiQA_zh_ppl_19fd62.py b/configs/datasets/LogiQA/LogiQA_zh_ppl_19fd62.py new file mode 100644 index 000000000..3b3344ea8 --- /dev/null +++ b/configs/datasets/LogiQA/LogiQA_zh_ppl_19fd62.py @@ -0,0 +1,47 @@ +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import FixKRetriever +from opencompass.openicl.icl_inferencer import PPLInferencer +from opencompass.openicl.icl_evaluator import AccEvaluator +from opencompass.datasets import LogiQADataset + + +_hint = "以下是中国国家公务员考试的逻辑推理题,请选出其中的正确答案。\n" +LogiQA_zh_infer_cfg = dict( + ice_template=dict( + type=PromptTemplate, + template="上下文信息: {context}\n提问: {query}\nA. {A}\nB. {B}\nC. {C}\nD. {D}\n答案: {correct_option}", + ), + prompt_template=dict( + type=PromptTemplate, + template={ + answer: + f"{_hint}上下文信息: {{context}}\n提问: {{query}}\nA. {{A}}\nB. {{B}}\nC. {{C}}\nD. {{D}}\n答案: {answer}" + for answer in ['A', 'B', 'C', 'D'] + }, + ice_token='', + ), + retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]), + inferencer=dict(type=PPLInferencer)) + +LogiQA_zh_eval_cfg = dict(evaluator=dict(type=AccEvaluator), ) + + +LogiQA_zh_datasets = [] +for _split in ["validation", "test"]: + + LogiQA_zh_reader_cfg = dict( + input_columns=['context', 'query', 'A', 'B', 'C', 'D'], + output_column='correct_option', + test_split=_split + ) + + LogiQA_zh_datasets.append( + dict( + abbr=f'LogiQA_zh-{_split}', + type=LogiQADataset, + path='jiacheng-ye/logiqa-zh', + reader_cfg=LogiQA_zh_reader_cfg, + infer_cfg=LogiQA_zh_infer_cfg, + eval_cfg=LogiQA_zh_eval_cfg + ) + ) diff --git a/opencompass/datasets/__init__.py b/opencompass/datasets/__init__.py index 70d875ab5..e4d669ea3 100644 --- a/opencompass/datasets/__init__.py +++ b/opencompass/datasets/__init__.py @@ -52,6 +52,7 @@ from .lawbench import * # noqa: F401, F403 from .lcsts import * # noqa: F401, F403 from .leval import * # noqa: F401, F403 +from .logiqa import * # noqa: F401, F403 from .longbench import * # noqa: F401, F403 from .math import * # noqa: F401, F403 from .mathbench import * # noqa: F401, F403 diff --git a/opencompass/datasets/logiqa.py b/opencompass/datasets/logiqa.py new file mode 100644 index 000000000..9c444bae9 --- /dev/null +++ b/opencompass/datasets/logiqa.py @@ -0,0 +1,25 @@ +from datasets import load_dataset + +from opencompass.registry import LOAD_DATASET + +from .base import BaseDataset + + +@LOAD_DATASET.register_module() +class LogiQADataset(BaseDataset): + + @staticmethod + def load(path: str): + dataset = load_dataset(path=path) + + def preprocess(example): + options = example.pop('options') + example['A'] = options[0] + example['B'] = options[1] + example['C'] = options[2] + example['D'] = options[3] + example['correct_option'] = chr(65 + example['correct_option']) + return example + + dataset = dataset.map(preprocess) + return dataset