## Meta-generation case study: Python code generation on MBPP

This notebook will demonstrate __best-of-n__,  __minimum Bayes risk decoding__, and __self-repair__ for code generation on the Mostly Basic Python Problems (MBPP) dataset.

In [1]:
import os
import re
import copy

from litellm import completion, ModelResponse
import datasets
import jellyfish # levenshtein distance for MBR utility
import numpy as np

from typing import Callable, Union
from pprint import pprint

from mbpp_utils import (
    # code execution helpers
    check_correctness,
    execute_tests,
    execute_codes,
    
    # string processing helpers
    make_prompt,
    extract_code,
    extract_func_calls,
)

First, let's load in the MBPP data. MBPP consists of basic algorithmic Python questions, such as the following:

```python
# Input: Specification
"""
Write a python function to remove first and last occurrence of a given character from the string.
"""

# Output: Python function
def remove_Occ(s,ch): 
    for i in range(len(s)): 
        if (s[i] == ch): 
            s = s[0 : i] + s[i + 1:] 
            break
    for i in range(len(s) - 1,-1,-1):  
        if (s[i] == ch): 
            s = s[0 : i] + s[i + 1:] 
            break
    return s 

# Public test cases
assert remove_Occ("hello","l") == "heo"
assert remove_Occ("abcda","a") == "bcd"
assert remove_Occ("PHP","P") == "H"

# Challenge test cases
assert remove_Occ("hellolloll","l") == "helollol"
assert remove_Occ("","l") == ""
```

For each example, we will only allow the model to have access to the first test case; we will keep the other tests hidden for evaluation only.

We'll evaluate outputs based on their execution accuracy on tests. For sampling-based methods, we'll consider the average execution accuracy across all samples (as there is no ranking over samples), while for meta-decoding methods, we'll consider only the execution accuracy of the top-1 returned output.

In [2]:
# load mbpp data
mbpp = datasets.load_dataset("mbpp", split="test")

# pprint("MBPP example:")
# pprint(mbpp[0])
# print()

prompts = []

# make zero shot prompts for MBPP
for example in mbpp:
    prompts.append(make_prompt(example))

In [3]:
example

{'task_id': 11,
 'text': 'Write a python function to remove first and last occurrence of a given character from the string.',
 'code': 'def remove_Occ(s,ch): \r\n    for i in range(len(s)): \r\n        if (s[i] == ch): \r\n            s = s[0 : i] + s[i + 1:] \r\n            break\r\n    for i in range(len(s) - 1,-1,-1):  \r\n        if (s[i] == ch): \r\n            s = s[0 : i] + s[i + 1:] \r\n            break\r\n    return s ',
 'test_list': ['assert remove_Occ("hello","l") == "heo"',
  'assert remove_Occ("abcda","a") == "bcd"',
  'assert remove_Occ("PHP","P") == "H"'],
 'test_setup_code': '',
 'challenge_test_list': ['assert remove_Occ("hellolloll","l") == "helollol"',
  'assert remove_Occ("","l") == ""']}

### Parallel Sampling

First, consider simply sampling a set of outputs from the model in parallel, using temperature & top-$p$ sampling.

In [8]:
# Let's set some sampling parameters that we'll keep fixed throughtout this demo
n = 30             # number of outputs to sample per input
temperature = 0.7  # sampling temperature
top_p = 0.95       # top-p cutoff
RANDOM_SEED = 1618 # fixed random seed to avoid nondeterminism

In [9]:
# code for sampling from generator

MODEL_NAME = "gpt-3.5-turbo-0125"

def generate_code(prompt: Union[str, list[dict]], **generate_kwargs) -> tuple[list[str], ModelResponse]:
    '''
    Generates code sample(s) for prompt
    Returns list of code samples and OpenAI-like response object.
    
    This function accepts additional keyword arguments for various generation 
    parameters supported by the openai api, including temperature, top_p, logprobs, etc.

    Note that we use the litellm library for model inference, which is compatible with a 
    wide array of model frameworks and APIs. See this link for more:
    https://github.com/BerriAI/litellm?tab=readme-ov-file#supported-providers-docs.
    While we assume the inference provider is OpenAI, the code in this demo could 
    also work for other backends with some minor modifications.
    '''
    if isinstance(prompt, str):
        messages = [{"role": "user", "content": prompt}]
    else:
        assert isinstance(prompt, list) and isinstance(prompt[0], dict)
        messages = prompt

    response = completion(MODEL_NAME, messages=messages, **generate_kwargs)

    # post-processing to extract code from chat model outputs
    codes = []
    for choice in response.choices:
        text = choice.message.content
        code = extract_code(text)
        codes.append(code)
    return codes, response

In [10]:
# Do parallel sampling
codes, response = generate_code(prompts[0], n=n, temperature=temperature, top_p=top_p, seed=RANDOM_SEED)

In [15]:
# Evaluate generated samples
execution_results = execute_tests(codes, mbpp[0]['test_list'])
print("mean pass@1:", sum(result['passed'] for result in execution_results) / len(execution_results))
[result['result'] for result in execution_results]

mean pass@1: 0.5333333333333333


['failed:  assert remove_Occ("hello","l") == "heo" ',
 'passed',
 'passed',
 'passed',
 'failed:  assert remove_Occ("hello","l") == "heo" ',
 'passed',
 'failed: substring not found',
 'passed',
 'failed:  assert remove_Occ("hello","l") == "heo" ',
 'failed:  assert remove_Occ("hello","l") == "heo" ',
 'passed',
 'failed:  assert remove_Occ("hello","l") == "heo" ',
 'passed',
 'failed:  assert remove_Occ("hello","l") == "heo" ',
 'passed',
 'failed:  assert remove_Occ("hello","l") == "heo" ',
 'failed:  assert remove_Occ("hello","l") == "heo" ',
 'failed:  assert remove_Occ("hello","l") == "heo" ',
 'passed',
 'passed',
 'passed',
 'passed',
 'failed:  assert remove_Occ("hello","l") == "heo" ',
 'failed:  assert remove_Occ("hello","l") == "heo" ',
 'failed:  assert remove_Occ("hello","l") == "heo" ',
 'passed',
 'passed',
 'passed',
 'failed:  assert remove_Occ("hello","l") == "heo" ',
 'passed']

Of the 30 samples, only 16 pass all the tests. This isn't great -- if we were to randomly sample a program, there would only be a slightly over 50% chance of it being correct :(

### Best-of-$n$

$$\hat{y} = \arg \max_{y \in \mathcal{Y}} v(y)$$

As our first meta-decoding algorithm, we consider best-of-$n$ using a sample's log probability under the generator as its value, i.e. $v(y) = p_\theta(y\,|\,x)$.

In [18]:
def best_of_n_logprob(prompt: str, **generate_kwargs) -> tuple[list[str], list[float]]:
    '''
    Runs best-of-n generation, using mean sequence logprob as value
    Returns list of codes in order of decreasing value
    '''
    # force these arguments to be passed 
    generate_kwargs["top_logprobs"] = 1
    generate_kwargs["logprobs"] = True
    # generate samples
    codes, response = generate_code(prompt, **generate_kwargs)
    scores = []
    # compute mean logprob for each sequence
    for choice in response.choices:
        logprobs = choice.logprobs['content']
        mean_logprob = sum(lp['logprob'] for lp in logprobs) / len(logprobs)
        scores.append(mean_logprob)

    sorted_indices = np.argsort(scores)[::-1]
    # arrange codes by decreasing mean log probability
    scores = [scores[i] for i in sorted_indices]
    codes = [codes[i] for i in sorted_indices]
    return codes, scores

In [19]:
# Run best-of-n
codes, scores = best_of_n_logprob(prompts[0], n=n, temperature=temperature, top_p=top_p, seed=RANDOM_SEED)

In [22]:
# Evaluate best-of-n
execution_results = execute_tests(codes, mbpp[0]['test_list'])
[result['result'] for result in execution_results]

['failed:  assert remove_Occ("hello","l") == "heo" ',
 'failed:  assert remove_Occ("hello","l") == "heo" ',
 'failed:  assert remove_Occ("hello","l") == "heo" ',
 'failed:  assert remove_Occ("hello","l") == "heo" ',
 'failed:  assert remove_Occ("hello","l") == "heo" ',
 'failed:  assert remove_Occ("hello","l") == "heo" ',
 'failed:  assert remove_Occ("hello","l") == "heo" ',
 'passed',
 'passed',
 'passed',
 'passed',
 'passed',
 'passed',
 'passed',
 'passed',
 'failed:  assert remove_Occ("hello","l") == "heo" ',
 'passed',
 'passed',
 'failed:  assert remove_Occ("hello","l") == "heo" ',
 'passed',
 'passed',
 'failed:  assert remove_Occ("hello","l") == "heo" ',
 'failed:  assert remove_Occ("hello","l") == "heo" ',
 'failed:  assert remove_Occ("hello","l") == "heo" ',
 'passed',
 'passed',
 'passed',
 'passed',
 'passed',
 'failed: substring not found']

Even still, the top-1 output is not correct -- worse, the top few candidates are all wrong. As it turns out, generator probability has been found to be a relatively poor proxy for execution accuracy; in contrast, recent work has found greater success using mutual information [(Zhang et al., 2022)](https://arxiv.org/abs/2211.16490) or learned reward models [(Ni et al., 2023)](https://arxiv.org/abs/2302.08468) as value functions for best-of-$n$.

### Minimum Bayes Risk

\begin{align*}
\hat{y} &= \arg \max_{y' \in \mathcal{Y_h}} \sum_{y \in \mathcal{Y_e}} u(y, y')
\end{align*}

In this section, we consider minimum Bayes risk. We consider two choices of pairwise utility function: (1) character-level edit similarity and (2) code execution equivalence, referred to as MBR-Exec [(Shi et al., 2022)](https://arxiv.org/abs/2204.11454).

In [31]:
PairwiseUtility = Callable[[str, str], float]

def mbr(prompt: str, metric_fn: PairwiseUtility, **generate_kwargs) -> tuple[list[str], list[float]]:
    '''
    Runs MBR decoding with custom user-defined pairwise utility
    '''
    codes, response = generate_code(prompt, **generate_kwargs)
    pairwise_scores = np.zeros((len(codes), len(codes)))
    for i1, code1 in enumerate(codes):
        for i2, code2 in enumerate(codes):
            if i1 <= i2:
                sim = metric_fn(code1, code2)
                pairwise_scores[i1, i2] = sim
                pairwise_scores[i2, i1] = sim
    gains = pairwise_scores.mean(axis=-1)
    sorted_indices = np.argsort(gains)[::-1]

    gains = [float(gains[i]) for i in sorted_indices]
    codes = [codes[i] for i in sorted_indices]
    return codes, gains

def edit_sim(code1: str, code2: str) -> float:
    edit_distance = jellyfish.levenshtein_distance(code1, code2)
    return 1 - edit_distance / max(len(code1), len(code2))

def make_exec_metric(test_list: list[str]) -> PairwiseUtility:
    '''
    Make pairwise similarity metric for specific example based on that example's test cases
    '''
    # extract the calls to this example's function
    test_func_calls = extract_func_calls(test_list)
    
    def exec_sim(code1: str, code2: str) -> float:
        '''
        Runs code1 and code2 on test_func_calls (from the closure of this function)
        Returns proportion of func calls with same execution result
        '''
        result1, result2 = execute_codes([code1, code2], test_func_calls)
        n_same = 0
        for r1, r2 in zip(result1, result2):
            if isinstance(r1, Exception) or isinstance(r2, Exception):
                continue
            if r1 == r2:
                n_same += 1
        similarity = n_same / len(test_func_calls)
        return similarity
    return exec_sim

In [29]:
# MBR-edit-sim
codes, gains = mbr(prompts[0], edit_sim, n=n, temperature=temperature, top_p=top_p, seed=RANDOM_SEED)
print(gains)
print(codes[0])

[0.7784882503257926, 0.7784882503257926, 0.7784882503257926, 0.7784882503257926, 0.7784882503257926, 0.7784882503257926, 0.7784882503257926, 0.7714686999833471, 0.7714686999833471, 0.7571552423515044, 0.7571552423515044, 0.7571552423515044, 0.7571552423515044, 0.7571552423515044, 0.7571552423515044, 0.7457562750158142, 0.7255068827231455, 0.7245432998058251, 0.7073885141071726, 0.7067493957561581, 0.6783766468826269, 0.6759921165572798, 0.6759921165572798, 0.6679532163742689, 0.6546167808379094, 0.6527382634190173, 0.6405044385544196, 0.6219316478118512, 0.5927113899384178, 0.4310940143717774]
def remove_Occ(s, char):
    first_occ = s.find(char)
    last_occ = s.rfind(char)
    if first_occ != -1 and last_occ != -1:
        return s[:first_occ] + s[first_occ+1:last_occ] + s[last_occ+1:]
    else:
        return s


In [30]:
# Evaluate MBR-edit-sim
execution_results = execute_tests(codes, mbpp[0]['test_list'])
[a['result'] for a in execution_results]

['passed',
 'passed',
 'passed',
 'passed',
 'passed',
 'passed',
 'passed',
 'passed',
 'passed',
 'failed:  assert remove_Occ("hello","l") == "heo" ',
 'failed:  assert remove_Occ("hello","l") == "heo" ',
 'failed:  assert remove_Occ("hello","l") == "heo" ',
 'failed:  assert remove_Occ("hello","l") == "heo" ',
 'failed:  assert remove_Occ("hello","l") == "heo" ',
 'failed:  assert remove_Occ("hello","l") == "heo" ',
 'failed:  assert remove_Occ("hello","l") == "heo" ',
 'failed:  assert remove_Occ("hello","l") == "heo" ',
 'failed:  assert remove_Occ("hello","l") == "heo" ',
 'failed:  assert remove_Occ("hello","l") == "heo" ',
 'failed:  assert remove_Occ("hello","l") == "heo" ',
 'passed',
 'passed',
 'passed',
 'passed',
 'passed',
 'passed',
 'failed:  assert remove_Occ("hello","l") == "heo" ',
 'passed',
 'failed: substring not found',
 'passed']

In [None]:
# MBR-exec
codes, gains = mbr(prompts[0], make_exec_metric(mbpp[0]['test_list'][:1]), n=n, temperature=temperature, top_p=top_p, seed=RANDOM_SEED)
print(gains)
print(codes[0])

In [None]:
execution_results = execute_tests(codes, mbpp[0]['test_list'])
[a['result'] for a in execution_results]

### Self-repair

Self-repair (or self-debugging) is a refinement strategy in which the code LLM iteratively repairs its own previously-generated code based on compiler or execution error messages. Here, the __initial generation__ is the code LLM's initial code samples and __feedback__ is derived from the code execution environment. Finally, the model __refines__ its generation by (optionally) explaining the previous turn's error and writing an improved piece of code.

\begin{align*}
    y^{(0)}&\sim g_0(y|x) \tag{initial generation}\\
    z^{(t)}&\sim h(z|x,y^{(<t)},z^{(<t)}) \tag{feedback}\\
    y^{(t)}&\sim g(y|x,y^{(<t)},z^{(\leq t)}) \tag{refinement}
\end{align*}

Self-repair has been shown to be effective for a variety of code generation tasks [(Chen et al., 2023)](https://arxiv.org/abs/2304.05128). However, for more challenging tasks, it has been found to have some limitations and may not always be more effective than best-of-N [(Olausson et al., 2024)](https://arxiv.org/abs/2306.09896).

In [43]:
# Example repair prompt from Olausson et al (https://arxiv.org/abs/2306.09896).
repair_template = [
    {
        "role": "system",
        "content": "You are a helpful programming assistant and an expert Python programmer. "
                   "You are helping a user write a program to solve a problem. "
                   "The user has written some code, but it has some errors and is not passing the tests. "
                   "You will help the user by first giving a concise (at most 2-3 sentences) textual explanation "
                   "of what is wrong with the code. After you have pointed out what is wrong with the code, "
                   "you will then generate a fixed version of the program. "
                   "Put your fixed program within code delimiters, for example: ```\n# YOUR CODE HERE\n```.",
    },
    {
        "role": "user",
        "content": "### Problem\n"
                   "{problem}\n\n"
                   "### Incorrect Code:\n"
                   "{code}\n\n"
                   "###Error\n"
                   "{error}"
    }
]

def run_repair(code: str, problem: str, error_msg: str, **generate_kwargs) -> str:
    '''
    Runs a single iteration of refinement given the previous round's code's error message
    '''
    repair_prompt = copy.deepcopy(repair_template)
    repair_prompt[-1]['content'] = repair_prompt[-1]['content'].format(problem=problem, code=code, error=error_msg)

    new_code, response = generate_code(repair_prompt, **generate_kwargs)
    return new_code, response


def self_repair(prompt: str, tests: list[str], max_rounds: int, **generate_kwargs) -> list[str]:
    '''
    Runs generation with self-repair for a given prompt and tests.
    '''
    problem = prompt.rsplit("\n", 1)[0] # remove the additional instruction at the end of the prompt
    codes, response = generate_code(prompt, **generate_kwargs)

    execution_results = execute_codes(codes, tests)

    refine_history = []

    refine_history.append((codes, response, copy.deepcopy(execution_results)))
    for round in range(max_rounds):
        refined_codes = []
        refined_responses = []
        for code, exec_result in zip(codes, execution_results):
            if exec_result['passed']:
                refined_codes.append(None)
                refined_responses.append(None)
                continue

            refined_code, refined_response = run_repair(code, exec_result['result'], n=1)
            refined_codes.append(refined_code[0])
            refined_responses.append(refined_response)
        
        execution_results = execute_codes(refined_codes, tests)
        codes = refined_code
        refine_history.append((refined_codes, refined_responses, copy.deepcopy(execution_results)))

    return refine_history

Observe that the model's earlier initial generation did not pass the tests:

In [51]:
print("Code:\n" + codes[0] + "\n")
print("Result:\n" + execution_results[0]['result'])

Code:
def remove_Occ(s, char):
    first_occ = s.find(char)
    last_occ = s.rfind(char)
    
    if first_occ != -1:
        s = s[:first_occ] + s[first_occ+1:]
    if last_occ != -1:
        s = s[:last_occ] + s[last_occ+1:]
    
    return s

Result:
failed:  assert remove_Occ("hello","l") == "heo" 


Now, let's try an iteration of refinement:

In [52]:
problem = prompts[0].rsplit("\n", 1)[0]
repair_code, repair_response = run_repair(codes[0], problem, execution_results[0]['result'], n=1)

The new result:

In [53]:
execute_tests(repair_code, mbpp[0]['test_list'])

[{'task_id': None, 'passed': True, 'result': 'passed', 'completion_id': None}]

It passes the tests! And the LLM's reasoning for its refinement:

In [54]:
print(repair_response.choices[0].message.content)

The issue with the code is that it removes the last occurrence of the character before removing the first occurrence. This causes the removal of the first occurrence to shift the index of the last occurrence, resulting in the incorrect output. To fix this, you should remove the last occurrence first and then the first occurrence.

```python
def remove_Occ(s, char):
    last_occ = s.rfind(char)
    first_occ = s.find(char)
    
    if last_occ != -1:
        s = s[:last_occ] + s[last_occ+1:]
    if first_occ != -1:
        s = s[:first_occ] + s[first_occ+1:]
    
    return s
```
