From bb0919b305f4b8a0b164c55bb8b65766043977f7 Mon Sep 17 00:00:00 2001 From: Tong Gao Date: Fri, 6 Oct 2023 23:53:41 -0500 Subject: [PATCH] Move fix_id_list to Retriever (#442) * [Refactor] Move fix_id_list to Retriever * update * move to base * fix --- .../datasets/GLUE_CoLA/GULE_CoLA_ppl_77d0df.py | 4 ++-- .../datasets/GLUE_QQP/GLUE_QQP_ppl_250d00.py | 4 ++-- configs/datasets/ceval/ceval_gen_2daf24.py | 4 ++-- configs/datasets/ceval/ceval_gen_5f30c7.py | 4 ++-- configs/datasets/ceval/ceval_ppl_578f8d.py | 4 ++-- configs/datasets/ceval/ceval_ppl_93e5ce.py | 4 ++-- configs/datasets/cmb/cmb_gen_72cbb7.py | 4 ++-- configs/datasets/cmmlu/cmmlu_gen_c13365.py | 4 ++-- configs/datasets/cmmlu/cmmlu_ppl_8b9c76.py | 4 ++-- configs/datasets/mmlu/mmlu_gen_23a9a9.py | 4 ++-- configs/datasets/mmlu/mmlu_gen_5d1409.py | 4 ++-- configs/datasets/mmlu/mmlu_gen_79e572.py | 4 ++-- configs/datasets/mmlu/mmlu_gen_a484b3.py | 4 ++-- configs/datasets/mmlu/mmlu_ppl_ac766d.py | 4 ++-- configs/datasets/nq/nq_gen_0356ec.py | 4 ++-- .../datasets/triviaqa/triviaqa_gen_0356ec.py | 4 ++-- docs/en/prompt/prompt_template.md | 4 ++-- docs/zh_cn/prompt/prompt_template.md | 4 ++-- .../icl_inferencer/icl_agent_inferencer.py | 5 +---- .../icl_inferencer/icl_attack_inferencer.py | 7 +------ .../icl_inferencer/icl_base_inferencer.py | 10 +++++++--- .../icl_inferencer/icl_clp_inferencer.py | 7 +------ .../icl_inferencer/icl_gen_inferencer.py | 12 ++---------- .../icl_inferencer/icl_ppl_inferencer.py | 7 +------ .../openicl/icl_inferencer/icl_sc_inferencer.py | 7 +------ .../icl_inferencer/icl_tot_inferencer.py | 8 +------- .../icl_retriever/icl_fix_k_retriever.py | 17 ++++++++--------- opencompass/utils/prompt.py | 4 ++++ tools/prompt_viewer.py | 6 +----- tools/update_dataset_suffix.py | 4 ++++ 30 files changed, 68 insertions(+), 98 deletions(-) diff --git a/configs/datasets/GLUE_CoLA/GULE_CoLA_ppl_77d0df.py b/configs/datasets/GLUE_CoLA/GULE_CoLA_ppl_77d0df.py index 48bed05e7..b98b3a2fe 100644 --- a/configs/datasets/GLUE_CoLA/GULE_CoLA_ppl_77d0df.py +++ b/configs/datasets/GLUE_CoLA/GULE_CoLA_ppl_77d0df.py @@ -23,8 +23,8 @@ }, ice_token='', ), - retriever=dict(type=FixKRetriever), - inferencer=dict(type=PPLInferencer, fix_id_list=[17, 18, 19, 20, 21])) + retriever=dict(type=FixKRetriever, fix_id_list=[17, 18, 19, 20, 21]), + inferencer=dict(type=PPLInferencer)) CoLA_eval_cfg = dict(evaluator=dict(type=AccEvaluator), ) diff --git a/configs/datasets/GLUE_QQP/GLUE_QQP_ppl_250d00.py b/configs/datasets/GLUE_QQP/GLUE_QQP_ppl_250d00.py index cb75aef93..5e1d18de1 100644 --- a/configs/datasets/GLUE_QQP/GLUE_QQP_ppl_250d00.py +++ b/configs/datasets/GLUE_QQP/GLUE_QQP_ppl_250d00.py @@ -22,8 +22,8 @@ }, ice_token='', ), - retriever=dict(type=FixKRetriever), - inferencer=dict(type=PPLInferencer, fix_id_list=[0, 1, 2, 3, 4])) + retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]), + inferencer=dict(type=PPLInferencer)) QQP_eval_cfg = dict(evaluator=dict(type=AccEvaluator), ) diff --git a/configs/datasets/ceval/ceval_gen_2daf24.py b/configs/datasets/ceval/ceval_gen_2daf24.py index c203b51c7..a2e020f1e 100644 --- a/configs/datasets/ceval/ceval_gen_2daf24.py +++ b/configs/datasets/ceval/ceval_gen_2daf24.py @@ -161,8 +161,8 @@ ]), ice_token="", ), - retriever=dict(type=FixKRetriever), - inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4]), + retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]), + inferencer=dict(type=GenInferencer), ) ceval_eval_cfg = dict( diff --git a/configs/datasets/ceval/ceval_gen_5f30c7.py b/configs/datasets/ceval/ceval_gen_5f30c7.py index 1ccbe4de6..caca00284 100644 --- a/configs/datasets/ceval/ceval_gen_5f30c7.py +++ b/configs/datasets/ceval/ceval_gen_5f30c7.py @@ -161,8 +161,8 @@ ]), ice_token="", ), - retriever=dict(type=FixKRetriever), - inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4]), + retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]), + inferencer=dict(type=GenInferencer), ) ceval_eval_cfg = dict( diff --git a/configs/datasets/ceval/ceval_ppl_578f8d.py b/configs/datasets/ceval/ceval_ppl_578f8d.py index 212b5b333..8447c86b5 100644 --- a/configs/datasets/ceval/ceval_ppl_578f8d.py +++ b/configs/datasets/ceval/ceval_ppl_578f8d.py @@ -163,8 +163,8 @@ }, ice_token="", ), - retriever=dict(type=FixKRetriever), - inferencer=dict(type=PPLInferencer, fix_id_list=[0, 1, 2, 3, 4]), + retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]), + inferencer=dict(type=PPLInferencer), ) ceval_eval_cfg = dict(evaluator=dict(type=AccEvaluator)) diff --git a/configs/datasets/ceval/ceval_ppl_93e5ce.py b/configs/datasets/ceval/ceval_ppl_93e5ce.py index 56a6bb64b..9deea61a8 100644 --- a/configs/datasets/ceval/ceval_ppl_93e5ce.py +++ b/configs/datasets/ceval/ceval_ppl_93e5ce.py @@ -163,8 +163,8 @@ }, ice_token="", ), - retriever=dict(type=FixKRetriever), - inferencer=dict(type=PPLInferencer, fix_id_list=[0, 1, 2, 3, 4]), + retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]), + inferencer=dict(type=PPLInferencer), ) ceval_eval_cfg = dict(evaluator=dict(type=AccEvaluator)) diff --git a/configs/datasets/cmb/cmb_gen_72cbb7.py b/configs/datasets/cmb/cmb_gen_72cbb7.py index 4cb9a325b..48729b9f1 100644 --- a/configs/datasets/cmb/cmb_gen_72cbb7.py +++ b/configs/datasets/cmb/cmb_gen_72cbb7.py @@ -28,8 +28,8 @@ ), ice_token="", ), - retriever=dict(type=FixKRetriever), - inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4]), + retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]), + inferencer=dict(type=GenInferencer), ) cmb_datasets.append( diff --git a/configs/datasets/cmmlu/cmmlu_gen_c13365.py b/configs/datasets/cmmlu/cmmlu_gen_c13365.py index e4b0bc4f6..f6191bdaf 100644 --- a/configs/datasets/cmmlu/cmmlu_gen_c13365.py +++ b/configs/datasets/cmmlu/cmmlu_gen_c13365.py @@ -96,8 +96,8 @@ ]), ice_token="", ), - retriever=dict(type=FixKRetriever), - inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4]), + retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]), + inferencer=dict(type=GenInferencer), ) cmmlu_eval_cfg = dict( diff --git a/configs/datasets/cmmlu/cmmlu_ppl_8b9c76.py b/configs/datasets/cmmlu/cmmlu_ppl_8b9c76.py index eb9ea96b9..631407ab3 100644 --- a/configs/datasets/cmmlu/cmmlu_ppl_8b9c76.py +++ b/configs/datasets/cmmlu/cmmlu_ppl_8b9c76.py @@ -98,8 +98,8 @@ }, ice_token="", ), - retriever=dict(type=FixKRetriever), - inferencer=dict(type=PPLInferencer, fix_id_list=[0, 1, 2, 3, 4]), + retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]), + inferencer=dict(type=PPLInferencer), ) cmmlu_eval_cfg = dict(evaluator=dict(type=AccEvaluator)) diff --git a/configs/datasets/mmlu/mmlu_gen_23a9a9.py b/configs/datasets/mmlu/mmlu_gen_23a9a9.py index 53595b3b5..c724902f8 100644 --- a/configs/datasets/mmlu/mmlu_gen_23a9a9.py +++ b/configs/datasets/mmlu/mmlu_gen_23a9a9.py @@ -29,8 +29,8 @@ dict(role='BOT', prompt='{target}\n') ])), prompt_template=mmlu_prompt_template, - retriever=dict(type=FixKRetriever), - inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4])) + retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]), + inferencer=dict(type=GenInferencer)) mmlu_eval_cfg = dict( evaluator=dict(type=AccEvaluator), diff --git a/configs/datasets/mmlu/mmlu_gen_5d1409.py b/configs/datasets/mmlu/mmlu_gen_5d1409.py index 8a925d428..3d530a35c 100644 --- a/configs/datasets/mmlu/mmlu_gen_5d1409.py +++ b/configs/datasets/mmlu/mmlu_gen_5d1409.py @@ -102,8 +102,8 @@ ), ice_token="", ), - retriever=dict(type=FixKRetriever), - inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4]), + retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]), + inferencer=dict(type=GenInferencer), ) mmlu_eval_cfg = dict( diff --git a/configs/datasets/mmlu/mmlu_gen_79e572.py b/configs/datasets/mmlu/mmlu_gen_79e572.py index eabab8e77..18b2ea7d8 100644 --- a/configs/datasets/mmlu/mmlu_gen_79e572.py +++ b/configs/datasets/mmlu/mmlu_gen_79e572.py @@ -87,8 +87,8 @@ f"{_hint}{{input}}\nA. {{A}}\nB. {{B}}\nC. {{C}}\nD. {{D}}\nAnswer:", ice_token="", ), - retriever=dict(type=FixKRetriever), - inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4]), + retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]), + inferencer=dict(type=GenInferencer), ) mmlu_eval_cfg = dict( diff --git a/configs/datasets/mmlu/mmlu_gen_a484b3.py b/configs/datasets/mmlu/mmlu_gen_a484b3.py index 93406ea6e..69f7939ba 100644 --- a/configs/datasets/mmlu/mmlu_gen_a484b3.py +++ b/configs/datasets/mmlu/mmlu_gen_a484b3.py @@ -102,8 +102,8 @@ ), ice_token="", ), - retriever=dict(type=FixKRetriever), - inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4]), + retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]), + inferencer=dict(type=GenInferencer), ) mmlu_eval_cfg = dict( diff --git a/configs/datasets/mmlu/mmlu_ppl_ac766d.py b/configs/datasets/mmlu/mmlu_ppl_ac766d.py index 900c1eb9b..f0473eb46 100644 --- a/configs/datasets/mmlu/mmlu_ppl_ac766d.py +++ b/configs/datasets/mmlu/mmlu_ppl_ac766d.py @@ -93,8 +93,8 @@ }, ice_token="", ), - retriever=dict(type=FixKRetriever), - inferencer=dict(type=PPLInferencer, fix_id_list=[0, 1, 2, 3, 4]), + retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]), + inferencer=dict(type=PPLInferencer), ) mmlu_eval_cfg = dict(evaluator=dict(type=AccEvaluator), ) diff --git a/configs/datasets/nq/nq_gen_0356ec.py b/configs/datasets/nq/nq_gen_0356ec.py index f2e4dc37b..beffcccce 100644 --- a/configs/datasets/nq/nq_gen_0356ec.py +++ b/configs/datasets/nq/nq_gen_0356ec.py @@ -44,8 +44,8 @@ ), ice_token="", ), - retriever=dict(type=FixKRetriever), - inferencer=dict(type=GenInferencer, max_out_len=50, fix_id_list=list(range(k))), + retriever=dict(type=FixKRetriever, fix_id_list=list(range(k))), + inferencer=dict(type=GenInferencer, max_out_len=50), ) nq_eval_cfg = dict(evaluator=dict(type=NQEvaluator), pred_role="BOT") diff --git a/configs/datasets/triviaqa/triviaqa_gen_0356ec.py b/configs/datasets/triviaqa/triviaqa_gen_0356ec.py index 79bc2d145..95c262f76 100644 --- a/configs/datasets/triviaqa/triviaqa_gen_0356ec.py +++ b/configs/datasets/triviaqa/triviaqa_gen_0356ec.py @@ -45,8 +45,8 @@ ), ice_token="", ), - retriever=dict(type=FixKRetriever), - inferencer=dict(type=GenInferencer, max_out_len=50, fix_id_list=list(range(k))), + retriever=dict(type=FixKRetriever, fix_id_list=list(range(k))), + inferencer=dict(type=GenInferencer, max_out_len=50), ) triviaqa_eval_cfg = dict(evaluator=dict(type=TriviaQAEvaluator), pred_role="BOT") diff --git a/docs/en/prompt/prompt_template.md b/docs/en/prompt/prompt_template.md index 12d015ee3..5147cd672 100644 --- a/docs/en/prompt/prompt_template.md +++ b/docs/en/prompt/prompt_template.md @@ -34,8 +34,8 @@ infer_cfg = dict( template='Solve the following questions.\n{question}\n{answer}', ice_token="" ), - retriever=dict(type=FixKRetriever), # Definition of how to retrieve in-context examples. - inferencer=dict(type=GenInferencer, fix_id_list=[0, 1]), # Method used to generate predictions. + retriever=dict(type=FixKRetriever, fix_id_list=[0, 1]), # Definition of how to retrieve in-context examples. + inferencer=dict(type=GenInferencer), # Method used to generate predictions. ) ``` diff --git a/docs/zh_cn/prompt/prompt_template.md b/docs/zh_cn/prompt/prompt_template.md index ee5e39286..d93c4dae5 100644 --- a/docs/zh_cn/prompt/prompt_template.md +++ b/docs/zh_cn/prompt/prompt_template.md @@ -34,8 +34,8 @@ infer_cfg=dict( template='Solve the following questions.\n{question}\n{answer}', ice_token="" ), - retriever=dict(type=FixKRetriever), # 定义 in context example 的获取方式 - inferencer=dict(type=GenInferencer, fix_id_list=[0, 1]), # 使用何种方式推理得到 prediction + retriever=dict(type=FixKRetriever, fix_id_list=[0, 1]), # 定义 in context example 的获取方式 + inferencer=dict(type=GenInferencer), # 使用何种方式推理得到 prediction ) ``` diff --git a/opencompass/openicl/icl_inferencer/icl_agent_inferencer.py b/opencompass/openicl/icl_inferencer/icl_agent_inferencer.py index 781ce9dce..835468682 100644 --- a/opencompass/openicl/icl_inferencer/icl_agent_inferencer.py +++ b/opencompass/openicl/icl_inferencer/icl_agent_inferencer.py @@ -55,10 +55,7 @@ def inference(self, output_json_filename = self.output_json_filename # 2. Get results of retrieval process - if 'Fix' in retriever.__class__.__name__: - ice_idx_list = retriever.retrieve(self.fix_id_list) - else: - ice_idx_list = retriever.retrieve() + ice_idx_list = retriever.retrieve() # Create tmp json file for saving intermediate results and future # resuming diff --git a/opencompass/openicl/icl_inferencer/icl_attack_inferencer.py b/opencompass/openicl/icl_inferencer/icl_attack_inferencer.py index 7a51c96b7..f8d8ea042 100644 --- a/opencompass/openicl/icl_inferencer/icl_attack_inferencer.py +++ b/opencompass/openicl/icl_inferencer/icl_attack_inferencer.py @@ -59,7 +59,6 @@ def __init__( output_json_filepath: Optional[str] = './icl_inference_output', output_json_filename: Optional[str] = 'predictions', save_every: Optional[int] = None, - fix_id_list: Optional[List[int]] = None, dataset_cfg: Optional[List[int]] = None, **kwargs) -> None: super().__init__( @@ -78,7 +77,6 @@ def __init__( self.output_column = dataset_cfg['reader_cfg']['output_column'] self.gen_field_replace_token = gen_field_replace_token self.max_out_len = max_out_len - self.fix_id_list = fix_id_list if self.model.is_api and save_every is None: save_every = 1 @@ -94,10 +92,7 @@ def predict(self, adv_prompt) -> List: output_json_filename = self.output_json_filename # 2. Get results of retrieval process - if 'Fix' in self.retriever.__class__.__name__: - ice_idx_list = self.retriever.retrieve(self.fix_id_list) - else: - ice_idx_list = self.retriever.retrieve() + ice_idx_list = self.retriever.retrieve() # 3. Generate prompts for testing input prompt_list, label_list = self.get_generation_prompt_list_from_retriever_indices( # noqa diff --git a/opencompass/openicl/icl_inferencer/icl_base_inferencer.py b/opencompass/openicl/icl_inferencer/icl_base_inferencer.py index 7dc7482af..fd3fbde76 100644 --- a/opencompass/openicl/icl_inferencer/icl_base_inferencer.py +++ b/opencompass/openicl/icl_inferencer/icl_base_inferencer.py @@ -25,9 +25,6 @@ class BaseInferencer: `JSON` file. output_json_filename (:obj:`str`, optional): File name for output `JSON` file. - api_name (:obj:`str`, optional): Name of API service. - call_api (:obj:`bool`): If ``True``, an API for LM models will be used, - determined by :obj:`api_name`. """ model = None @@ -38,8 +35,15 @@ def __init__( batch_size: Optional[int] = 1, output_json_filepath: Optional[str] = './icl_inference_output', output_json_filename: Optional[str] = 'predictions', + fix_id_list: Optional[List[int]] = None, **kwargs, ) -> None: + + if fix_id_list: + raise ValueError('Passing fix_id_list to Inferencer is no longer ' + 'allowed. Please pass it to FixKRetriever ' + 'instead.') + self.model = model self.max_seq_len = max_seq_len diff --git a/opencompass/openicl/icl_inferencer/icl_clp_inferencer.py b/opencompass/openicl/icl_inferencer/icl_clp_inferencer.py index c7a2b4ac4..4d0951af2 100644 --- a/opencompass/openicl/icl_inferencer/icl_clp_inferencer.py +++ b/opencompass/openicl/icl_inferencer/icl_clp_inferencer.py @@ -55,7 +55,6 @@ def __init__( batch_size: Optional[int] = 1, output_json_filepath: Optional[str] = './icl_inference_output', output_json_filename: Optional[str] = 'predictions', - fix_id_list: Optional[List[int]] = None, single_token: bool = True, **kwargs) -> None: super().__init__( @@ -67,7 +66,6 @@ def __init__( **kwargs, ) - self.fix_id_list = fix_id_list # TODO: support multiple token assert single_token, 'Only support single token choice currently.' self.single_token = single_token @@ -104,10 +102,7 @@ def inference(self, raise ValueError(err_msg) # 2. Get results of retrieval process - if self.fix_id_list: - ice_idx_list = retriever.retrieve(self.fix_id_list) - else: - ice_idx_list = retriever.retrieve() + ice_idx_list = retriever.retrieve() # 3. Generate in-context examples for testing inputs for idx in range(len(ice_idx_list)): diff --git a/opencompass/openicl/icl_inferencer/icl_gen_inferencer.py b/opencompass/openicl/icl_inferencer/icl_gen_inferencer.py index a319b9c95..0398e2c6b 100644 --- a/opencompass/openicl/icl_inferencer/icl_gen_inferencer.py +++ b/opencompass/openicl/icl_inferencer/icl_gen_inferencer.py @@ -51,7 +51,6 @@ def __init__( output_json_filepath: Optional[str] = './icl_inference_output', output_json_filename: Optional[str] = 'predictions', save_every: Optional[int] = None, - fix_id_list: Optional[List[int]] = None, **kwargs) -> None: super().__init__( model=model, @@ -64,7 +63,6 @@ def __init__( self.gen_field_replace_token = gen_field_replace_token self.max_out_len = max_out_len - self.fix_id_list = fix_id_list if self.model.is_api and save_every is None: save_every = 1 @@ -85,10 +83,7 @@ def inference(self, output_json_filename = self.output_json_filename # 2. Get results of retrieval process - if 'Fix' in retriever.__class__.__name__: - ice_idx_list = retriever.retrieve(self.fix_id_list) - else: - ice_idx_list = retriever.retrieve() + ice_idx_list = retriever.retrieve() # 3. Generate prompts for testing input prompt_list = self.get_generation_prompt_list_from_retriever_indices( @@ -220,10 +215,7 @@ def inference(self, output_json_filename = self.output_json_filename # 2. Get results of retrieval process - if 'Fix' in retriever.__class__.__name__: - ice_idx_list = retriever.retrieve(self.fix_id_list) - else: - ice_idx_list = retriever.retrieve() + ice_idx_list = retriever.retrieve() # 3. Generate prompts for testing input prompt_list = self.get_generation_prompt_list_from_retriever_indices( diff --git a/opencompass/openicl/icl_inferencer/icl_ppl_inferencer.py b/opencompass/openicl/icl_inferencer/icl_ppl_inferencer.py index 0d8bad9c8..606afd869 100644 --- a/opencompass/openicl/icl_inferencer/icl_ppl_inferencer.py +++ b/opencompass/openicl/icl_inferencer/icl_ppl_inferencer.py @@ -41,7 +41,6 @@ def __init__( output_json_filepath: Optional[str] = './icl_inference_output', output_json_filename: Optional[str] = 'predictions', labels: Optional[List] = None, - fix_id_list: Optional[List[int]] = None, **kwargs) -> None: super().__init__( model=model, @@ -53,7 +52,6 @@ def __init__( ) self.labels = labels - self.fix_id_list = fix_id_list def inference(self, retriever: BaseRetriever, @@ -75,10 +73,7 @@ def inference(self, output_json_filename = self.output_json_filename # 2. Get results of retrieval process - if self.fix_id_list: - ice_idx_list = retriever.retrieve(self.fix_id_list) - else: - ice_idx_list = retriever.retrieve() + ice_idx_list = retriever.retrieve() # 3. Get labels of all the classes if self.labels is None: diff --git a/opencompass/openicl/icl_inferencer/icl_sc_inferencer.py b/opencompass/openicl/icl_inferencer/icl_sc_inferencer.py index b7e9cffe3..dbbd41c98 100644 --- a/opencompass/openicl/icl_inferencer/icl_sc_inferencer.py +++ b/opencompass/openicl/icl_inferencer/icl_sc_inferencer.py @@ -52,7 +52,6 @@ def __init__( output_json_filepath: Optional[str] = './icl_inference_output', output_json_filename: Optional[str] = 'predictions', save_every: Optional[int] = None, - fix_id_list: Optional[List[int]] = None, sc_size: Optional[int] = 1, infer_type: Optional[str] = '', generation_kwargs: dict = {}, @@ -69,7 +68,6 @@ def __init__( self.gen_field_replace_token = gen_field_replace_token self.generation_kwargs = generation_kwargs self.max_out_len = max_out_len - self.fix_id_list = fix_id_list self.sc_size = sc_size if self.model.is_api and save_every is None: @@ -91,10 +89,7 @@ def inference(self, output_json_filename = self.output_json_filename # 2. Get results of retrieval process - if 'Fix' in retriever.__class__.__name__: - ice_idx_list = retriever.retrieve(self.fix_id_list) - else: - ice_idx_list = retriever.retrieve() + ice_idx_list = retriever.retrieve() # 3. Generate prompts for testing input prompt_list = self.get_generation_prompt_list_from_retriever_indices( diff --git a/opencompass/openicl/icl_inferencer/icl_tot_inferencer.py b/opencompass/openicl/icl_inferencer/icl_tot_inferencer.py index 5fd174835..22a2298e3 100644 --- a/opencompass/openicl/icl_inferencer/icl_tot_inferencer.py +++ b/opencompass/openicl/icl_inferencer/icl_tot_inferencer.py @@ -46,7 +46,6 @@ class ToTInferencer(GenInferencer): `save_every` epochs. generation_kwargs (:obj:`Dict`, optional): Parameters for the :obj:`model.generate()` method. - fix_id_list (:obj:`List[int]`, optional): List of indices to fix naive_run (:obj:`bool`): if True, run naive IO/CoT sampling instead of ToT + BFS. prompt_wrapper (:obj:`dict`): wrapper for prompts @@ -76,7 +75,6 @@ def __init__( output_json_filepath: Optional[str] = './icl_inference_output', output_json_filename: Optional[str] = 'predictions', save_every: Optional[int] = None, - fix_id_list: Optional[List[int]] = None, naive_run: bool = False, prompt_wrapper: dict = {}, prompt_sample: str = 'standard', @@ -97,7 +95,6 @@ def __init__( output_json_filename=output_json_filename, output_json_filepath=output_json_filepath, save_every=save_every, - fix_id_list=fix_id_list, sc_size=n_evaluate_sample, **kwargs, ) @@ -319,10 +316,7 @@ def inference(self, output_json_filename = self.output_json_filename # 2. Get results of retrieval process - if 'Fix' in retriever.__class__.__name__: - ice_idx_list = retriever.retrieve(self.fix_id_list) - else: - ice_idx_list = retriever.retrieve() + ice_idx_list = retriever.retrieve() # 3. Generate prompts for testing input prompt_list = self.get_generation_prompt_list_from_retriever_indices( diff --git a/opencompass/openicl/icl_retriever/icl_fix_k_retriever.py b/opencompass/openicl/icl_retriever/icl_fix_k_retriever.py index 1e6f73973..c9ade7551 100644 --- a/opencompass/openicl/icl_retriever/icl_fix_k_retriever.py +++ b/opencompass/openicl/icl_retriever/icl_fix_k_retriever.py @@ -19,6 +19,8 @@ class FixKRetriever(BaseRetriever): Args: dataset (`BaseDataset`): Any BaseDataset instances. Attributes of ``reader``, ``train`` and ``test`` will be used. + fix_id_list (List[int]): List of in-context example indices for every + test prompts. ice_separator (`Optional[str]`): The separator between each in-context example template when origin `PromptTemplate` is provided. Defaults to '\n'. @@ -31,22 +33,19 @@ class FixKRetriever(BaseRetriever): def __init__(self, dataset, + fix_id_list: List[int], ice_separator: Optional[str] = '\n', ice_eos_token: Optional[str] = '\n', ice_num: Optional[int] = 1) -> None: super().__init__(dataset, ice_separator, ice_eos_token, ice_num) + self.fix_id_list = fix_id_list - def retrieve(self, id_list: List[int]): - """Retrieve the in-context example index for each test example. - - Args: - id_list (List[int]): List of in-context example indices for every - test prompts. - """ + def retrieve(self): + """Retrieve the in-context example index for each test example.""" num_idx = len(self.index_ds) - for idx in id_list: + for idx in self.fix_id_list: assert idx < num_idx, f'Index {idx} is out of range of {num_idx}' rtr_idx_list = [] for _ in trange(len(self.test_ds), disable=not self.is_main_process): - rtr_idx_list.append(id_list) + rtr_idx_list.append(self.fix_id_list) return rtr_idx_list diff --git a/opencompass/utils/prompt.py b/opencompass/utils/prompt.py index d07126f4a..fe6e00e59 100644 --- a/opencompass/utils/prompt.py +++ b/opencompass/utils/prompt.py @@ -56,6 +56,10 @@ def get_prompt_hash(dataset_cfg: Union[ConfigDict, List[ConfigDict]]) -> str: 'test_split'] = dataset_cfg.infer_cfg.reader_cfg.test_split for k, v in dataset_cfg.infer_cfg.items(): dataset_cfg.infer_cfg[k]['type'] = v['type'].split('.')[-1] + # A compromise for the hash consistency + if 'fix_id_list' in dataset_cfg.infer_cfg.retriever: + fix_id_list = dataset_cfg.infer_cfg.retriever.pop('fix_id_list') + dataset_cfg.infer_cfg.inferencer['fix_id_list'] = fix_id_list d_json = json.dumps(dataset_cfg.infer_cfg.to_dict(), sort_keys=True) hash_object = hashlib.sha256(d_json.encode()) return hash_object.hexdigest() diff --git a/tools/prompt_viewer.py b/tools/prompt_viewer.py index 35280b1fe..99b44922b 100644 --- a/tools/prompt_viewer.py +++ b/tools/prompt_viewer.py @@ -61,7 +61,6 @@ def print_prompts(model_cfg, dataset_cfg, count=1): infer_cfg = dataset_cfg.get('infer_cfg') - fix_id_list = infer_cfg.inferencer.get('fix_id_list', []) dataset = build_dataset_from_cfg(dataset_cfg) ice_template = None @@ -76,10 +75,7 @@ def print_prompts(model_cfg, dataset_cfg, count=1): infer_cfg['retriever']['dataset'] = dataset retriever = ICL_RETRIEVERS.build(infer_cfg['retriever']) - if fix_id_list: - ice_idx_list = retriever.retrieve(fix_id_list) - else: - ice_idx_list = retriever.retrieve() + ice_idx_list = retriever.retrieve() assert infer_cfg.inferencer.type in [PPLInferencer, GenInferencer], \ 'Only PPLInferencer and GenInferencer are supported' diff --git a/tools/update_dataset_suffix.py b/tools/update_dataset_suffix.py index 91ede7200..138d6f77e 100755 --- a/tools/update_dataset_suffix.py +++ b/tools/update_dataset_suffix.py @@ -45,6 +45,10 @@ def get_prompt_hash(dataset_cfg: Union[ConfigDict, List[ConfigDict]]) -> str: 'test_split'] = dataset_cfg.infer_cfg.reader_cfg.test_split for k, v in dataset_cfg.infer_cfg.items(): dataset_cfg.infer_cfg[k]['type'] = v['type'].split('.')[-1] + # A compromise for the hash consistency + if 'fix_id_list' in dataset_cfg.infer_cfg.retriever: + fix_id_list = dataset_cfg.infer_cfg.retriever.pop('fix_id_list') + dataset_cfg.infer_cfg.inferencer['fix_id_list'] = fix_id_list d_json = json.dumps(dataset_cfg.infer_cfg.to_dict(), sort_keys=True) hash_object = hashlib.sha256(d_json.encode()) return hash_object.hexdigest()