-
Notifications
You must be signed in to change notification settings - Fork 228
/
Copy pathmbpp.py
93 lines (78 loc) · 3.48 KB
/
mbpp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""Program Synthesis with Large Language Models
https://arxiv.org/abs/2108.07732
The benchmark consists of around 1,000 crowd-sourced Python programming problems,
designed to be solvable by entry level programmers, covering programming fundamentals,
standard library functionality, and so on. Each problem consists of a task description,
code solution and 3 automated test cases. As described in the paper, a subset of the data
has been hand-verified by the authors.
Homepage:: https://github.com/google-research/google-research/tree/master/mbpp
"""
import re
from evaluate import load
from lm_eval.base import Task
_CITATION = """
@article{austin2021program,
title={Program Synthesis with Large Language Models},
author={Austin, Jacob and Odena, Augustus and Nye, Maxwell and Bosma, Maarten and Michalewski, Henryk and Dohan, David and Jiang, Ellen and Cai, Carrie and Terry, Michael and Le, Quoc and others},
journal={arXiv preprint arXiv:2108.07732},
year={2021}
}
"""
class MBPP(Task):
"""A task represents an entire benchmark including its dataset, problems,
answers, generation settings and evaluation methods.
"""
DATASET_PATH = "mbpp"
def __init__(self):
super().__init__(
stop_words=["\nclass", "\nassert", '\n"""', "\nprint", "\nif", "\n<|/"],
requires_execution=True,
)
def get_dataset(self):
"""Returns dataset for the task or an iterable of any object, that get_prompt can handle"""
dataset = self.dataset["test"]
# the wrong split of mbpp can be loaded with old datasets cache
assert (
len(dataset) == 500
), "please ensure you have the latest version of MBPP dataset, try deleting its old cache"
return dataset
def get_prompt(self, doc):
"""Builds the prompt for the LM to generate from.
MBPP prompt is built following to InCoder (Fried et al.) approach
prompt = docstring that includes one test
"""
description = doc["text"]
test_example = doc["test_list"][0]
prompt = f'"""\n{description}\n{test_example}\n"""\n'
return prompt
def get_reference(self, doc):
"""Builds the reference solution for the doc (sample from the test dataset)."""
return "\n".join(doc["test_list"])
@staticmethod
def first_block(string, stop_words):
"""Split off first block of code by scanning for class, def etc. on newlines."""
return re.split("|".join(stop_words), string)[0].rstrip()
def postprocess_generation(self, generation, idx):
"""Defines the postprocessing for a LM generation.
:param generation: str
code generation from LM
:param idx: int
index of doc in the dataset to which the generation belongs
"""
prompt = self.get_prompt(self.get_dataset()[idx])
output = generation[len(prompt) :]
return self.first_block(output, self.stop_words)
def process_results(self, generations, references):
"""Takes the list of LM generations and evaluates them against ground truth references,
returning the metric for the generations.
:param generations: list(list(str))
list of lists containing generations
:param references: list(str)
list of str containing refrences
"""
code_metric = load("code_eval")
results, _ = code_metric.compute(
references=references,
predictions=generations,
)
return results