# Fill-in-the Middle Dataset Builder

This notebook builds a dataset of prompts to test fill-in-the-middle (FIM) capabilities of code generation
models. This version only supports single-line FIM tasks.

In [1]:
from pathlib import Path
import sys
import itertools
from tqdm import tqdm
import random
import json

# https://stackoverflow.com/questions/61058798/python-relative-import-in-jupyter-notebook
sys.path.append(str(Path(".").absolute().parent))
from util import gunzip_json

To build this dataset, we adapt the approach described in the MBXP paper. We 
find working solutions to the MultiPL problems generated by code generation
models that pass all tests. However, we also strip all comments that
appear in the solution body.

In [3]:
def is_comment(line: str, language: str):
    """
    Given a line of code in language, determine if it is a single line comment.
    """
    COMMENT_PREFIX = {
        "java": "//",
        "js": "//",
        "py": "#"
    }
    return line.lstrip().startswith(COMMENT_PREFIX[language])

def find_working_solution(completions_path: Path):
    """
    completions_path is named FILENAME.json.gz. Open FILENAME.results.json.gz and find the index of 
    an entry in the "results" array that has `"status": "OK"`. If found, open the completions_path
    return the corresponding entry in  the "completions" array. Otherwise, return None.
    """
    filename = completions_path.name[:-len(".json.gz")]
    results_json = gunzip_json(completions_path.parent / (filename + ".results.json.gz"))
    if results_json is None:
        return None
    for i, result in enumerate(results_json["results"]):
        if result["status"] == "OK":
            completions_json = gunzip_json(completions_path)
            language = completions_json["language"]
            body = completions_json["completions"][i]
            # In MultiPL-E, we use "}" as a stop token for Java, thus need to append these to
            # get a complete program.
            if language == "java":
                body = body + "\n    }\n}"
            return {
                "language": language,
                "prompt": completions_json["prompt"],
                "body": [ line for line in body.split("\n") if not is_comment(line, language) ],
                "tests": completions_json["tests"]
            }
    return None


def build_fim_tasks(name_prefix: str, completions_path: Path):
    """
    Produces FIM problems for completions_path, and returns [] if no working solution
    is found.

    Each returned problem is a dictionary of strings:

    {
        "name": f"{name_prefix}-{problem_name}-L{line_number}",
        "language": language,
        "prompt": prompt,
        "suffix": suffix,
        "canonical_solution": line,
        "tests": tests
    }

    The prompt includes the original prompt and the prefix. The canonical solution is a single
    line, and we ensure that it is never just whitespace.
    """
    working_solution = find_working_solution(completions_path)
    if working_solution is None:
        return
    language = working_solution["language"]
    prompt = working_solution["prompt"]
    lines = working_solution["body"]
    prefix = working_solution["prompt"]
    tests = working_solution["tests"]
    line_offset = len(prefix.split("\n"))
    for line_number, line in enumerate(lines):
        if line.strip() == "":
            continue
        yield {
            "name": f"{name_prefix}-{completions_path.stem}-L{line_number + line_offset}",
            "language": language,
            "prompt": prompt + ("\n".join(lines[:line_number])),
            "suffix": "\n".join(lines[line_number + 1:]),
            "canonical_solution": line,
            "tests": tests,
        }

def gather_benchmarks(root: Path, languages):
  generators = [
    (root / f"humaneval-{lang}-davinci-0.8-reworded").glob("*.json.gz") for
    lang in languages
  ]
  for path in tqdm(itertools.chain(*generators)):
    if path.stem.endswith(".results"):
        continue
    yield from build_fim_tasks("humaneval", path)

## Building the Benchmarks

The line below hard-codes the path to MultiPL-E completions.

In [4]:
benches = list(gather_benchmarks(Path("../../../datasets/MultiPL-E-experiments/"), ["java", "js", "py"]))
len(benches)

956it [00:03, 270.57it/s]


4792

## Eyeball Validation

I found a solution with a blank line, and the examples below show that we skip it.

In [5]:
def print_bench(bench):
    print(bench["prompt"])
    print(">>>>>>>>>>>>>>>>>>>>>>>>>")
    print(bench["canonical_solution"])
    print(">>>>>>>>>>>>>>>>>>>>>>>>>")
    print(bench["suffix"])

examples = [ bench for bench in benches if bench["language"] == "js" and "is_prime" in bench["name"] ]
len(examples)

17

In [89]:
print_bench(examples[0])

//Return true if a given number is prime, and false otherwise.
// >>> is_prime(6)
// false
// >>> is_prime(101)
// true
// >>> is_prime(11)
// true
// >>> is_prime(13441)
// true
// >>> is_prime(61)
// true
// >>> is_prime(4)
// false
// >>> is_prime(1)
// false
function is_prime(n){

>>>>>>>>>>>>>>>>>>>>>>>>>
    if (n==1) {
>>>>>>>>>>>>>>>>>>>>>>>>>
        return false;
    }

    if (n==2) {
        return true;
    }

    if (n%2==0) {
        return false;
    }

    var limit = Math.sqrt(n);
    for (var i = 3; i <= limit; i += 2) {
        if (n%i==0) {
            return false;
        }
    }

    return true;
}



In [90]:
print_bench(examples[16])

//Return true if a given number is prime, and false otherwise.
// >>> is_prime(6)
// false
// >>> is_prime(101)
// true
// >>> is_prime(11)
// true
// >>> is_prime(13441)
// true
// >>> is_prime(61)
// true
// >>> is_prime(4)
// false
// >>> is_prime(1)
// false
function is_prime(n){
    if (n==1) {
        return false;
    }

    if (n==2) {
        return true;
    }

    if (n%2==0) {
        return false;
    }

    var limit = Math.sqrt(n);
    for (var i = 3; i <= limit; i += 2) {
        if (n%i==0) {
            return false;
        }
    }

    return true;
>>>>>>>>>>>>>>>>>>>>>>>>>
}
>>>>>>>>>>>>>>>>>>>>>>>>>



In [91]:
print_bench(examples[3])

//Return true if a given number is prime, and false otherwise.
// >>> is_prime(6)
// false
// >>> is_prime(101)
// true
// >>> is_prime(11)
// true
// >>> is_prime(13441)
// true
// >>> is_prime(61)
// true
// >>> is_prime(4)
// false
// >>> is_prime(1)
// false
function is_prime(n){
    if (n==1) {
        return false;
    }

>>>>>>>>>>>>>>>>>>>>>>>>>
    if (n==2) {
>>>>>>>>>>>>>>>>>>>>>>>>>
        return true;
    }

    if (n%2==0) {
        return false;
    }

    var limit = Math.sqrt(n);
    for (var i = 3; i <= limit; i += 2) {
        if (n%i==0) {
            return false;
        }
    }

    return true;
}



## Save the Dataset

We shuffle the dataset before saving. This is the stupid hack to reduce peak memory consumption.
Without it, we'll consume the most memory on some awful Java program.

In [97]:
random.shuffle(benches)
with Path("MultiPLE-fim.jsonl").open("wt") as f:
    for bench in benches:
        f.write(json.dumps(bench))
        f.write("\n")