Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP; Input Requested] Integrated DS1000 task #39

Merged
merged 37 commits into from
Mar 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
3e4c3d4
integrated DS1000 task; added option for task to handle downloading o…
benlipkin Jan 25, 2023
6b40a20
alphabetize imports
benlipkin Jan 30, 2023
31f8352
add trust_remote_code and use_auth_token args
loubnabnl Jan 29, 2023
7b229b8
update readme with trust_remote_code arg
loubnabnl Jan 29, 2023
7b8786a
shifted ds1000 dependency installation from runtime to setup
benlipkin Jan 30, 2023
681f47f
incoder and santacoder infilling support first commit; passing for si…
benlipkin Feb 3, 2023
51ace91
remove redundant code; special tokens already added to tokenizer
benlipkin Feb 5, 2023
935569c
vectorized implementation to strip left padding
benlipkin Feb 5, 2023
2fc6875
rolled back strip_left_padding implementation; parse_infill now only …
benlipkin Feb 5, 2023
bf38f8c
updated task builder function names for consistency
benlipkin Feb 5, 2023
111d46f
explicitly specified prompt prefix and suffix newlines
benlipkin Feb 5, 2023
0767073
specified ds1000 commit; added download warning to user
benlipkin Feb 5, 2023
e872f12
reverted to right padding; fixes index misalignment during generation
benlipkin Feb 5, 2023
b5a580c
track iterations in progress bar
benlipkin Feb 5, 2023
9367df4
handle different num_samples and batch_size when tracking iters
benlipkin Feb 6, 2023
0550610
change format of boolean arguments
loubnabnl Jan 30, 2023
bba79b5
change how bool args are called and fix typos
loubnabnl Jan 30, 2023
d6acbff
modified accuracy from pass@k to mean pass@1 averaged over num_sample…
benlipkin Feb 7, 2023
2ce5463
more concise naming for ds1000 accuracy metric.
benlipkin Feb 7, 2023
8a98e47
track iters on test case execution during scoring
benlipkin Feb 7, 2023
d21afa0
handle edge case where model generates mask tokens
benlipkin Feb 8, 2023
a9e11a7
added support for matplotlib different problem format
benlipkin Feb 8, 2023
d87aadc
covered edge case where model generates new start token
benlipkin Feb 13, 2023
bc19ec2
only insertion returned in insert mode so prefix prompt already stripped
benlipkin Feb 13, 2023
a2bc888
suppress tensorflow warnings at import
benlipkin Feb 14, 2023
e163484
synchronize dataset download across parallel processes using lockfile
benlipkin Feb 14, 2023
514e701
refactored infill mode tracker to global variable; original approach …
benlipkin Feb 14, 2023
9f37578
account for num processe in tqdm total
benlipkin Feb 14, 2023
2c734f2
specify python and torch version to user as well as how to suppress t…
benlipkin Feb 16, 2023
8345c89
type hinting refactor for version compatibility; code formatting with…
benlipkin Feb 16, 2023
1ecddbd
added note about OOM error to README
benlipkin Feb 27, 2023
26ac685
update setup.py license
benlipkin Mar 2, 2023
74868ef
added note reflecting resolution of tensorflow cuda oom exception
benlipkin Mar 2, 2023
f9b9379
new dataset warning more friendly for contributors
benlipkin Mar 2, 2023
62934c2
isort imports
benlipkin Mar 2, 2023
38a5fcd
added task description to docs
benlipkin Mar 2, 2023
be12ec5
Merge branch 'main' into DS-1000
benlipkin Mar 2, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,20 @@ Install [`torch`](https://pytorch.org/get-started/locally/) based on your device
```
pip install -r requirements.txt
```
To run the `DS-1000` benchmark, additional constraints must be resolved.
```
# python version must be 3.7.10
pip install -e ".[ds1000]" # installs all additional dependencies except PyTorch
# torch==1.12.1 required. Download version with relevant GPU support etc., e.g.,
pip install torch==1.12.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116

# to suppress any tensorflow optimization warnings,
# precede call to "accelerate launch" with "TF_CPP_MIN_LOG_LEVEL=3"

# on some systems, tensorflow will attempt to allocate all GPU memory
# to its process at import which will raise a CUDA out-of-memory error
# setting "export TF_FORCE_GPU_ALLOW_GROWTH=true" resolves this
```
Also make sure you have `git-lfs` installed and are logged in the Hub
```
huggingface-cli login
Expand Down
23 changes: 23 additions & 0 deletions docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,29 @@ accelerate launch main.py \
We expect a model [finetuned](https://github.com/bigcode-project/bigcode-evaluation-harness/tree/main/finetuning/APPS) on the train split of APPS.
TODO: add few-shot setup for APPS.

### DS-1000
[DS-1000](https://ds1000-code-gen.github.io/): Code generation benchmark with 1000 data science questions spanning seven Python libraries that (1) reflects diverse, realistic, and practical use cases, (2) has a reliable metric, (3) defends against memorization by perturbing questions.

The task can be specified as `--tasks ds1000-$SUBSET-$MODE`, where subset can include `all` libraries or any of the following subsets: `numpy`, `scipy`, `pandas`, `tensorflow`, `pytorch`, `sklearn`, `matplotlib`. Supported generation modes are `completion` (purely autoregressive) or `insertion` (via fill-in-middle [FIM]).

- Prompts & Generation: prompts include partial code with one or more missing lines. The form of such prompts varies between `completion` and `insertion` modes (`[insert]` token used to reflect FIM region). Default generation args are reflected below.
- Evaluation: generations are evaluated via execution of unit tests. As in the original manuscript, $pass@1$ is evaluated over each of `num_samples` and the mean pass rate is returned as the metric. Default evaluation args are presented below.

Below is the command to run evaluation on the full benchmark in insertion mode with the arguments that correspond to the original manuscript.

```bash
export TF_FORCE_GPU_ALLOW_GROWTH=true
TF_CPP_MIN_LOG_LEVEL=3 accelerate launch main.py \
--model <MODEL_NAME> \
--batch_size <BATCH_SIZE> \
--tasks ds1000-all-insertion \
--n_samples 40 \
--max_length_generation 1024 \
--temperature 0.2 \
--top_p 0.95 \
--allow_code_execution
```

## Code generation benchmarks without unit tests

For these tasks, we do single generations and compare the generated code against reference solutions and compute BLEU score. For the following tasks, we use a two-shot setting where we include 2 inputs and their solutions in the prompt, all preceded by an instruction such as: ` "Answer the following instructions in a one line SQL query:\n"`. The solutions consist of one line so we stop the generation when a new line is generated. 3 languages are present: Python, SQL and Java.
Expand Down
8 changes: 7 additions & 1 deletion lm_eval/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import abstractmethod, ABC
from datasets import load_dataset
from warnings import warn


class Task(ABC):
Expand All @@ -22,7 +23,12 @@ def __init__(self, stop_words=None, requires_execution=True):
"""
self.stop_words = stop_words
self.requires_execution = requires_execution
self.dataset = load_dataset(path=self.DATASET_PATH, name=self.DATASET_NAME)
try:
self.dataset = load_dataset(path=self.DATASET_PATH, name=self.DATASET_NAME)
except:
warn(
"This task will use a locally downloaded dataset, not from the HF hub."
)

@abstractmethod
def get_dataset(self):
Expand Down
4 changes: 3 additions & 1 deletion lm_eval/tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pprint import pprint

from . import apps, codexglue_code_to_text, conala, concode, humaneval, mbpp, codexglue_text_to_text
from . import (apps, codexglue_code_to_text, codexglue_text_to_text, conala,
concode, ds1000, humaneval, mbpp)

TASK_REGISTRY = {
**apps.create_all_tasks(),
Expand All @@ -9,6 +10,7 @@
"codexglue_code_to_text-python-left": codexglue_code_to_text.LeftCodeToText,
"conala": conala.Conala,
"concode": concode.Concode,
**ds1000.create_all_tasks(),
benlipkin marked this conversation as resolved.
Show resolved Hide resolved
"humaneval": humaneval.HumanEval,
"mbpp": mbpp.MBPP,
}
Expand Down
18 changes: 12 additions & 6 deletions lm_eval/tasks/codexglue_code_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import os
import re
import typing

from mosestokenizer import MosesDetokenizer

Expand Down Expand Up @@ -53,28 +54,33 @@ def __init__(self):
return CodeToText


def compute_codexglue_code_to_text_bleu(gold_and_predicted_items: list[tuple[str, str]]):
def compute_codexglue_code_to_text_bleu(
gold_and_predicted_items: typing.List[typing.Tuple[str, str]]
):
"""
Compute BLEU scores using codexglue_code_to_text_bleu.computeMaps (codexglue_summarization_evaluator)
Compute BLEU scores using codexglue_code_to_text_bleu.computeMaps (codexglue_summarization_evaluator)
This uses a specific BLEU tokenization and preprocessing necessary for this task by
the original authors of the dataset.

Taken from: https://github.com/dpfried/lm-evaluation-harness/blob/5d9a6aaaaa929bcad95bb73d85e78fe75eb64b4e/lm_eval/tasks/codexglue_summarization.py#L102
"""
from lm_eval.tasks.custom_metrics import codexglue_code_to_text_bleu

predicted_map = {}
gold_map = {}

for ix, (gold_str, predicted_str) in enumerate(gold_and_predicted_items):
gold, *rest = gold_str.strip().split('\t')
gold, *rest = gold_str.strip().split("\t")
if len(rest) > 0:
print(f"warning: gold instance {ix} contains a tab; ignoring text after")
gold_map[ix] = [codexglue_code_to_text_bleu.splitPuncts(gold.strip().lower())]
pred, *rest = predicted_str.strip().split('\t')

pred, *rest = predicted_str.strip().split("\t")
if len(rest) > 0:
print(f"warning: gold instance {ix} contains a tab; ignoring text after")
predicted_map[ix] = [codexglue_code_to_text_bleu.splitPuncts(pred.strip().lower())]
predicted_map[ix] = [
codexglue_code_to_text_bleu.splitPuncts(pred.strip().lower())
]

return codexglue_code_to_text_bleu.bleuFromMaps(gold_map, predicted_map)[0]

Expand Down
191 changes: 191 additions & 0 deletions lm_eval/tasks/ds1000.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
"""
DS-1000: A Natural and Reliable Benchmark for Data Science Code Generation

https://arxiv.org/pdf/2211.11501.pdf

DS-1000 is a code generation benchmark with a thousand data science questions spanning seven Python libraries that (1) reflects diverse, realistic, and practical use cases, (2) has a reliable metric, (3) defends against memorization by perturbing questions.

Homepage: https://ds1000-code-gen.github.io/
"""

import fcntl
import functools
import io
import itertools
import pathlib
import warnings
import zipfile

import requests
import tqdm

from lm_eval.base import Task

_CITATION = """
@article{Lai2022DS1000,
title={DS-1000: A Natural and Reliable Benchmark for Data Science Code Generation},
author={Yuhang Lai and Chengxi Li and Yiming Wang and Tianyi Zhang and Ruiqi Zhong and Luke Zettlemoyer and Scott Wen-tau Yih and Daniel Fried and Sida Wang and Tao Yu},
journal={ArXiv},
year={2022},
volume={abs/2211.11501}
}
"""


def create_all_tasks():
def create_task(key, mode):
class DS1000(GeneralDS1000):
def __init__(self):
super().__init__(key, mode)

return DS1000

return {
f"ds1000-{key.lower()}-{mode.lower()}": create_task(key, mode)
for key in [
"All",
"Numpy",
"Pandas",
"Scipy",
"Matplotlib",
"Sklearn",
"Tensorflow",
"Pytorch",
]
for mode in ["Completion", "Insertion"]
}


class GeneralDS1000(Task):
DATASET_PATH = None
DATASET_NAME = None

def __init__(self, key, mode):
super().__init__(
stop_words=["</code>", "# SOLUTION END"], requires_execution=True
)
self._key = key
self._mode = mode
if self._key == "Matplotlib" and self._mode == "Insertion":
warnings.warn("Insertion not supported for Matplotlib. Running Completion.")
self._mode = "Completion"
self._dir = pathlib.Path(__file__).parent / "ds"
self._dir.mkdir(parents=True, exist_ok=True)
self._src = self._dir / "ds1000.py"
self._data = self._dir / "ds1000_data"
self._download_source()
self._download_dataset()

def _download_source(self):
url = "https://github.com/HKUNLP/DS-1000/blob/49c1c543ada8b58138181333cdc62e613204efcf/ds1000.py?raw=true"
lock = self._src.with_suffix(".lock")
with open(lock, "w") as f_lock:
fcntl.flock(f_lock, fcntl.LOCK_EX)
if not self._src.exists():
warnings.warn(f"DS-1000 source is being saved to {self._src}.")
print("Downloading source code...")
r = requests.get(url, stream=True)
with open(self._src, "wb") as f_src:
f_src.write(r.content)
open(self._src.parent / "__init__.py", "w").close()
print("Done.")
fcntl.flock(f_lock, fcntl.LOCK_UN)

def _download_dataset(self):
url = "https://github.com/HKUNLP/DS-1000/blob/49c1c543ada8b58138181333cdc62e613204efcf/ds1000_data.zip?raw=true"
lock = self._data.with_suffix(".lock")
with open(lock, "w") as f_lock:
fcntl.flock(f_lock, fcntl.LOCK_EX)
if not self._data.exists():
warnings.warn(f"DS-1000 data is being saved to {self._data}.")
print("Downloading dataset...")
r = requests.get(url, stream=True)
z = zipfile.ZipFile(io.BytesIO(r.content))
z.extractall(self._dir)
print("Done.")
fcntl.flock(f_lock, fcntl.LOCK_UN)

@functools.lru_cache()
def get_dataset(self):
"""Returns dataset for the task or an iterable of any object, that get_prompt can handle"""
from .ds.ds1000 import DS1000Dataset

data = DS1000Dataset(self._data, mode=self._mode).data
if self._key == "All":
if self._mode == "Insertion":
warnings.warn(
"Insertion not supported for Matplotlib. Only running others."
)
data = {k: v for k, v in data.items() if k != "Matplotlib"}
dataset = list(itertools.chain(*data.values()))
else:
dataset = data[self._key]
return dataset

def get_prompt(self, doc):
"""
Builds the prompt for the LM to generate from.
:param doc: dict[str: str]
sample from the test dataset
:return: str | dict[str: str]
"""
if self._mode == "Completion":
return doc["prompt"]
elif self._mode == "Insertion":
prefix, suffix = doc["prompt"].split("[insert]")
benlipkin marked this conversation as resolved.
Show resolved Hide resolved
prefix = f"{prefix.strip()}\n"
suffix = f"\n{suffix.strip()}\n"
return {"prefix": prefix, "suffix": suffix}
else:
raise ValueError(f"Invalid mode: {self._mode}")

def get_reference(self, doc):
"""
Builds the reference solution for the doc (sample from the test dataset).
:param doc: dict[str: str]
sample from the test dataset
:return: str
"""
return doc["reference_code"]

def postprocess_generation(self, generation, idx):
"""
Defines the postprocessing for a LM generation.
:param generation: str
code generation from LM
:param idx: int (if needed)
index of doc in the dataset to which the generation belongs
:return: str
"""
if self._mode == "Completion":
for start in ["BEGIN SOLUTION\n<code>", "# SOLUTION START"]:
try:
generation = generation.split(start, 1)[-1]
except IndexError:
pass
for stop in self.stop_words:
generation = generation.split(stop)[0]
return generation.strip()

def process_results(self, generations, references):
"""
Takes the list of LM generations and evaluates them against ground truth references,
returning the metric for the generations as in {"metric_name": result}.
We encourage to directly load the metric from `evaluate` library to keep the code concise.
:param generations: list(list(str))
list of lists containing generations
:param references: list(str)
list of str containing refrences
:return: dict[str: float]
"""
dataset = self.get_dataset()
num_correct = 0
print("Scoring generations...")
for i, ref in tqdm.tqdm(enumerate(references), total=len(references)):
test = [doc for doc in dataset if doc["reference_code"] == ref][0]
for gen in generations[i]:
is_correct = test.test(gen)
if is_correct:
num_correct += 1
accuracy = num_correct / len(references) / len(generations[0])
return {f"mean pass@1 accuracy ({len(generations[0])} samples)": accuracy}
Loading