Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] add safety to collections #185

Merged
merged 2 commits into from
Aug 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from mmengine.config import read_base

with read_base():
from .civilcomments_ppl_6a2561 import civilcomments_datasets # noqa: F401, F403
from .civilcomments_clp_a3c5fd import civilcomments_datasets # noqa: F401, F403
4 changes: 4 additions & 0 deletions configs/datasets/collections/base_medium.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,9 @@
from ..triviaqa.triviaqa_gen_2121ce import triviaqa_datasets
from ..flores.flores_gen_806ede import flores_datasets
from ..crowspairs.crowspairs_ppl_e811e1 import crowspairs_datasets
from ..civilcomments.civilcomments_clp_a3c5fd import civilcomments_datasets
from ..jigsawmultilingual.jigsawmultilingual_clp_fe50d8 import jigsawmultilingual_datasets
from ..realtoxicprompts.realtoxicprompts_gen_7605e4 import realtoxicprompts_datasets
from ..truthfulqa.truthfulqa_gen_5ddc62 import truthfulqa_datasets

datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])
4 changes: 4 additions & 0 deletions configs/datasets/collections/chat_medium.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,9 @@
from ..triviaqa.triviaqa_gen_2121ce import triviaqa_datasets
from ..flores.flores_gen_806ede import flores_datasets
from ..crowspairs.crowspairs_gen_21f7cb import crowspairs_datasets
from ..civilcomments.civilcomments_clp_a3c5fd import civilcomments_datasets
from ..jigsawmultilingual.jigsawmultilingual_clp_fe50d8 import jigsawmultilingual_datasets
from ..realtoxicprompts.realtoxicprompts_gen_7605e4 import realtoxicprompts_datasets
from ..truthfulqa.truthfulqa_gen_5ddc62 import truthfulqa_datasets

datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from mmengine.config import read_base

with read_base():
from .jigsawmultilingual_ppl_fe50d8 import jigsawmultilingual_datasets # noqa: F401, F403
from .jigsawmultilingual_clp_fe50d8 import jigsawmultilingual_datasets # noqa: F401, F403
2 changes: 1 addition & 1 deletion configs/datasets/realtoxicprompts/realtoxicprompts_gen.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from mmengine.config import read_base

with read_base():
from .realtoxicprompts_gen_ac723c import realtoxicprompts_datasets # noqa: F401, F403
from .realtoxicprompts_gen_7605e4 import realtoxicprompts_datasets # noqa: F401, F403
5 changes: 2 additions & 3 deletions configs/datasets/truthfulqa/truthfulqa_gen_1e7d8d.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@

# TODO: allow empty output-column
truthfulqa_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template='{question}'),
prompt_template=dict(type=PromptTemplate, template='{question}'),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer))

Expand All @@ -31,6 +29,7 @@

truthfulqa_datasets = [
dict(
abbr='truthful_qa',
type=TruthfulQADataset,
path='truthful_qa',
name='generation',
Expand Down
1 change: 1 addition & 0 deletions configs/datasets/truthfulqa/truthfulqa_gen_5ddc62.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

truthfulqa_datasets = [
dict(
abbr='truthful_qa',
type=TruthfulQADataset,
path='truthful_qa',
name='generation',
Expand Down
25 changes: 17 additions & 8 deletions configs/summarizers/medium.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@
from .groups.jigsaw_multilingual import jigsaw_multilingual_summary_groups

summarizer = dict(
dataset_abbrs = [
'--------- 考试 Exam ---------', # category
dataset_abbrs=[
'--------- 考试 Exam ---------', # category
# 'Mixed', # subcategory
"ceval",
'agieval',
'mmlu',
"GaokaoBench",
'ARC-c',
'--------- 语言 Language ---------', # category
'--------- 语言 Language ---------', # category
# '字词释义', # subcategory
'WiC',
'summedits',
Expand All @@ -33,14 +33,14 @@
'winogrande',
# '翻译', # subcategory
'flores_100',
'--------- 知识 Knowledge ---------', # category
'--------- 知识 Knowledge ---------', # category
# '知识问答', # subcategory
'BoolQ',
'commonsense_qa',
'nq',
'triviaqa',
# '多语种问答', # subcategory
'--------- 推理 Reasoning ---------', # category
'--------- 推理 Reasoning ---------', # category
# '文本蕴含', # subcategory
'cmnli',
'ocnli',
Expand All @@ -67,7 +67,7 @@
'mbpp',
# '综合推理', # subcategory
"bbh",
'--------- 理解 Understanding ---------', # category
'--------- 理解 Understanding ---------', # category
# '阅读理解', # subcategory
'C3',
'CMRC_dev',
Expand All @@ -84,11 +84,20 @@
'eprstmt-dev',
'lambada',
'tnews-dev',
'--------- 安全 Safety ---------', # category
'--------- 安全 Safety ---------', # category
# '偏见', # subcategory
'crows_pairs',
# '有毒性(判别)', # subcategory
'civil_comments',
# '有毒性(判别)多语言', # subcategory
'jigsaw_multilingual',
# '有毒性(生成)', # subcategory
'real-toxicity-prompts',
# '真实性/有用性', # subcategory
'truthful_qa',
],
summary_groups=sum([v for k, v in locals().items() if k.endswith("_summary_groups")], []),
summary_groups=sum(
[v for k, v in locals().items() if k.endswith("_summary_groups")], []),
prompt_db=dict(
database_path='configs/datasets/log.json',
config_dir='configs/datasets',
Expand Down
17 changes: 17 additions & 0 deletions opencompass/openicl/icl_inferencer/icl_base_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,23 @@ def save_prompt_and_ppl(self, label, input, prompt, ppl, idx):
self.results_dict[str(idx)]['label: ' + str(label)]['prompt'] = prompt
self.results_dict[str(idx)]['label: ' + str(label)]['PPL'] = ppl


class CLPInferencerOutputHandler:
results_dict = {}

def __init__(self) -> None:
self.results_dict = {}

def write_to_json(self, save_dir: str, filename: str):
"""Dump the result to a json file."""
dump_results_dict(self.results_dict, Path(save_dir) / filename)

def save_ice(self, ice):
for idx, example in enumerate(ice):
if str(idx) not in self.results_dict.keys():
self.results_dict[str(idx)] = {}
self.results_dict[str(idx)]['in-context examples'] = example

def save_prompt_and_condprob(self, input, prompt, cond_prob, idx, choices):
if str(idx) not in self.results_dict.keys():
self.results_dict[str(idx)] = {}
Expand Down
21 changes: 18 additions & 3 deletions opencompass/openicl/icl_inferencer/icl_clp_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ..icl_prompt_template import PromptTemplate
from ..icl_retriever import BaseRetriever
from ..utils import get_logger
from .icl_base_inferencer import BaseInferencer, PPLInferencerOutputHandler
from .icl_base_inferencer import BaseInferencer, CLPInferencerOutputHandler

logger = get_logger(__name__)

Expand Down Expand Up @@ -79,7 +79,7 @@ def inference(self,
output_json_filename: Optional[str] = None,
normalizing_str: Optional[str] = None) -> List:
# 1. Preparation for output logs
output_handler = PPLInferencerOutputHandler()
output_handler = CLPInferencerOutputHandler()

ice = []

Expand All @@ -88,6 +88,20 @@ def inference(self,
if output_json_filename is None:
output_json_filename = self.output_json_filename

# CLP cannot infer with log probability for api models
# unless model provided such options which needs specific
# implementation, open an issue if you encounter the case.
if self.model.is_api:
# Write empty file in case always rerun for this model
if self.is_main_process:
os.makedirs(output_json_filepath, exist_ok=True)
err_msg = 'API model is not supported for conditional log '\
'probability inference and skip this exp.'
output_handler.results_dict = {'error': err_msg}
output_handler.write_to_json(output_json_filepath,
output_json_filename)
raise ValueError(err_msg)

# 2. Get results of retrieval process
if self.fix_id_list:
ice_idx_list = retriever.retrieve(self.fix_id_list)
Expand Down Expand Up @@ -117,7 +131,7 @@ def inference(self,
choice_ids = [self.model.tokenizer.encode(c) for c in choices]
if self.model.tokenizer.__class__.__name__ == 'ChatGLMTokenizer': # noqa
choice_ids = [c[2:] for c in choice_ids]
else:
elif hasattr(self.model.tokenizer, 'add_bos_token'):
if self.model.tokenizer.add_bos_token:
choice_ids = [c[1:] for c in choice_ids]
if self.model.tokenizer.add_eos_token:
Expand All @@ -135,6 +149,7 @@ def inference(self,
ice[idx],
ice_template=ice_template,
prompt_template=prompt_template)
prompt = self.model.parse_template(prompt, mode='ppl')
if self.max_seq_len is not None:
prompt_token_num = get_token_len(prompt)
# add one because additional token will be added in the end
Expand Down