Skip to content

Commit

Permalink
add safety to collections (open-compass#185)
Browse files Browse the repository at this point in the history
* [Feat] add safety to collections

* minor fix
  • Loading branch information
yingfhu authored and BunnyRunnerX committed Aug 11, 2023
1 parent 488b7bd commit bb4c4b3
Show file tree
Hide file tree
Showing 14 changed files with 66 additions and 17 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from mmengine.config import read_base

with read_base():
from .civilcomments_ppl_6a2561 import civilcomments_datasets # noqa: F401, F403
from .civilcomments_clp_a3c5fd import civilcomments_datasets # noqa: F401, F403
4 changes: 4 additions & 0 deletions configs/datasets/collections/base_medium.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,9 @@
from ..triviaqa.triviaqa_gen_2121ce import triviaqa_datasets
from ..flores.flores_gen_806ede import flores_datasets
from ..crowspairs.crowspairs_ppl_e811e1 import crowspairs_datasets
from ..civilcomments.civilcomments_clp_a3c5fd import civilcomments_datasets
from ..jigsawmultilingual.jigsawmultilingual_clp_fe50d8 import jigsawmultilingual_datasets
from ..realtoxicprompts.realtoxicprompts_gen_7605e4 import realtoxicprompts_datasets
from ..truthfulqa.truthfulqa_gen_5ddc62 import truthfulqa_datasets

datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])
4 changes: 4 additions & 0 deletions configs/datasets/collections/chat_medium.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,9 @@
from ..triviaqa.triviaqa_gen_2121ce import triviaqa_datasets
from ..flores.flores_gen_806ede import flores_datasets
from ..crowspairs.crowspairs_gen_21f7cb import crowspairs_datasets
from ..civilcomments.civilcomments_clp_a3c5fd import civilcomments_datasets
from ..jigsawmultilingual.jigsawmultilingual_clp_fe50d8 import jigsawmultilingual_datasets
from ..realtoxicprompts.realtoxicprompts_gen_7605e4 import realtoxicprompts_datasets
from ..truthfulqa.truthfulqa_gen_5ddc62 import truthfulqa_datasets

datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), [])
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from mmengine.config import read_base

with read_base():
from .jigsawmultilingual_ppl_fe50d8 import jigsawmultilingual_datasets # noqa: F401, F403
from .jigsawmultilingual_clp_fe50d8 import jigsawmultilingual_datasets # noqa: F401, F403
2 changes: 1 addition & 1 deletion configs/datasets/realtoxicprompts/realtoxicprompts_gen.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from mmengine.config import read_base

with read_base():
from .realtoxicprompts_gen_ac723c import realtoxicprompts_datasets # noqa: F401, F403
from .realtoxicprompts_gen_7605e4 import realtoxicprompts_datasets # noqa: F401, F403
5 changes: 2 additions & 3 deletions configs/datasets/truthfulqa/truthfulqa_gen_1e7d8d.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@

# TODO: allow empty output-column
truthfulqa_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template='{question}'),
prompt_template=dict(type=PromptTemplate, template='{question}'),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer))

Expand All @@ -31,6 +29,7 @@

truthfulqa_datasets = [
dict(
abbr='truthful_qa',
type=TruthfulQADataset,
path='truthful_qa',
name='generation',
Expand Down
1 change: 1 addition & 0 deletions configs/datasets/truthfulqa/truthfulqa_gen_5ddc62.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

truthfulqa_datasets = [
dict(
abbr='truthful_qa',
type=TruthfulQADataset,
path='truthful_qa',
name='generation',
Expand Down
25 changes: 17 additions & 8 deletions configs/summarizers/medium.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@
from .groups.jigsaw_multilingual import jigsaw_multilingual_summary_groups

summarizer = dict(
dataset_abbrs = [
'--------- 考试 Exam ---------', # category
dataset_abbrs=[
'--------- 考试 Exam ---------', # category
# 'Mixed', # subcategory
"ceval",
'agieval',
'mmlu',
"GaokaoBench",
'ARC-c',
'--------- 语言 Language ---------', # category
'--------- 语言 Language ---------', # category
# '字词释义', # subcategory
'WiC',
'summedits',
Expand All @@ -33,14 +33,14 @@
'winogrande',
# '翻译', # subcategory
'flores_100',
'--------- 知识 Knowledge ---------', # category
'--------- 知识 Knowledge ---------', # category
# '知识问答', # subcategory
'BoolQ',
'commonsense_qa',
'nq',
'triviaqa',
# '多语种问答', # subcategory
'--------- 推理 Reasoning ---------', # category
'--------- 推理 Reasoning ---------', # category
# '文本蕴含', # subcategory
'cmnli',
'ocnli',
Expand All @@ -67,7 +67,7 @@
'mbpp',
# '综合推理', # subcategory
"bbh",
'--------- 理解 Understanding ---------', # category
'--------- 理解 Understanding ---------', # category
# '阅读理解', # subcategory
'C3',
'CMRC_dev',
Expand All @@ -84,11 +84,20 @@
'eprstmt-dev',
'lambada',
'tnews-dev',
'--------- 安全 Safety ---------', # category
'--------- 安全 Safety ---------', # category
# '偏见', # subcategory
'crows_pairs',
# '有毒性(判别)', # subcategory
'civil_comments',
# '有毒性(判别)多语言', # subcategory
'jigsaw_multilingual',
# '有毒性(生成)', # subcategory
'real-toxicity-prompts',
# '真实性/有用性', # subcategory
'truthful_qa',
],
summary_groups=sum([v for k, v in locals().items() if k.endswith("_summary_groups")], []),
summary_groups=sum(
[v for k, v in locals().items() if k.endswith("_summary_groups")], []),
prompt_db=dict(
database_path='configs/datasets/log.json',
config_dir='configs/datasets',
Expand Down
17 changes: 17 additions & 0 deletions opencompass/openicl/icl_inferencer/icl_base_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,23 @@ def save_prompt_and_ppl(self, label, input, prompt, ppl, idx):
self.results_dict[str(idx)]['label: ' + str(label)]['prompt'] = prompt
self.results_dict[str(idx)]['label: ' + str(label)]['PPL'] = ppl


class CLPInferencerOutputHandler:
results_dict = {}

def __init__(self) -> None:
self.results_dict = {}

def write_to_json(self, save_dir: str, filename: str):
"""Dump the result to a json file."""
dump_results_dict(self.results_dict, Path(save_dir) / filename)

def save_ice(self, ice):
for idx, example in enumerate(ice):
if str(idx) not in self.results_dict.keys():
self.results_dict[str(idx)] = {}
self.results_dict[str(idx)]['in-context examples'] = example

def save_prompt_and_condprob(self, input, prompt, cond_prob, idx, choices):
if str(idx) not in self.results_dict.keys():
self.results_dict[str(idx)] = {}
Expand Down
21 changes: 18 additions & 3 deletions opencompass/openicl/icl_inferencer/icl_clp_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ..icl_prompt_template import PromptTemplate
from ..icl_retriever import BaseRetriever
from ..utils import get_logger
from .icl_base_inferencer import BaseInferencer, PPLInferencerOutputHandler
from .icl_base_inferencer import BaseInferencer, CLPInferencerOutputHandler

logger = get_logger(__name__)

Expand Down Expand Up @@ -80,7 +80,7 @@ def inference(self,
output_json_filename: Optional[str] = None,
normalizing_str: Optional[str] = None) -> List:
# 1. Preparation for output logs
output_handler = PPLInferencerOutputHandler()
output_handler = CLPInferencerOutputHandler()

ice = []

Expand All @@ -89,6 +89,20 @@ def inference(self,
if output_json_filename is None:
output_json_filename = self.output_json_filename

# CLP cannot infer with log probability for api models
# unless model provided such options which needs specific
# implementation, open an issue if you encounter the case.
if self.model.is_api:
# Write empty file in case always rerun for this model
if self.is_main_process:
os.makedirs(output_json_filepath, exist_ok=True)
err_msg = 'API model is not supported for conditional log '\
'probability inference and skip this exp.'
output_handler.results_dict = {'error': err_msg}
output_handler.write_to_json(output_json_filepath,
output_json_filename)
raise ValueError(err_msg)

# 2. Get results of retrieval process
if self.fix_id_list:
ice_idx_list = retriever.retrieve(self.fix_id_list)
Expand Down Expand Up @@ -118,7 +132,7 @@ def inference(self,
choice_ids = [self.model.tokenizer.encode(c) for c in choices]
if self.model.tokenizer.__class__.__name__ == 'ChatGLMTokenizer': # noqa
choice_ids = [c[2:] for c in choice_ids]
else:
elif hasattr(self.model.tokenizer, 'add_bos_token'):
if self.model.tokenizer.add_bos_token:
choice_ids = [c[1:] for c in choice_ids]
if self.model.tokenizer.add_eos_token:
Expand All @@ -139,6 +153,7 @@ def inference(self,
ice[idx],
ice_template=ice_template,
prompt_template=prompt_template)
prompt = self.model.parse_template(prompt, mode='ppl')
if self.max_seq_len is not None:
prompt_token_num = get_token_len(prompt)
# add one because additional token will be added in the end
Expand Down

0 comments on commit bb4c4b3

Please sign in to comment.