Copyright 2022 Google LLC. SPDX-License-Identifier: Apache-2.0

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

# Experiment: HumanEval Benchmark

This notebook is a part of the open-source code release associated with the paper:

[Code as Policies: Language Model Programs for Embodied Control](https://code-as-policies.github.io/)

This notebook gives the results corresponding to Table III in the paper which evaluates different code-gen approaches on the [HumanEval benchmark](https://github.com/openai/human-eval)

1) Please obtain an OpenAI API Key here:
https://openai.com/blog/openai-api/

2) Gain Codex access by joining the waitlist here:
https://openai.com/blog/openai-codex/

Once you have Codex access you can use `code-davinci-002` as the `model_name`. Using the GPT-3 model (`text-dainvci-002`) is also ok, but performance won't be as good (there will be more code logic errors).

3) Please also specify a location in your Google Drive on which the results will be stored.

Note due to current rate limiting of the Codex API, this entire notebook may take 20+ hours to finish.

In [None]:
openai_api_key = 'YOUR KEY HERE'
model_name = 'code-davinci-002' # 'text-davinci-002'
google_drive_folder = 'drive/MyDrive/...'

# HumanEval CodeGen Benchmark

From [HumanEval](https://github.com/openai/human-eval)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
from pathlib import Path
results_path = Path(google_drive_folder)

# Install HumanEval benchmark

In [None]:
!pip install git+https://github.com/openai/human-eval.git

# Their package also needs some data not installed by pip.
# Get this straight from git cloning the repo,
# and hacking some paths.
!git clone https://github.com/openai/human-eval
!cp -r human-eval/data /usr/local/lib/python3.7/dist-packages

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/openai/human-eval.git
  Cloning https://github.com/openai/human-eval.git to /tmp/pip-req-build-842wcofg
  Running command git clone -q https://github.com/openai/human-eval.git /tmp/pip-req-build-842wcofg
Collecting fire
  Downloading fire-0.4.0.tar.gz (87 kB)
[K     |████████████████████████████████| 87 kB 3.2 MB/s 
Building wheels for collected packages: human-eval, fire
  Building wheel for human-eval (setup.py) ... [?25l[?25hdone
  Created wheel for human-eval: filename=human_eval-1.0-py3-none-any.whl size=7446 sha256=ad3ee05f24b8af20f3e44126c16576373922992534af96274bf4455fb7db8a1b
  Stored in directory: /tmp/pip-ephem-wheel-cache-l44fg2ie/wheels/10/c6/41/a3d3cf28a68aa72be379d082afbafcc713353941c175b69b2d
  Building wheel for fire (setup.py) ... [?25l[?25hdone
  Created wheel for fire: filename=fire-0.4.0-py2.py3-none-any.whl size=115942 sha256=ba

In [None]:
# To make the evaluation actually run,
# we have to uncomment a line that they left commented as a safety.
!sed -i '58s/# //' /usr/local/lib/python3.7/dist-packages/human_eval/execution.py

# Install Codex LLM

In [None]:
! pip install openai ratelimiter
import openai
openai.api_key = openai_api_key

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting openai
  Downloading openai-0.23.0.tar.gz (43 kB)
[K     |████████████████████████████████| 43 kB 1.5 MB/s 
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
Collecting ratelimiter
  Downloading ratelimiter-1.2.0.post0-py3-none-any.whl (6.6 kB)
Collecting pandas-stubs>=1.1.0.11
  Downloading pandas_stubs-1.2.0.62-py3-none-any.whl (163 kB)
[K     |████████████████████████████████| 163 kB 8.9 MB/s 
Building wheels for collected packages: openai
  Building wheel for openai (PEP 517) ... [?25l[?25hdone
  Created wheel for openai: filename=openai-0.23.0-py3-none-any.whl size=54478 sha256=63136fdcb615dbe5cc5cd20fb166b85d6113871c7632a10c9b866d5e8db5e8c6
  Stored in directory: /root/.cache/pip/wheels/70/d5/31/f9f67660319d89e4f54501d27b1e90f88a3309c42ea4fd734c
Succes

In [None]:
from google.colab import output
output.enable_custom_widget_manager()

In [None]:
from copy import copy
from time import sleep
from tqdm.auto import trange, tqdm

import numpy as np

import ast
import astunparse

from pygments import highlight
from pygments.lexers import PythonLexer
from pygments.formatters import TerminalFormatter

from time import time
from ratelimiter import RateLimiter

def limited_cb(until):
    duration = int(round(until - time()))
    print('Rate limited, sleeping for {:d} seconds'.format(duration))

openai_rate_limiter = RateLimiter(max_calls=15, period=60, callback=limited_cb)

def exec_safe(code_str, gvars, lvars):
  banned_phrases = ['import', '__']
  for phrase in banned_phrases:
    assert phrase not in code_str
  
  empty_fn = lambda *args, **kwargs: None
  custom_gvars = merge_dicts([
      gvars,
      {'exec': empty_fn, 'eval': empty_fn}
  ])
  exec(code_str, custom_gvars, lvars)

default_query_kwargs = {
    'engine': model_name,
    'max_tokens': 512,
    'temperature': 0,
    'frequency_penalty': 0,
    'logprobs': 1
}

lmp_cache = {}

def lmp(base_prompt, query, stop_tokens=None, log=False, return_response=False, strip=False, use_cache=False, query_kwargs=None):
    prompt = f'{base_prompt}\n{query}'
    
    if not use_cache or prompt not in lmp_cache:
      use_query_kwargs = copy(default_query_kwargs)
      if query_kwargs is not None:
        use_query_kwargs.update(query_kwargs)
      with openai_rate_limiter:
        while True:
          try:
            result = openai.Completion.create(
                prompt=prompt, stop=stop_tokens, **use_query_kwargs
            )['choices'][0]
            
            response = result['text']
            logp = np.mean(result['logprobs']['token_logprobs'])
            break
          except Exception as e:
            print('got err')
            print(e)
            print('retrying after 10')
            sleep(10)
            continue

      if strip:
        response = response.strip()

      lmp_cache[prompt] = {
          'response': response,
          'mean_log_prob': logp
      }

    response = lmp_cache[prompt]['response']

    if log:
      print(query)
      print(response)

    if return_response:
      return response

def lmp_fgen(prompt, f_name, f_sig, stop_tokens=['# define function:', '# example:'], recurse=False, 
            recurse_level=0, max_recurse_level=4, use_cache=False,
             context_vars=None, bug_fix=False, log=False, return_src=False, query_kwargs=None, strip=False, info=''):
    query = f'# define function: {f_sig}.'
    if info:
      query = f'{query}\n# info: {info}.'
    if query_kwargs is not None:
      query_kwargs = copy(query_kwargs)
      query_kwargs['temperature'] = 0
    f_src = lmp(prompt, query, stop_tokens=stop_tokens, log=False, return_response=True, use_cache=use_cache, query_kwargs=query_kwargs)
    if bug_fix:
        with openai_rate_limiter:
          f_src = openai.Edit.create(
            model='code-davinci-edit-001',
            input='# ' + f_src,
            temperature=0,
            instruction="Fix syntax errors. Keep same inputs and outputs. Only small changes. No comments.",
          )['choices'][0]['text']

    if strip:
      f_src = f_src.strip()

    if context_vars is None:
        context_vars = {}
    gvars = context_vars
    lvars = {}

    f_success = True
    try:
      exec_safe(f_src, gvars, lvars)
      f = lvars[f_name]
    except Exception as e:
      # print('error', f_sig)
      # print(e)
      # print(f_src)
      # print()
      f = lambda *args, **kargs: None   
      f_success = False 

    all_child_fs, all_child_f_srcs = {}, {}
    if recurse and recurse_level < max_recurse_level and f_success:
      # recursively define child_fs in the function body if needed
      f_def_body = None
      for node in ast.parse(f_src).body:
        if isinstance(node, ast.FunctionDef):
          f_def_body = astunparse.unparse(node.body)
      if f_def_body is not None:      
        potential_child_fs, potential_child_f_sigs = {}, {}
        f_parser = FunctionParser(potential_child_fs, potential_child_f_sigs)
        f_parser.visit(ast.parse(f_def_body))
        for potential_child_f_name, potential_child_f_sig in potential_child_f_sigs.items():
          if potential_child_f_name in potential_child_fs:
            potential_child_fs[potential_child_f_name] = potential_child_f_sig

        for child_f_name, child_f_sig in potential_child_fs.items():
          all_vars = merge_dicts([context_vars, all_child_fs, lvars])
          if not var_exists(child_f_name, all_vars):
            child_fs, child_f_srcs = lmp_fgen(
                prompt, child_f_name, child_f_sig, 
                stop_tokens=stop_tokens, 
                context_vars=all_vars, 
                bug_fix=bug_fix,
                log=False, 
                recurse=True,
                recurse_level=recurse_level+1,
                return_src=True,
                use_cache=use_cache,
                query_kwargs=query_kwargs
              )

            all_child_fs.update(child_fs)
            all_child_f_srcs.update(child_f_srcs)

        if len(all_child_fs) > 0:
          # redefine parent f so newly created all_child_fs are in scope
          gvars = merge_dicts([context_vars, all_child_fs])
          lvars = {}
        
          exec_safe(f_src, gvars, lvars)
          
          f = lvars[f_name]

    if log:
        to_print = highlight(f'{query}\n{f_src}', PythonLexer(), TerminalFormatter())
        print(f'LMP FGEN created:\n\n{to_print}\n')

    fs = {
        f_name: f
    }
    fs.update(all_child_fs)

    if return_src:
        f_srcs = {
            f_name: f_src
        }
        f_srcs.update(all_child_f_srcs)

        return fs, f_srcs
    return fs

def lmp_batch(base_prompt, cmds, stop_tokens=['# define'], strip=False, batch_size=20, query_kwargs=None, ret_logprobs=False, use_cache=False):
    prompts = [
      f'{base_prompt}\n{cmd}'
      for cmd in cmds
    ]

    if use_cache:
      prompts_use_idxs = [
          idx for idx, prompt in enumerate(prompts) if prompt not in lmp_cache
      ]
    else:
      prompts_use_idxs = list(range(len(prompts)))

    use_query_kwargs = copy(default_query_kwargs)
    if query_kwargs is not None:
      use_query_kwargs.update(query_kwargs)

    for start_idx in trange(0, len(prompts_use_idxs), batch_size, leave=False):
        end_idx = min(start_idx + batch_size, len(prompts_use_idxs))
        batch_idxs = prompts_use_idxs[start_idx : end_idx]
        batch_prompts = [prompts[idx] for idx in batch_idxs]

        with openai_rate_limiter:
          while True:
            try:
              raw_responses_batch = openai.Completion.create(
                  prompt=batch_prompts, stop=stop_tokens, **use_query_kwargs
              )
              break
            except Exception as e:
              print('got err')
              print(e)
              print('retrying after 10')
              sleep(10)
              continue
            
        responses_batch = [
            r['text']
            for r in raw_responses_batch['choices']
        ]
        mean_logprobs_batch = [
            np.mean(r['logprobs']['token_logprobs'])
            for r in raw_responses_batch['choices']
        ]

        if strip:
            responses_batch = [response.strip() for response in responses_batch]

        for p, r, logp in zip(batch_prompts, responses_batch, mean_logprobs_batch):
          lmp_cache[p] = {
              'response': r,
              'mean_log_prob': logp
          }

    responses = [lmp_cache[p]['response'] for p in prompts]

    if ret_logprobs:
      mean_log_probs = [lmp_cache[p]['mean_log_prob'] for p in prompts]
      return responses, mean_log_probs

    return responses

def lmp_fgen_batch(prompt, prompt_with_comment, queries, stop_tokens=['# define function:', '# example:'], 
                   recurse=False, context_vars=None, log=False, strip=False, query_kwargs=None, ret_logprobs=False):

    f_srcs_list = lmp_batch(prompt, queries, stop_tokens=stop_tokens, query_kwargs=query_kwargs, ret_logprobs=ret_logprobs, use_cache=False)
    if ret_logprobs:
      f_srcs_list, logprobs = f_srcs_list
    for idx, (query, f_src) in enumerate(zip(queries, f_srcs_list)):
      f_srcs_list[idx] = query + f_src
    
    if strip:
      for idx, f_src in enumerate(f_srcs_list):
        f_srcs_list[idx] = f_src.strip()

    if recurse:
      if context_vars is None:
        context_vars = {}

      # recursively define child_fs in the function body if needed
      for idx, f_src in enumerate(f_srcs_list):
        try:
          lvars = {}
          exec(f_src, {}, lvars)

          f_def_body = None
          for node in ast.parse(f_src).body:
            if isinstance(node, ast.FunctionDef):
              f_def_body = astunparse.unparse(node.body)
          assert f_def_body is not None
        except Exception as e:
          # print('err recurse')
          # print(e)
          # print(f_src)
          # print()
          continue

        potential_child_fs, potential_child_f_sigs = {}, {}
        f_parser = FunctionParser(potential_child_fs, potential_child_f_sigs)
        f_parser.visit(ast.parse(f_def_body))
        for potential_child_f_name, potential_child_f_sig in potential_child_f_sigs.items():
          if potential_child_f_name in potential_child_fs:
            potential_child_fs[potential_child_f_name] = potential_child_f_sig

        all_child_fs, all_child_f_srcs = {}, {}
        for child_f_name, child_f_sig in potential_child_fs.items():
          all_vars = merge_dicts([context_vars, all_child_fs, lvars])
          if not var_exists(child_f_name, all_vars):
            child_fs, child_f_srcs = lmp_fgen(
                prompt_with_comment, child_f_name, child_f_sig,
                context_vars=all_vars, 
                log=False,
                recurse=True,
                return_src=True,
                use_cache=True,
                query_kwargs=query_kwargs
              )

            all_child_fs.update(child_fs)
            all_child_f_srcs.update(child_f_srcs)

        if len(all_child_fs) > 0:
          child_f_srcs_str = "\n".join(all_child_f_srcs.values())
          f_srcs_list[idx] = f'{f_src}\n{child_f_srcs_str}'
          
    if log:
      for query, f_src in zip(queries, f_srcs_list):
        to_print = highlight(f_src, PythonLexer(), TerminalFormatter())
        print(f'LMP FGEN created:\n\n{to_print}\n')

    if ret_logprobs:
      return f_srcs_list, logprobs

    return f_srcs_list

class FunctionParser(ast.NodeTransformer):

    def __init__(self, fs, f_assigns):
      super().__init__()
      self._fs = fs
      self._f_assigns = f_assigns

    def visit_Call(self, node):
        self.generic_visit(node)
        if isinstance(node.func, ast.Name):
            f_sig = astunparse.unparse(node).strip()
            f_name = astunparse.unparse(node.func).strip()
            self._fs[f_name] = f_sig
        return node

    def visit_Assign(self, node):
        self.generic_visit(node)
        if isinstance(node.value, ast.Call):
            assign_str = astunparse.unparse(node).strip()
            f_name = astunparse.unparse(node.value.func).strip()
            self._f_assigns[f_name] = assign_str
        return node

def var_exists(name, all_vars):
    try:
        eval(name, all_vars)
    except:
        exists = False
    else:
        exists = True
    return exists

def merge_dicts(dicts):
    return {
        k : v 
        for d in dicts
        for k, v in d.items()
    }

# Run Benchmark

## Prompts

In [None]:
prompt_f_gen_hier = '''
def get_total(xs: List[float]) -> float:
    """Find the sum of a list of numbers called xs.
    """
    return sum(xs)
# end of function

def get_abs_diff_between_means(xs0: List[float], xs1: List[float]) -> float:
    """Get the absolute difference between the means of two lists of numbers.
    """
    m0 = get_mean(xs0)
    m1 = get_mean(xs1)
    return abs(m0 - m1)
# end of function
'''.strip()


prompt_f_gen_hier_comment = '''
# define function: total = get_total(xs).
def get_total(xs):
    return sum(xs)

# define function: diff = get_abs_diff_between_means(xs0, xs1).
def get_abs_diff_between_means(xs0, xs1):
    m0 = get_mean_pure_python(xs0)
    m1 = get_mean_pure_python(xs1)
    return abs(m0 - m1)
'''.strip()

prompt_f_gen_flat = '''
def get_total(xs: List[float]) -> float:
    """Find the sum of a list of numbers called xs.
    """
    return sum(xs)

def get_abs_diff_between_means(xs0: List[float], xs1: List[float]) -> float:
    """Get the absolute difference between the means of two lists of numbers.
    """
    m0 = sum(xs0) / len(xs0)
    m1 = sum(xs1) / len(xs1)
    return abs(m0 - m1)
'''.strip()

## Load Problems

In [None]:
from human_eval.data import write_jsonl, read_problems

problems = read_problems()

task_ids = list(problems.keys())
task_prompts = [
                f"{problem['prompt']}"
                for problem in problems.values()
              ]

print(len(problems))
print(problems['HumanEval/0'].keys())

164
dict_keys(['task_id', 'prompt', 'entry_point', 'canonical_solution', 'test'])


In [None]:
idx = 0
problem = problems[f'HumanEval/{idx}']
solutions = lmp_fgen_batch(prompt_f_gen_hier, prompt_f_gen_hier_comment, [problem['prompt']], stop_tokens=['def', 'if __name__', '# end of function'], recurse=True, log=True)

## Hier greedy

In [None]:
solutions = lmp_fgen_batch(prompt_f_gen_hier, prompt_f_gen_hier_comment, task_prompts, stop_tokens=['def', 'if __name__'], recurse=True)
results = [
    {
        'task_id': task_id,
        'completion': solution
    }
    for task_id, solution in zip(task_ids, solutions)
]
result_path = results_path / "results_hier_greedy.jsonl"
write_jsonl(result_path, results)

! python3 human-eval/human_eval/evaluate_functional_correctness.py "$result_path"

## Flat w/ Prompt greedy

In [None]:
solutions = lmp_fgen_batch(prompt_f_gen_flat, '', task_prompts, stop_tokens=['def', 'if __name__'], recurse=False)
results = [
    {
        'task_id': task_id,
        'completion': solution
    }
    for task_id, solution in zip(task_ids, solutions)
]

result_path = results_path / "results_flat_with_prompt_greedy.jsonl"
write_jsonl(result_path, results)

! python3 human-eval/human_eval/evaluate_functional_correctness.py "$result_path"

## Flat w/o Prompt greedy

In [None]:
solutions = lmp_fgen_batch('', '', task_prompts, stop_tokens=['def', 'if __name__'], recurse=False)
results = [
    {
        'task_id': task_id,
        'completion': solution
    }
    for task_id, solution in zip(task_ids, solutions)
]

result_path = results_path / "results_flat_without_prompt_greedy.jsonl"

write_jsonl(result_path, results)

! python3 human-eval/human_eval/evaluate_functional_correctness.py "$result_path"

## Hier samples

In [None]:
results = []
for _ in trange(100):
  solutions, logprobs = lmp_fgen_batch(prompt_f_gen_hier, prompt_f_gen_hier_comment, task_prompts, stop_tokens=['def', 'if __name__', '# end of function'], recurse=True, query_kwargs={'temperature': 0.8}, ret_logprobs=True)
  results.extend([
      {
          'task_id': task_id,
          'completion': solution,
          'logprob': logprob
      }
      for task_id, solution, logprob in zip(task_ids, solutions, logprobs)
  ])

result_path = results_path / "results_hier_samples.jsonl"

write_jsonl(result_path, results)

! python3 human-eval/human_eval/evaluate_functional_correctness.py "$result_path"

## Flat w/ Prompt samples

In [None]:
results = []
for _ in trange(100):
  solutions, logprobs = lmp_fgen_batch(prompt_f_gen_flat, '', task_prompts, stop_tokens=['def', 'if __name__'], recurse=False, query_kwargs={'temperature': 0.8}, ret_logprobs=True)
  results.extend([
      {
          'task_id': task_id,
          'completion': solution,
          'logprob': logprob
      }
      for task_id, solution, logprob in zip(task_ids, solutions, logprobs)
  ])

result_path = results_path / "results_flat_with_prompt_samples.jsonl"

write_jsonl(result_path, results)

! python3 human-eval/human_eval/evaluate_functional_correctness.py "$result_path"

## Flat w/o Prompt samples

In [None]:
results = []
for _ in trange(100):
  solutions, logprobs = lmp_fgen_batch('', '', task_prompts, stop_tokens=['def', 'if __name__'], recurse=False, query_kwargs={'temperature': 0.8}, ret_logprobs=True)
  results.extend([
      {
          'task_id': task_id,
          'completion': solution,
          'logprob': logprob
      }
      for task_id, solution, logprob in zip(task_ids, solutions, logprobs)
  ])

result_path = results_path / "results_flat_without_prompt_samples.jsonl"

write_jsonl(result_path, results)

! python3 human-eval/human_eval/evaluate_functional_correctness.py "$result_path"