## Meta-generation algorithms: `Treefinement`

We show a meta-generation algorithm called **Treefinement** [[Aggarwal et al 2024]()] that combines parallel sampling, refinement, and tree search.

---
Authors: Sean Welleck, Pranjal Aggarwal

#### Problem setting: verified code generation

We assume access to a verifier $V(s)$ that tells us if we have reached a successful terminal state.

Concretely, our states contain Rust programs, and $V$ is a program verifier (called Verus) that either succeeds or returns error messages.

Given a program specification $x$, the task is to generate a program $y$ that passes the verifier, $V(y)=1$.

In [1]:
import utils
import multiprocessing
import os
from functools import partial

# Please see the README for information on the VERUS_PATH 
VERUS_PATH = os.environ['VERUS_PATH']

## Treefinement

<img src="treefinement-background.png" width="600px">

**Treefinement** receives a specification as input, and begins by generating multiple program candidates and running the verifier to receive error messages. One of these program candidates is selected for iterative refinement.  During iterative refinement, a node $s$ contains to an output program candidate $y_s$ and its error messages $e(y_s)$. Each edge corresponds to refining the program, i.e. mapping from $(x, y_s, e(y_s))$ to a new program candidate $y_s'$, and running the verifier.

**Treefinement** performs tree search on this environment using a search strategy such as breadth-first search or depth-first search. We use the [$\mathrm{Rebase}$ [Wu et al 2024]](https://arxiv.org/abs/2408.00724) tree search strategy, which assumes access to a heuristic **value function** $v(s)$ that assigns a scalar score to each node. Given nodes $s_1,\ldots,s_B$ at the current depth, $\mathrm{Rebase}$ expands each node $s_i$ proportional to its value:
\begin{align}
    w_i=\text{Round}\left(B\frac{\exp\left(v(s_i)/\tau\right)}{\sum_j \exp\left(v(s_j)/\tau\right)}\right),
\end{align}
where $v(\cdot)$ is the heuristic value function, $B$ is a total expansion width hyperparameter, and $\tau$ is a temperature hyperparameter.

### Implementation

First, we assume that a LLM generator is accessible at `base_url`. The LLM generator should support the standard OpenAI API.


<!-- We use SGLang to do so, which lets us benefit from its features such as prefix caching. -->

In [2]:
import openai

config = {
    'generation_params': {
        'temperature': 0.7,
        'max_tokens': 1024,
    },
    'batch_size': 32,
    'model': "gpt-4o"
}

client = openai.Client()

#### Input specification

As our running example, we will use this specification that comes from the [HumanEval-Verus dataset](https://github.com/secure-foundations/human-eval-verus/blob/main/tasks/human_eval_042.rs):

```rust
// Return list with elements incremented by 1.
fn incr_list(l: Vec<i32>) -> (result: Vec<i32>)
    requires
        forall|i: int| 0 <= i < l.len() ==> l[i] + 1 <= i32::MAX,
    ensures
        result.len() == l.len(),
        forall|i: int| 0 <= i < l.len() ==> #[trigger] result[i] == l[i] + 1,
{
```

In [3]:
input_program = """use vstd::prelude::*;

verus! {

// Return list with elements incremented by 1.
fn incr_list(l: Vec<i32>) -> (result: Vec<i32>)
    requires
        forall|i: int| 0 <= i < l.len() ==> l[i] + 1 <= i32::MAX,  // avoid overflow

    ensures
        result.len() == l.len(),
        forall|i: int| 0 <= i < l.len() ==> #[trigger] result[i] == l[i] + 1,
{
    """


#### 1. Initialization

We perform the initialization step using parallel sampling, i.e. generating program candidates $y_1,\ldots,y_N$.

First, we implement the generator as an API call that provides a system prompt and an instruction containing the specification:

In [4]:
def parallel_generator(config, specification):
    paren = '{}'
    paren_close = '}'
    messages=[
        { "role": "system", "content": utils.system_prompt() },
        { "role": "user", "content": f"Consider the following verus code:\n```{specification}\n```\n\nThe code contains the relevant spec functions and the preconditions (requires) and postconditions (ensures) for the main function. Your goal is to complete the function, by adding necessary procedure, along with proof statements (such as invariants, asserts, proof blocks etc) to prove the program. Only output the new program and not the entire code. You are not allowed to create new functions, however can use any functions already defined if within the scope. Remember to just output the completion without the function signature, requires and ensures. Only the body of the function is required. Remember to end in: \n```rust\n{paren_close} // End of function\n{paren_close} // verus!\nfn main() {paren}\n```\n\n"},
    ]
    
    response = client.chat.completions.create(
        model=config['model'],
        messages = messages,
        temperature=config['generation_params']['temperature'],
        max_tokens=1024,
        top_p=1,
        frequency_penalty=0,
        presence_penalty=0,
        response_format={
            "type": "text"
        },
        n=config['batch_size']
    )
    generations = [choice.message.content for choice in response.choices]
    return generations



Now let's run the parallel generation, and run the verifier to check if any of the generated programs were verified. Note that the verifier returns `0` when a program is syntactically valid, but does not verify.

In [5]:
generations = parallel_generator(config, input_program)

valid, verified = [], []
for generation in generations:
    out = utils.check(utils.parse_generation(input_program, generation), VERUS_PATH)
    if out['verified'] == 0:
        valid.append(out)

    if out['verified'] == 1:
        verified.append(out)

print(len(valid)/config['batch_size'], 'Syntactically valid', sep=' ')
print(len(verified)/config['batch_size'], 'Verified', sep=' ')

0.6875 Syntactically valid
0.0 Verified


Above, we see that some of the programs were syntactically valid, but did not pass the verifier.

Hence, we will pass these candidate programs as initial programs in Treefinement, which will perform iterative refinement as a tree search.

#### 2. Treefinement

Next we implement the **Treefinement** tree search. This involves a few components:

- A `refinement_generator` that calls an LLM. We implement an asychronous version that expands the current node set in parallel.
- An initial message function (`initialize_history`), containing instructions, the initial program, and its errors.
- Messages for subsequent refinements, containing the current program and its errors.
- The value function (`evaluate_code`). We use the symbolic value function from [Aggarwal et al 2024](), which is based on the number of verified functions and errors reported by the verifier.
- Logic that calls the value function, verifier, and implements the softmax expansion from `REBASE`.

Additional utilities (e.g., parsing, the actual verifier call) are in `utils.py`.

In [7]:
# This is the refinement generator
def async_refinement_generator(history, model, temperature, max_tokens, n, batch_size):
    run_llm_with_kwargs = partial(
        utils.refinement_generator, 
        model=model, temperature=temperature, max_tokens=max_tokens, n=n
    )
    with multiprocessing.Pool(batch_size) as pool:
        generations = pool.map(run_llm_with_kwargs, history)
    generations = [x[0] for x in generations if len(x) > 0]
    return generations


In [8]:
# This initializes the messages that will be sent to the refinement generator.
# In particular, it creates `expansion_width` copies of two messages: 
#   - a system prompt
#   - an instruction containing the program and its error message
def initialize(programs, error_messages, expansion_width):
    paren_close = '}'
    paren = '{}'
    messages = []
    for program, error_message in zip(programs, error_messages):
        message = {
        'role': 'user',
        'content': [
                {
                    'type': 'text',
                    'text': f"""Given a Verus program with function signature, preconditions, postconditions, and code, fix the errors present in the code. Your task is to provide only the corrected body of the function, without repeating the function signature, requires, ensures, or specs. Focus on fixing proof statements or adjusting the code to ensure correct compilation. Do not modify function signatures, requires clauses, ensures clauses, or specs.
```rust
<function_body>
{paren_close} // End of function
fn main() {paren}
{paren_close} // verus!
```

Below is the program::
```rust
{program}
```

The program has following error message:
```
{error_message}
````

Solution Format:
[Thoughts on Error Message]
[Thoughts on Error Resolution]
[Thoughts on Corner Cases, such as Overflow etc.]
```rust
[Function Body, with closing parantheses]
```""",
                }
            ]
        }
        messages.append([{
                "role": "system",
                "content": utils.system_prompt() 
            }, message])
        
        # Fill in so that we have `expansion_width` candidates
        for i in range(expansion_width-len(messages)):
            messages.append(messages[-1])
    return messages

def refinement_message(error_message):
    return [{
        'role' : 'user',
        'content' : [
            {
                'type': 'text',
                'text': f"I got the following errors:\n ```{error_message}```\n Follow the previous format, and fix all the errors. Then give the complete updated code."
            }   
        ]
    }]

Now we implement and perform the tree search.

In [9]:
import torch
from torch.nn import Softmax
import random
from copy import deepcopy
import numpy as np
import pickle

max_iters = 10
rebase_temperature = 0.1
save_dir = './'
verus_path = VERUS_PATH

done = False


# Get the initial programs and error messages
programs = [x['extracted_code'] for x in valid]
scores = [utils.evaluate_code(x['extracted_code'], VERUS_PATH) for x in valid]
error_messages = [s[1] for s in scores]

# Initialize
history = initialize(programs, error_messages, config['batch_size'])

for iteration_number in range(max_iters):
    print('Iteration Number:', iteration_number)

    # Expand
    generations = async_refinement_generator(
        history=history,
        model=config['model'],
        temperature=config['generation_params']['temperature'],
        max_tokens=config['generation_params']['max_tokens'],
        n=1,
        batch_size=config['batch_size']
    )
    new_states = [history[i] + [{'role': 'assistant', 'content': x}] for i, x in enumerate(generations)]

    # Score expanded nodes with the value function
    scores = [utils.evaluate_node(input_program, state, verus_path) for state in new_states]
    error_messages = [x[1] for x in scores]
    scores = [x[0] for x in scores]

    # Filter out invalid states
    index_to_keep = [i for i, x in enumerate(scores) if x >= 0]
    if len(index_to_keep) == 0:
        print('No valid states found, retrying')
        continue
    new_states = [new_states[i] for i in index_to_keep]
    scores = [scores[i] for i in index_to_keep]
    error_messages = [error_messages[i] for i in index_to_keep]

    # REBASE expansion: sample K instances from the new_states such that the probability of sampling an 
    # instance is proportional to the softmax of score.
    softmax_scores = Softmax(dim=0)(torch.tensor([x/rebase_temperature for x in scores])).tolist()
    new_indices = random.choices(
        [i for i in range(len(new_states))], 
        weights = softmax_scores, 
        k=config['batch_size']
    )
    new_indices = sorted(new_indices)
    new_states = [deepcopy(new_states[i]) for i in new_indices]
    scores = [scores[i] for i in new_indices]
    error_messages = [deepcopy(error_messages[i]) for i in new_indices]

    # Make a refinement message for the next iteration's expansion stage.
    for i in range(len(new_states)):
        new_states[i] += refinement_message(error_messages[i])
    history = new_states

    print("Scores (avg %.3f): %s" % (np.mean(scores), str(scores)))

    # Check if any programs were successfully verified.
    if any([x>=1 for x in scores]):
        print('Found a verified program!')
        done = True
        break


Iteration Number: 0
Scores (avg 0.632): [0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6]
Iteration Number: 1
Scores (avg 0.600): [0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6]
Iteration Number: 2
Scores (avg 0.626): [0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6333333333

## Visualize a successful trajectory

**Treefinement** found a verified program! Let's visualize the initial program, refined programs, and their error messages on a successful trajectory.


In [11]:
import re
from IPython.display import Markdown

def display_rust_program_with_errors(program: str, errors: str):
    rust_code_md = f"```rust\n{program}\n```"
    error_message_md = f"**Error Message:**\n```\n{errors}\n```"
    return rust_code_md + "\n\n" + error_message_md


verified_index = np.argmax(scores)
md = "## Successful trajectory\n\n### Initialization\n\n"


for i in range(0, len(history[verified_index])):
    if i % 2 == 1:
        continue
    if i < 2:
        errors_ = re.findall(r"The program has following error message:\n```(.*?)```", history[verified_index][1]['content'][0]['text'], re.DOTALL)[0]
        program_ = programs[verified_index]
    else:
        program_ = input_program.strip()[:-1]+ utils.extract_code(history[verified_index][i]['content'])
        md += "\n\n---------------\n\n### Refinement %d\n\n" % (i // 2)
        errors_ = history[verified_index][i+1]['content'][0]['text']
        errors_ = re.findall(r"I got the following errors:\n ```(.*?)```", errors_, re.DOTALL)[0]
    md += display_rust_program_with_errors(program_, errors_)

Markdown(md)


## Successful trajectory

### Initialization

```rust
use vstd::prelude::*;

verus! {

// Return list with elements incremented by 1.
fn incr_list(l: Vec<i32>) -> (result: Vec<i32>)
    requires
        forall|i: int| 0 <= i < l.len() ==> l[i] + 1 <= i32::MAX,  // avoid overflow

    ensures
        result.len() == l.len(),
        forall|i: int| 0 <= i < l.len() ==> #[trigger] result[i] == l[i] + 1,
{
    let len: usize = l.len();
    let mut result = Vec::with_capacity(len);
    for i in 0..len
        invariant
            result.len() == i,
            len == l.len(),
            forall|j: int| 0 <= j < i ==> result[j] == l[j] + 1,
    {
        let elem = l[i];
        result.push(elem + 1);
    }
    result
} // End of function
} // verus!
fn main() {}
```

**Error Message:**
```

error: invariant not satisfied at end of loop body
  --> temp.rs:20:13
   |
20 |             forall|j: int| 0 <= j < i ==> result[j] == l[j] + 1,
   |             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

error: possible arithmetic underflow/overflow
  --> temp.rs:23:21
   |
23 |         result.push(elem + 1);
   |                     ^^^^^^^^

note: automatically chose triggers for this expression:
  --> temp.rs:20:13
   |
20 |             forall|j: int| 0 <= j < i ==> result[j] == l[j] + 1,
   |             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

note:   trigger 1 of 1:
  --> temp.rs:20:43
   |
20 |             forall|j: int| 0 <= j < i ==> result[j] == l[j] + 1,
   |                                           ^^^^^^^^^

note: Verus printed one or more automatically chosen quantifier triggers
      because it had low confidence in the chosen triggers.
      To suppress these messages, do one of the following:
        (1) manually annotate a single desired trigger using #[trigger]
            (example: forall|i: int, j: int| f(i) && #[trigger] g(i) && #[trigger] h(j)),
        (2) manually annotate multiple desired triggers using #![trigger ...]
            (example: forall|i: int| #![trigger f(i)] #![trigger g(i)] f(i) && g(i)),
        (3) accept the automatically chosen trigger using #![auto]
            (example: forall|i: int, j: int| #![auto] f(i) && g(i) && h(j))
        (4) use the --triggers-silent command-line option to suppress all printing of triggers.
      (Note: triggers are used by the underlying SMT theorem prover to instantiate quantifiers;
      the theorem prover instantiates a quantifier whenever some expression matches the
      pattern specified by one of the quantifier's triggers.)
  --> temp.rs:20:13
   |
20 |             forall|j: int| 0 <= j < i ==> result[j] == l[j] + 1,
   |             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

error: aborting due to 2 previous errors



```

---------------

### Refinement 1

```rust
use vstd::prelude::*;

verus! {

// Return list with elements incremented by 1.
fn incr_list(l: Vec<i32>) -> (result: Vec<i32>)
    requires
        forall|i: int| 0 <= i < l.len() ==> l[i] + 1 <= i32::MAX,  // avoid overflow

    ensures
        result.len() == l.len(),
        forall|i: int| 0 <= i < l.len() ==> #[trigger] result[i] == l[i] + 1,
{
    let len: usize = l.len();
    let mut result = Vec::with_capacity(len);
    for i in 0..len
        invariant
            result.len() == i,
            len == l.len(),
            forall|j: int| 0 <= j < i ==> #[trigger] result[j] == l[j] + 1,
    {
        let elem = l[i];
        assert(elem + 1 <= i32::MAX);  // Reinforce the precondition within the loop
        result.push(elem + 1);
    }
    result
} // End of function

fn main() {}
```

**Error Message:**
```
error: assertion failed
  --> temp.rs:22:16
   |
22 |         assert(elem + 1 <= i32::MAX);  // Reinforce the precondition within the loop
   |                ^^^^^^^^^^^^^^^^^^^^ assertion failed

error: aborting due to 1 previous error


```

---------------

### Refinement 2

```rust
use vstd::prelude::*;

verus! {

// Return list with elements incremented by 1.
fn incr_list(l: Vec<i32>) -> (result: Vec<i32>)
    requires
        forall|i: int| 0 <= i < l.len() ==> l[i] + 1 <= i32::MAX,  // avoid overflow

    ensures
        result.len() == l.len(),
        forall|i: int| 0 <= i < l.len() ==> #[trigger] result[i] == l[i] + 1,
{
    let len: usize = l.len();
    let mut result = Vec::with_capacity(len);
    for i in 0..len
        invariant
            result.len() == i,
            len == l.len(),
            forall|j: int| 0 <= j < i ==> #[trigger] result[j] == l[j] + 1,
    {
        let elem = l[i];
        result.push(elem + 1);  // The precondition ensures this won't overflow
    }
    result
} // End of function

fn main() {}
} // verus!
```

**Error Message:**
```
error: invariant not satisfied at end of loop body
  --> temp.rs:19:13
   |
19 |             forall|j: int| 0 <= j < i ==> #[trigger] result[j] == l[j] + 1,
   |             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

error: possible arithmetic underflow/overflow
  --> temp.rs:22:21
   |
22 |         result.push(elem + 1);  // The precondition ensures this won't overflow
   |                     ^^^^^^^^

error: aborting due to 2 previous errors


```

---------------

### Refinement 3

```rust
use vstd::prelude::*;

verus! {

// Return list with elements incremented by 1.
fn incr_list(l: Vec<i32>) -> (result: Vec<i32>)
    requires
        forall|i: int| 0 <= i < l.len() ==> l[i] + 1 <= i32::MAX,  // avoid overflow

    ensures
        result.len() == l.len(),
        forall|i: int| 0 <= i < l.len() ==> #[trigger] result[i] == l[i] + 1,
{
    let len: usize = l.len();
    let mut result = Vec::with_capacity(len);
    for i in 0..len
        invariant
            result.len() == i,
            len == l.len(),
            forall|j: int| 0 <= j < i ==> #[trigger] result[j] == l[j] + 1,
    {
        let elem = l[i];
        // Assert the precondition within the loop to ensure no overflow
        assert(elem + 1 <= i32::MAX);
        result.push(elem + 1);
    }
    result
} // End of function

fn main() {}
} // verus!
```

**Error Message:**
```
error: assertion failed
  --> temp.rs:23:16
   |
23 |         assert(elem + 1 <= i32::MAX);
   |                ^^^^^^^^^^^^^^^^^^^^ assertion failed

error: aborting due to 1 previous error


```

---------------

### Refinement 4

```rust
use vstd::prelude::*;

verus! {

// Return list with elements incremented by 1.
fn incr_list(l: Vec<i32>) -> (result: Vec<i32>)
    requires
        forall|i: int| 0 <= i < l.len() ==> l[i] + 1 <= i32::MAX,  // avoid overflow

    ensures
        result.len() == l.len(),
        forall|i: int| 0 <= i < l.len() ==> #[trigger] result[i] == l[i] + 1,
{
    let len: usize = l.len();
    let mut result = Vec::with_capacity(len);
    for i in 0..len
        invariant
            result.len() == i,
            len == l.len(),
            forall|j: int| 0 <= j < i ==> #[trigger] result[j] == l[j] + 1,
            forall|j: int| 0 <= j < len ==> l[j] + 1 <= i32::MAX, // Reinforce precondition as part of the invariant
    {
        let elem = l[i];
        result.push(elem + 1);
    }
    result
} // End of function

fn main() {}
} // verus!
```

**Error Message:**
```

```

### Further reading

The implementation above is a simplified demonstration of Treefinement based on [[Aggarwal et al 2024]()]. In particular, the version here relies on a fixed base language model (such as GPT-4), which places an upper bound on its performance. 

Check out [[Aggarwal et al 2024]()] for a version that iteratively improves by updating the generator based on feedback from collected trajectories, which enables a much more powerful treefinement search than the one shown above, even with openly available models (e.g., Llama 3.1).

#### Refinement vs. additional parallel sampling

A natural question is whether we actually need refinement: could we just sample more in the initialization step?

This plot from [[Aggarwal et al 2024]()] shows that refinement is necessary. Namely, it compares:
1. initialization with 256 parallel samples followed by Treefinement with 256 total generations ("Tree Search")
2. initialization with 512 parallel samples ("Exploration")

<img src="./treefinement_vs_parallel.png" width="400px">

We can see that the tree-search based refinement gives an additional jump in performance that is not obtainable with parallel sampling. Namely, parallel sampling saturates, while Treefinement leads to an additional improvement.


#### Tree search vs. linear refinement

A second question is whether structure refinement as a tree search is needed: could we just refine once, or just sample a few chains and then iteratively refine each chain?

[[Aggarwal et al 2024]()] shows that tree search refinement indeed outperforms both one-step refinement and 8-step refinement, holding the generation budget constant across methods (256 total generations).