From 01c1ef11e8fb34382b74068501b4e33e50bc2f9c Mon Sep 17 00:00:00 2001 From: Leymore Date: Tue, 29 Aug 2023 13:48:31 +0000 Subject: [PATCH] refactor prompt viewer & output reference --- tools/prompt_viewer.py | 142 ++++++++++++++++++++--------------------- 1 file changed, 68 insertions(+), 74 deletions(-) diff --git a/tools/prompt_viewer.py b/tools/prompt_viewer.py index 35280b1fe..e997d93c2 100644 --- a/tools/prompt_viewer.py +++ b/tools/prompt_viewer.py @@ -4,8 +4,7 @@ from mmengine.config import Config, ConfigDict -from opencompass.openicl.icl_inferencer import (CLPInferencer, GenInferencer, - PPLInferencer) +from opencompass.openicl.icl_inferencer import GenInferencer, PPLInferencer from opencompass.registry import ICL_PROMPT_TEMPLATES, ICL_RETRIEVERS from opencompass.utils import (Menu, build_dataset_from_cfg, build_model_from_cfg, dataset_abbr_from_cfg, @@ -45,6 +44,24 @@ def parse_dataset_cfg(dataset_cfg: ConfigDict) -> Dict[str, ConfigDict]: return dataset2cfg +def get_prompt(ice_idx, prompt_func, max_seq_len, model): + prompt = prompt_func(ice_idx) + num_token = model.get_token_len_from_template(prompt) + if max_seq_len is None: + print(f'Number of tokens: {num_token}') + return ice_idx, prompt + while len(ice_idx) > 0 and num_token > max_seq_len: + num_ice = len(ice_idx) + old_num_token = num_token + ice_idx = ice_idx[:-1] + prompt = prompt_func(ice_idx) + num_token = model.get_token_len_from_template(prompt) + print(f'Truncating ice {num_ice} -> {num_ice - 1}', + f'Number of tokens: {old_num_token} -> {num_token}') + print(f'Number of tokens: {num_token}') + return prompt + + def print_prompts(model_cfg, dataset_cfg, count=1): # TODO: A really dirty method that copies code from PPLInferencer and # GenInferencer. In the future, the prompt extraction code should be @@ -84,92 +101,70 @@ def print_prompts(model_cfg, dataset_cfg, count=1): assert infer_cfg.inferencer.type in [PPLInferencer, GenInferencer], \ 'Only PPLInferencer and GenInferencer are supported' + ice = retriever.generate_ice(ice_idx_list[0], ice_template=ice_template) + print('=' * 100) + print('Full in-context example:') + print('-' * 100) + print(ice) + print('=' * 100) for idx in range(min(count, len(ice_idx_list))): + print('=' * 100) + print(f'Data Item #{idx}:') if infer_cfg.inferencer.type == PPLInferencer: labels = retriever.get_labels(ice_template=ice_template, prompt_template=prompt_template) - ice = [ - retriever.generate_ice(ice_idx_list[_idx], - ice_template=ice_template) - for _idx in range(len(ice_idx_list)) - ] - print('-' * 100) - print('ICE Template:') - print('-' * 100) - print(ice[0]) - print('-' * 100) for label in labels: - prompt = retriever.generate_label_prompt( - idx, - ice[idx], - label, - ice_template=ice_template, - prompt_template=prompt_template, - remain_sep=None) - if max_seq_len is not None: - prompt_token_num = model.get_token_len_from_template( - prompt) - while len(ice_idx_list[idx] - ) > 0 and prompt_token_num > max_seq_len: - num_ice = len(ice_idx_list[idx]) - print(f'Truncating ice {num_ice} -> {num_ice - 1}', - f'Number of tokens: {prompt_token_num} -> ...') - ice_idx_list[idx] = ice_idx_list[idx][:-1] - ice[idx] = retriever.generate_ice( - ice_idx_list[idx], ice_template=ice_template) - prompt = retriever.generate_label_prompt( - idx, - ice[idx], - label, - ice_template=ice_template, - prompt_template=prompt_template) - prompt_token_num = model.get_token_len_from_template( - prompt) - print(f'Number of tokens: {prompt_token_num}') - if model is not None: - prompt = model.parse_template(prompt, mode='ppl') - print('-' * 100) - print(f'Label: {label}') - print('Sample prompt:') - print('-' * 100) - print(prompt) - print('-' * 100) - elif infer_cfg.inferencer.type in [GenInferencer, CLPInferencer]: - ice_idx = ice_idx_list[idx] - ice = retriever.generate_ice(ice_idx, ice_template=ice_template) - prompt = retriever.generate_prompt_for_generate_task( - idx, - ice, - gen_field_replace_token=infer_cfg.inferencer.get( - 'gen_field_replace_token', ''), - ice_template=ice_template, - prompt_template=prompt_template) - if max_seq_len is not None: - prompt_token_num = model.get_token_len_from_template(prompt) - while len(ice_idx) > 0 and prompt_token_num > max_seq_len: - num_ice = len(ice_idx) - print(f'Truncating ice {num_ice} -> {num_ice - 1}', - f'Number of tokens: {prompt_token_num} -> ...') - ice_idx = ice_idx[:-1] + + def prompt_func(ice_idx): ice = retriever.generate_ice(ice_idx, ice_template=ice_template) - prompt = retriever.generate_prompt_for_generate_task( + return retriever.generate_label_prompt( idx, ice, - gen_field_replace_token=infer_cfg.inferencer.get( - 'gen_field_replace_token', ''), + label, ice_template=ice_template, - prompt_template=prompt_template) - prompt_token_num = model.get_token_len_from_template( - prompt) - print(f'Number of tokens: {prompt_token_num}') + prompt_template=prompt_template, + remain_sep=None) + + print('-' * 100) + print(f'Label: {label}') + prompt = get_prompt(ice_idx_list[idx], prompt_func, + max_seq_len, model) + if model is not None: + prompt = model.parse_template(prompt, mode='ppl') + print('Prompt:') + print('-' * 61) + print(prompt) + print('-' * 100) + elif infer_cfg.inferencer.type == GenInferencer: + + def prompt_func(ice_idx): + ice = retriever.generate_ice(ice_idx, + ice_template=ice_template) + return retriever.generate_prompt_for_generate_task( + idx, + ice, + gen_field_replace_token=infer_cfg.inferencer.get( + 'gen_field_replace_token', ''), + ice_template=ice_template, + prompt_template=prompt_template) + + prompt = get_prompt(ice_idx_list[idx], prompt_func, max_seq_len, + model) if model is not None: prompt = model.parse_template(prompt, mode='gen') print('-' * 100) - print('Sample prompt:') - print('-' * 100) + print('Prompt:') + print('-' * 61) print(prompt) print('-' * 100) + else: + raise NotImplementedError + + reference = dataset.test[idx][dataset.reader.output_column] + print('Reference:') + print('-' * 61) + print(reference) def main(): @@ -215,7 +210,6 @@ def main(): print('=' * 64, '[BEGIN]', '=' * 64) print(f'[MODEL]: {model_abbr}') print(f'[DATASET]: {dataset_abbr}') - print('---') print_prompts(model_cfg, dataset_cfg, args.count) print('=' * 65, '[END]', '=' * 65) print()