## Meta-generation case study: Python code generation

This notebook will demonstrate __best-of-n__, __self-repair__, and __minimum Bayes risk decoding__ for code generation on the Mostly Basic Python Problems (MBPP) dataset.

In [1]:
import os
import re
import copy

from litellm import completion, ModelResponse
import datasets
import jellyfish
import numpy as np

from typing import Callable, Union
from pprint import pprint

from execute import (
    check_correctness,
    execute_tests,
    execute_codes,
)

from utils import (
    make_prompt,
    extract_code,
    extract_func_calls,
)

First, let's load in the MBPP data. MBPP consists of basic algorithmic Python questions, such as the following:


```python
# Specification
"""
Write a python function to remove first and last occurrence of a given character from the string.
"""

# Ground truth function
def remove_Occ(s,ch): 
    for i in range(len(s)): 
        if (s[i] == ch): 
            s = s[0 : i] + s[i + 1:] 
            break
    for i in range(len(s) - 1,-1,-1):  
        if (s[i] == ch): 
            s = s[0 : i] + s[i + 1:] 
            break
    return s 

# Test Cases
assert remove_Occ("hello","l") == "heo"
assert remove_Occ("abcda","a") == "bcd"
assert remove_Occ("PHP","P") == "H"
```

We'll evaluate outputs based on the execution accuracy on a set of test cases. Specifically, for sampling-based methods, we'll consider the average accuracy across all samples (as there is no ranking over samples), while for meta-decoding methods, we'll consider the accuracy of the top-1 returned output.

In [2]:
# load mbpp data
mbpp = datasets.load_dataset("mbpp", split="test")

# pprint("MBPP example:")
# pprint(mbpp[0])
# print()

prompts = []

# make zero shot prompts for MBPP
for example in mbpp:
    prompts.append(make_prompt(example))

### Sampling

First, consider simply sampling a set of outputs from the model, using temperature & top-$p$ sampling.

In [55]:
# code for sampling from generator

MODEL_NAME = "gpt-3.5-turbo"

def generate_code(prompt: Union[str, list[dict]], **generate_kwargs) -> tuple[list[str], ModelResponse]:
    '''
    Generates code sample(s) for prompt
    Returns list of code samples and OpenAI-like response object.
    
    This function accepts additional keyword arguments for various generation 
    parameters supported by the openai api, including temperature, top_p, logprobs, etc.
    '''
    if isinstance(prompt, str):
        messages = [{"role": "user", "content": prompt}]
    else:
        assert isinstance(prompt, list) and isinstance(prompt[0], dict)
        messages = prompt

    response = completion(MODEL_NAME, messages=messages, **generate_kwargs)

    # post-processing to extract code from chat model outputs
    codes = []
    for choice in response.choices:
        text = choice.message.content
        code = extract_code(text)
        codes.append(code)
    return codes, response

In [56]:
codes, response = generate_code(prompts[0], n=10, temperature=0.8, top_p=0.95, logprobs=True, top_logprobs=1)

assert all(len(c) > 0 for c in codes)

In [59]:
execution_results = execute_tests(codes, mbpp[0]['test_list'])
print("mean pass@1:", sum(result['passed'] for result in execution_results) / len(execution_results))
[result['result'] for result in execution_results]

mean pass@1: 0.6


['failed:  assert remove_Occ("hello","l") == "heo" ',
 'passed',
 'failed:  assert remove_Occ("hello","l") == "heo" ',
 'failed:  assert remove_Occ("hello","l") == "heo" ',
 'passed',
 'passed',
 'failed:  assert remove_Occ("hello","l") == "heo" ',
 'passed',
 'passed',
 'passed']

### Best-of-$n$

$$\hat{y} = \arg \max_{y \in \mathcal{Y}} V(y)$$

As our first meta-decoding algorithm, we consider best-of-$n$ using a sample's log probability under the generator as its value.

In [4]:
def best_of_n_logprob(prompt: str, **generate_kwargs) -> tuple[list[str], list[float]]:
    '''
    Runs best-of-n generation, using mean sequence logprob as value
    Returns list of codes in order of decreasing value
    '''
    generate_kwargs["top_logprobs"] = 1
    generate_kwargs["logprobs"] = True
    # generate samples
    codes, response = generate_code(prompt, **generate_kwargs)
    scores = []
    # compute mean logprob for each sequence
    for choice in response.choices:
        logprobs = choice.logprobs['content']
        mean_logprob = sum(lp['logprob'] for lp in logprobs) / len(logprobs)
        scores.append(mean_logprob)

    sorted_indices = np.argsort(scores)[::-1]
    # arrange codes by decreasing mean log probability
    scores = [scores[i] for i in sorted_indices]
    codes = [codes[i] for i in sorted_indices]
    return codes, scores

In [5]:
codes, scores = best_of_n_logprob(prompts[0], n=10, temperature=0.9)

In [6]:
print(scores)
print(codes[0])

[-0.03816219730931771, -0.046072972968882075, -0.04632711074250333, -0.0723968563913914, -0.0797575224169243, -0.09274121795797267, -0.09437068446480579, -0.12788232105058547, -0.12788232105058547, -0.14468262731788623]
def remove_Occ(s, char):
    first_occ = s.find(char)
    last_occ = s.rfind(char)
    if first_occ != -1:
        s = s[:first_occ] + s[first_occ+1:]
    if last_occ != -1:
        s = s[:last_occ] + s[last_occ+1:]
    return s


In [13]:
execution_results = execute_tests(codes, mbpp[0]['test_list'])
[a['result'] for a in execution_results]

['failed:  assert remove_Occ("hello","l") == "heo" ',
 'passed',
 'failed:  assert remove_Occ("hello","l") == "heo" ',
 'failed:  assert remove_Occ("hello","l") == "heo" ',
 'passed',
 'failed:  assert remove_Occ("hello","l") == "heo" ',
 'passed',
 'failed:  assert remove_Occ("hello","l") == "heo" ',
 'failed:  assert remove_Occ("hello","l") == "heo" ',
 'failed:  assert remove_Occ("hello","l") == "heo" ']

### MBR

\begin{align*}
\hat{y} &= \arg \min_{y \in \mathcal{Y}} R(y) = \arg \max_{y \in \mathcal{Y}} G(y) \\
R(y) &= \sum_{y' \in \mathcal{Y_e}} \ell(y, y') = 1 - G(y)
\end{align*}

In this section, we consider minimum Bayes risk. We consider two choices of gain function: edit similarity or execution equivalence [(Shi et al.)](https://arxiv.org/abs/2204.11454).

In [34]:
PairwiseMetric = Callable[[str, str], float]

def mbr(prompt: str, metric_fn: PairwiseMetric, **generate_kwargs) -> tuple[list[str], list[float]]:
    '''
    Runs MBR decoding with custom pairwise metric
    '''
    codes, response = generate_code(prompt, **generate_kwargs)
    pairwise_scores = np.zeros((len(codes), len(codes)))
    for i1, code1 in enumerate(codes):
        for i2, code2 in enumerate(codes):
            if i1 <= i2:
                sim = metric_fn(code1, code2)
                pairwise_scores[i1, i2] = sim
                pairwise_scores[i2, i1] = sim
    gains = pairwise_scores.mean(axis=-1)
    sorted_indices = np.argsort(gains)[::-1]

    gains = [float(gains[i]) for i in sorted_indices]
    codes = [codes[i] for i in sorted_indices]
    return codes, gains

def edit_sim(code1, code2):
    edit_distance = jellyfish.levenshtein_distance(code1, code2)
    return 1 - edit_distance / max(len(code1), len(code2))

def make_exec_metric(test_list):
    '''
    Make pairwise similarity metric for specific example based on that example's test cases
    '''
    # extract the calls to this example's function
    test_func_calls = extract_func_calls(test_list)
    
    def exec_sim(code1, code2):
        '''
        Runs code1 and code2 on test_func_calls (from the closure of this function)
        Returns proportion of func calls with same execution result
        '''
        result1, result2 = execute_codes([code1, code2], test_func_calls)
        n_same = 0
        for r1, r2 in zip(result1, result2):
            if isinstance(r1, Exception) or isinstance(r2, Exception):
                continue
            if r1 == r2:
                n_same += 1
        similarity = n_same / len(test_func_calls)
        return similarity
    return exec_sim

In [41]:
# MBR-edit-sim
codes, gains = mbr(prompts[0], edit_sim, n=10, temperature=0.9)
print(gains)
print(codes[0])

[0.6921033971312858, 0.6921033971312858, 0.6415651261880847, 0.6163579617256353, 0.6122981366459628, 0.6076617274005611, 0.6027653814183951, 0.5051370324289156, 0.4704761904761905, 0.42899317679638865]
def remove_Occ(s, char):
    first_occ = s.find(char)
    last_occ = s.rfind(char)
    
    if first_occ != -1:
        s = s[:first_occ] + s[first_occ+1:]
    if last_occ != -1:
        s = s[:last_occ] + s[last_occ+1:]
    
    return s


In [42]:
execution_results = execute_tests(codes, mbpp[0]['test_list'])
[a['result'] for a in execution_results]

['failed:  assert remove_Occ("hello","l") == "heo" ',
 'failed:  assert remove_Occ("hello","l") == "heo" ',
 'passed',
 'passed',
 'failed:  assert remove_Occ("hello","l") == "heo" ',
 'failed:  assert remove_Occ("hello","l") == "heo" ',
 'passed',
 'passed',
 'passed',
 'passed']

In [37]:
# MBR-exec
codes, gains = mbr(prompts[0], make_exec_metric(mbpp[0]['test_list']), n=10, temperature=0.9)
print(gains)
print(codes[0])

[0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.36666666666666664, 0.36666666666666664, 0.36666666666666664, 0.29999999999999993]
def remove_Occ(s, char):
    first_index = s.find(char)
    last_index = s.rfind(char)
    
    if first_index != -1:
        s = s[:first_index] + s[first_index+1:]
    if last_index != -1:
        s = s[:last_index] + s[last_index+1:]
    
    return s


In [38]:
execution_results = execute_tests(codes, mbpp[0]['test_list'])
[a['result'] for a in execution_results]

['failed:  assert remove_Occ("hello","l") == "heo" ',
 'failed:  assert remove_Occ("hello","l") == "heo" ',
 'failed:  assert remove_Occ("hello","l") == "heo" ',
 'failed:  assert remove_Occ("hello","l") == "heo" ',
 'failed:  assert remove_Occ("hello","l") == "heo" ',
 'failed:  assert remove_Occ("hello","l") == "heo" ',
 'passed',
 'passed',
 'passed',
 'failed:  assert remove_Occ("hello","l") == "heo" ']

### Self-repair

Self-repair (or self-debugging) is a refinement strategy in which the code LLM iteratively repairs its own previously-generated code based on compiler or execution error messages. Here, the __initial generation__ is the code LLM's initial code samples and __feedback__ is derived from the code execution environment. Finally, the model __refines__ its generation by (optionally) explaining the previous turn's error and writing an improved piece of code.

\begin{align*}
    y^{(0)}&\sim g_0(y|x) \tag{initial generation}\\
    z^{(t)}&\sim h(z|x,y^{(<t)},z^{(<t)}) \tag{feedback}\\
    y^{(t)}&\sim g(y|x,y^{(<t)},z^{(\leq t)}) \tag{refinement}
\end{align*}

Self-repair has been shown to be effective for a variety of code generation tasks [(Chen et al.)](https://arxiv.org/abs/2304.05128). However, for more challenging tasks, it has been found to have some limitations and may not always be more effective than best-of-N [(Olausson et al.)](https://arxiv.org/abs/2306.09896).

In [43]:
# Example repair prompt from Olausson et al (https://arxiv.org/abs/2306.09896).
repair_template = [
    {
        "role": "system",
        "content": "You are a helpful programming assistant and an expert Python programmer. "
                   "You are helping a user write a program to solve a problem. "
                   "The user has written some code, but it has some errors and is not passing the tests. "
                   "You will help the user by first giving a concise (at most 2-3 sentences) textual explanation "
                   "of what is wrong with the code. After you have pointed out what is wrong with the code, "
                   "you will then generate a fixed version of the program. "
                   "Put your fixed program within code delimiters, for example: ```\n# YOUR CODE HERE\n```.",
    },
    {
        "role": "user",
        "content": "### Problem\n"
                   "{problem}\n\n"
                   "### Incorrect Code:\n"
                   "{code}\n\n"
                   "###Error\n"
                   "{error}"
    }
]

def run_repair(code: str, problem: str, error_msg: str, **generate_kwargs) -> str:
    '''
    Runs a single iteration of refinement given the previous round's code's error message
    '''
    repair_prompt = copy.deepcopy(repair_template)
    repair_prompt[-1]['content'] = repair_prompt[-1]['content'].format(problem=problem, code=code, error=error_msg)

    new_code, response = generate_code(repair_prompt, **generate_kwargs)
    return new_code, response


def self_repair(prompt: str, tests: list[str], max_rounds: int, **generate_kwargs) -> list[str]:
    '''
    Runs generation with self-repair for a given prompt and tests.
    '''
    problem = prompt.rsplit("\n", 1)[0] # remove the additional instruction at the end of the prompt
    codes, response = generate_code(prompt, **generate_kwargs)

    execution_results = execute_codes(codes, tests)

    refine_history = []

    refine_history.append((codes, response, copy.deepcopy(execution_results)))
    for round in range(max_rounds):
        refined_codes = []
        refined_responses = []
        for code, exec_result in zip(codes, execution_results):
            if exec_result['passed']:
                refined_codes.append(None)
                refined_responses.append(None)
                continue

            refined_code, refined_response = run_repair(code, exec_result['result'], n=1)
            refined_codes.append(refined_code[0])
            refined_responses.append(refined_response)
        
        execution_results = execute_codes(refined_codes, tests)
        codes = refined_code
        refine_history.append((refined_codes, refined_responses, copy.deepcopy(execution_results)))

    return refine_history

Observe that the model's earlier initial generation did not pass the tests:

In [51]:
print("Code:\n" + codes[0] + "\n")
print("Result:\n" + execution_results[0]['result'])

Code:
def remove_Occ(s, char):
    first_occ = s.find(char)
    last_occ = s.rfind(char)
    
    if first_occ != -1:
        s = s[:first_occ] + s[first_occ+1:]
    if last_occ != -1:
        s = s[:last_occ] + s[last_occ+1:]
    
    return s

Result:
failed:  assert remove_Occ("hello","l") == "heo" 


Now, let's try an iteration of refinement:

In [52]:
problem = prompts[0].rsplit("\n", 1)[0]
repair_code, repair_response = run_repair(codes[0], problem, execution_results[0]['result'], n=1)

The new result:

In [53]:
execute_tests(repair_code, mbpp[0]['test_list'])

[{'task_id': None, 'passed': True, 'result': 'passed', 'completion_id': None}]

And the LLM's reasoning for its refinement:

In [54]:
print(repair_response.choices[0].message.content)

The issue with the code is that it removes the last occurrence of the character before removing the first occurrence. This causes the removal of the first occurrence to shift the index of the last occurrence, resulting in the incorrect output. To fix this, you should remove the last occurrence first and then the first occurrence.

```python
def remove_Occ(s, char):
    last_occ = s.rfind(char)
    first_occ = s.find(char)
    
    if last_occ != -1:
        s = s[:last_occ] + s[last_occ+1:]
    if first_occ != -1:
        s = s[:first_occ] + s[first_occ+1:]
    
    return s
```
