Skip to content

Commit

Permalink
Add humaneval postprocessor for GPT models & eval config for GPT4, en…
Browse files Browse the repository at this point in the history
…hance the original humaneval postprocessor (open-compass#129)

* [Enhancement] Enhance humaneval postprocessor

* add human-eval testcase

* update

* update

---------

Co-authored-by: Leymore <zfz-960727@163.com>
  • Loading branch information
2 people authored and BunnyRunnerX committed Aug 10, 2023
1 parent 5a8ef3a commit 768388c
Show file tree
Hide file tree
Showing 4 changed files with 190 additions and 6 deletions.
4 changes: 2 additions & 2 deletions configs/datasets/glm/humaneval.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets import HFDataset, HumanEvaluator
from opencompass.datasets import HFDataset, HumanEvaluator, humaneval_postprocess

humaneval_reader_cfg = dict(
input_columns=['prompt'], output_column='task_id', train_split='test')
Expand All @@ -17,7 +17,7 @@
humaneval_eval_cfg = dict(
evaluator=dict(type=HumanEvaluator),
k=[1, 10, 100], # the parameter only for humaneval
pred_postprocessor=dict(type='humaneval'),
pred_postprocessor=dict(type=humaneval_postprocess),
)

humaneval_datasets = [
Expand Down
40 changes: 40 additions & 0 deletions configs/eval_gpt4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from mmengine.config import read_base
from opencompass.models import OpenAI
from opencompass.partitioners import NaivePartitioner
from opencompass.runners import LocalRunner
from opencompass.tasks import OpenICLInferTask

with read_base():
from .datasets.collections.chat_medium import datasets
from .summarizers.medium import summarizer

# GPT4 needs a special humaneval postprocessor
from opencompass.datasets.humaneval import humaneval_gpt_postprocess
for _dataset in datasets:
if _dataset['path'] == 'openai_humaneval':
_dataset['eval_cfg']['pred_postprocessor']['type'] = humaneval_gpt_postprocess


api_meta_template = dict(
round=[
dict(role='HUMAN', api_role='HUMAN'),
dict(role='BOT', api_role='BOT', generate=True),
],
)

models = [
dict(abbr='GPT4',
type=OpenAI, path='gpt-4-0613',
key='ENV', # The key will be obtained from $OPENAI_API_KEY, but you can write down your key here as well
meta_template=api_meta_template,
query_per_second=1,
max_out_len=2048, max_seq_len=2048, batch_size=8),
]

infer = dict(
partitioner=dict(type=NaivePartitioner),
runner=dict(
type=LocalRunner,
max_num_workers=4,
task=dict(type=OpenICLInferTask)),
)
42 changes: 38 additions & 4 deletions opencompass/datasets/humaneval.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import os.path as osp
import re
import tempfile
from typing import List

from opencompass.openicl.icl_evaluator import BaseEvaluator
from opencompass.registry import ICL_EVALUATORS, TEXT_POSTPROCESSORS


@ICL_EVALUATORS.register_module()
class HumanEvaluator(BaseEvaluator):
"""Evaluator for human eval."""

Expand Down Expand Up @@ -41,11 +40,46 @@ def score(self, predictions, references):
return {f'humaneval_{k}': score[k] * 100 for k in score}


@TEXT_POSTPROCESSORS.register_module('humaneval')
def humaneval_postprocess(text: str) -> str:
if '```' in text:
blocks = re.findall(r'```(.*?)```', text, re.DOTALL)
if len(blocks) == 0:
text = text.split('```')[1] # fall back to default strategy
else:
text = blocks[0] # fetch the first code block
if not text.startswith('\n'): # in case starting with ```python
text = text[max(text.find('\n') + 1, 0):]
if text.strip().startswith('from') or text.strip().startswith('import'):
def_idx = text.find('def')
if def_idx != -1:
text = text[max(text.find('\n', def_idx) + 1, 0):]
text = text.split('\n\n')[0]
if text.strip().startswith('def'):
text = '\n'.join(text.split('\n')[1:])
if not text.startswith(' '):
if text.startswith(' '):
text = ' ' + text.lstrip()
else:
text = '\n'.join([' ' + line for line in text.split('\n')])
return text


def humaneval_gpt_postprocess(text: str) -> str:
"""Better answer postprocessor for better instruction-aligned models like
GPT."""
if '```' in text:
text = text.split('```')[1]
blocks = re.findall(r'```(.*?)```', text, re.DOTALL)
if len(blocks) == 0:
text = text.split('```')[1] # fall back to default strategy
else:
text = blocks[0] # fetch the first code block
if not text.startswith('\n'): # in case starting with ```python
text = text[max(text.find('\n') + 1, 0):]
if text.strip().startswith('from') or text.strip().startswith('import'):
def_idx = text.find('def')
if def_idx != -1:
text = text[max(text.find('\n', def_idx) + 1, 0):]
text = text.split('\n\n\n')[0]
if text.strip().startswith('def'):
text = '\n'.join(text.split('\n')[1:])
if not text.startswith(' '):
Expand Down
110 changes: 110 additions & 0 deletions tests/dataset/test_humaneval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import unittest

from opencompass.datasets.humaneval import humaneval_postprocess


def run_humaneval_check(completion):
program = [
'def get_fraction(x: float) -> float:',
humaneval_postprocess(completion),
'',
'assert get_fraction(1.28) == 0.28',
'assert get_fraction(1.0) == 0.0',
]
program = '\n'.join(program)
exec(program)


class TestHumaneval(unittest.TestCase):

def test_vanilla(self):
raw = ' return x - int(x)'
run_humaneval_check(raw)

def test_python_quote(self):
lines = [
'```python',
' return x - int(x)',
'```',
]
raw = '\n'.join(lines)
run_humaneval_check(raw)

def test_bare_quote(self):
lines = [
'```',
' return x - int(x)',
'```',
]
raw = '\n'.join(lines)
run_humaneval_check(raw)

def test_error_space_quote(self):
lines = [
'```',
' return x - int(x)',
'```',
]
raw = '\n'.join(lines)
run_humaneval_check(raw)

def test_import_1(self):
lines = [
'import numpy as np',
'import math',
'from typing import List',
'',
'def func(x):',
' return x - int(x)',
]
raw = '\n'.join(lines)
run_humaneval_check(raw)

def test_import_2(self):
lines = [
'from typing import List',
'import numpy as np',
'import math',
'def func(x):',
' return x - int(x)',
]
raw = '\n'.join(lines)
run_humaneval_check(raw)

def test_import_3(self):
lines = [
'import math',
'',
'',
'def func(x):',
' return x - int(x)',
]
raw = '\n'.join(lines)
run_humaneval_check(raw)

def test_comment(self):
lines = [
'def func(x: float) -> float:',
" '''",
' blah blah blah',
' blah blah blah',
" '''",
' return x - int(x)',
]
raw = '\n'.join(lines)
run_humaneval_check(raw)

def test_additional(self):
lines = [
' return x - int(x)',
'',
'',
'def func(x: float) -> float:',
" '''",
' blah blah blah',
' blah blah blah',
" '''",
' return x - int(x)',
]
raw = '\n'.join(lines)
run_humaneval_check(raw)

0 comments on commit 768388c

Please sign in to comment.