Skip to content

Commit

Permalink
[Feature] re-implement ceval load dataset (#446)
Browse files Browse the repository at this point in the history
  • Loading branch information
Leymore committed Sep 27, 2023
1 parent d9f3e88 commit 9db5652
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 24 deletions.
23 changes: 23 additions & 0 deletions configs/summarizers/groups/ceval.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,26 @@

_ceval_all = _ceval_stem + _ceval_social_science + _ceval_humanities + _ceval_other
ceval_summary_groups.append({'name': 'ceval', 'subsets': _ceval_all})

_ceval_stem = ['computer_network', 'operating_system', 'computer_architecture', 'college_programming', 'college_physics', 'college_chemistry', 'advanced_mathematics', 'probability_and_statistics', 'discrete_mathematics', 'electrical_engineer', 'metrology_engineer', 'high_school_mathematics', 'high_school_physics', 'high_school_chemistry', 'high_school_biology', 'middle_school_mathematics', 'middle_school_biology', 'middle_school_physics', 'middle_school_chemistry', 'veterinary_medicine']
_ceval_stem = ['ceval-test-' + s for s in _ceval_stem]
ceval_summary_groups.append({'name': 'ceval-test-stem', 'subsets': _ceval_stem})

_ceval_social_science = ['college_economics', 'business_administration', 'marxism', 'mao_zedong_thought', 'education_science', 'teacher_qualification', 'high_school_politics', 'high_school_geography', 'middle_school_politics', 'middle_school_geography']
_ceval_social_science = ['ceval-test-' + s for s in _ceval_social_science]
ceval_summary_groups.append({'name': 'ceval-test-social-science', 'subsets': _ceval_social_science})

_ceval_humanities = ['modern_chinese_history', 'ideological_and_moral_cultivation', 'logic', 'law', 'chinese_language_and_literature', 'art_studies', 'professional_tour_guide', 'legal_professional', 'high_school_chinese', 'high_school_history', 'middle_school_history']
_ceval_humanities = ['ceval-test-' + s for s in _ceval_humanities]
ceval_summary_groups.append({'name': 'ceval-test-humanities', 'subsets': _ceval_humanities})

_ceval_other = ['civil_servant', 'sports_science', 'plant_protection', 'basic_medicine', 'clinical_medicine', 'urban_and_rural_planner', 'accountant', 'fire_engineer', 'environmental_impact_assessment_engineer', 'tax_accountant', 'physician']
_ceval_other = ['ceval-test-' + s for s in _ceval_other]
ceval_summary_groups.append({'name': 'ceval-test-other', 'subsets': _ceval_other})

_ceval_hard = ['advanced_mathematics', 'discrete_mathematics', 'probability_and_statistics', 'college_chemistry', 'college_physics', 'high_school_mathematics', 'high_school_chemistry', 'high_school_physics']
_ceval_hard = ['ceval-test-' + s for s in _ceval_hard]
ceval_summary_groups.append({'name': 'ceval-test-hard', 'subsets': _ceval_hard})

_ceval_all = _ceval_stem + _ceval_social_science + _ceval_humanities + _ceval_other
ceval_summary_groups.append({'name': 'ceval-test', 'subsets': _ceval_all})
38 changes: 14 additions & 24 deletions opencompass/datasets/ceval.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import csv
import os.path as osp

from datasets import DatasetDict, load_dataset
from datasets import Dataset, DatasetDict

from opencompass.registry import LOAD_DATASET

Expand All @@ -12,26 +13,15 @@ class CEvalDataset(BaseDataset):

@staticmethod
def load(path: str, name: str):
dev_dataset = load_dataset('csv',
data_files=osp.join(path, 'dev',
f'{name}_dev.csv'),
split='train')
val_dataset = load_dataset('csv',
data_files=osp.join(path, 'val',
f'{name}_val.csv'),
split='train')
val_dataset = val_dataset.add_column('explanation',
[''] * len(val_dataset))
test_dataset = load_dataset('csv',
data_files=osp.join(
path, 'test', f'{name}_test.csv'),
split='train')
test_dataset = test_dataset.add_column(
'answer',
[''] * len(test_dataset)).add_column('explanation',
[''] * len(test_dataset))
return DatasetDict({
'val': val_dataset,
'dev': dev_dataset,
'test': test_dataset
})
dataset = {}
for split in ['dev', 'val', 'test']:
with open(osp.join(path, split, f'{name}_{split}.csv')) as f:
reader = csv.reader(f)
header = next(reader)
for row in reader:
item = dict(zip(header, row))
item.setdefault('explanation', '')
item.setdefault('answer', '')
dataset.setdefault(split, []).append(item)
dataset = {i: Dataset.from_list(dataset[i]) for i in dataset}
return DatasetDict(dataset)

0 comments on commit 9db5652

Please sign in to comment.