-
Notifications
You must be signed in to change notification settings - Fork 360
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Support inference ppl datasets (#1315)
* commit inference ppl datasets * revised format * revise * revise * revise * revise * revise * revise
- Loading branch information
Showing
12 changed files
with
662 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# Inference-PPL Datasets | ||
|
||
- **Description**: Compute the loss only on the labeled positions, especially used for reasoning corpus. | ||
- **Datasets**: cn-reasoning-val.jsonl (example datasets, inference-ppl can be generalized to more corpus). | ||
|
||
# PPL Computation | ||
|
||
$$ \text{ppl} = - \frac{1}{n} \sum_{i=0}^n \sum_{c=0}^{vocab\_size} y_{i,c} \log p_{i,c} \tag{1} $$ | ||
|
||
where Eq. (1) is the normal mean ppl computation formula, for inference-ppl, we only compute the average score based on pre-labeled position. | ||
|
||
# Quick Start | ||
|
||
```shell | ||
cd opencompass | ||
python run.py configs/eval_inference_ppl.py | ||
``` | ||
|
||
# Some results | ||
|
||
| Model | Result | | ||
| ----------- | ----------- | | ||
| Qwen1.5-7b | 0.59 | | ||
| Qwen1.5-14b | 0.54 | | ||
| Llama2-7b | 0.49 | | ||
| Llama2-13b | 0.43 | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from opencompass.openicl.icl_prompt_template import PromptTemplate | ||
from opencompass.openicl.icl_retriever import ZeroRetriever | ||
from opencompass.openicl.icl_inferencer import InferencePPLOnlyInferencer | ||
from opencompass.openicl.icl_evaluator import AverageInferencePPLEvaluator | ||
|
||
from opencompass.datasets import InferencePPLDataset | ||
|
||
# Build InferencePPLDataset | ||
inference_ppl_datasets = [] | ||
|
||
llm_cmp_infer_cfg = dict( | ||
prompt_template=dict( | ||
type=PromptTemplate, | ||
template='{text}', | ||
), | ||
# No in-context example, using ZeroRetriever | ||
retriever=dict(type=ZeroRetriever), | ||
# compute inference-ppl | ||
inferencer=dict(type=InferencePPLOnlyInferencer), | ||
) | ||
|
||
# Average the inference-ppl scores | ||
llm_cmp_eval_cfg = dict(evaluator=dict(type=AverageInferencePPLEvaluator)) | ||
|
||
inference_ppl_datasets.append( | ||
dict( | ||
abbr=f'inference-ppl', | ||
type=InferencePPLDataset, | ||
path='./data/inference_ppl', | ||
name='cn-reasoning-val', | ||
samples=None, # Set small samples for testing | ||
reader_cfg=dict( | ||
input_columns=['text'], | ||
output_column=None, | ||
), | ||
infer_cfg=llm_cmp_infer_cfg, | ||
eval_cfg=llm_cmp_eval_cfg, | ||
)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
from mmengine.config import read_base | ||
|
||
with read_base(): | ||
# Inference PPL datasets | ||
from .datasets.inference_ppl.inference_ppl import inference_ppl_datasets | ||
|
||
# Model configs | ||
from .models.qwen.hf_qwen1_5_7b import models as qwen1_5_7b | ||
from .models.qwen.hf_qwen1_5_14b import models as qwen1_5_14b | ||
from .models.hf_llama.hf_llama2_7b import models as llama2_7b | ||
from .models.hf_llama.hf_llama2_13b import models as llama2_13b | ||
|
||
|
||
from opencompass.partitioners import NaivePartitioner | ||
from opencompass.runners import LocalRunner | ||
from opencompass.tasks import OpenICLInferTask, OpenICLEvalTask | ||
|
||
|
||
# -------------Inference Stage ---------------------------------------- | ||
|
||
datasets = [*inference_ppl_datasets] | ||
workdir = 'outputs/inference_ppl' | ||
|
||
models = [ | ||
*qwen1_5_7b, | ||
*qwen1_5_14b, | ||
*llama2_7b, | ||
*llama2_13b, | ||
] | ||
|
||
|
||
|
||
# Set custom batch_size and num_gpus for faster loss calculation | ||
# Smaller batch_size should give more precise results, at the cost of worse efficiency | ||
model_cfg = dict( | ||
batch_size=8, | ||
run_cfg=dict(num_gpus=4, num_procs=1) | ||
) | ||
|
||
for mdl in models: | ||
mdl.update(model_cfg) | ||
|
||
|
||
infer = dict( | ||
partitioner=dict(type=NaivePartitioner), | ||
runner=dict( | ||
type=LocalRunner, | ||
task=dict(type=OpenICLInferTask), | ||
max_num_workers=256, # Maximum concurrent evaluation task count | ||
), | ||
) | ||
|
||
|
||
# -------------Evaluation Stage ---------------------------------------- | ||
eval = dict( | ||
partitioner=dict(type=NaivePartitioner), | ||
runner=dict( | ||
type=LocalRunner, | ||
task=dict(type=OpenICLEvalTask), | ||
max_num_workers=256, | ||
) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import os.path as osp | ||
from typing import List | ||
|
||
from datasets import load_dataset | ||
|
||
from opencompass.registry import LOAD_DATASET | ||
|
||
from .base import BaseDataset | ||
|
||
|
||
@LOAD_DATASET.register_module() | ||
class InferencePPLDataset(BaseDataset): | ||
|
||
@staticmethod | ||
def load(path: str, name: List[str] = None, samples: int = None): | ||
|
||
# Check if file exists in the given path | ||
supported_extensions = ['jsonl'] | ||
for ext in supported_extensions: | ||
filename = osp.join( | ||
path, f'{name}.{ext}') # name refers to data subset name | ||
|
||
if osp.exists(filename): | ||
break | ||
else: | ||
raise FileNotFoundError(f'{filename} not found.') | ||
|
||
samples = 'test' if samples is None else f'test[:{samples}]' | ||
|
||
data_files = {'test': filename} | ||
|
||
dataset = load_dataset('json', data_files=data_files, split=samples) | ||
|
||
# Filter out empty samples | ||
dataset = dataset.filter(lambda example: len(example['text']) > 0) | ||
|
||
return dataset |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.