## Meta-generation algorithms: `Treefinement`

We show a meta-generation algorithm called **Treefinement** [[Aggarwal et al 2024]()] that combines parallel sampling, refinement, and tree search.

#### Problem setting: verified code generation

We assume access to a verifier $V(s)$ that tells us if we have reached a successful terminal state.

Concretely, our states contain Rust programs, and $V$ is a program verifier (called Verus) that either succeeds or returns error messages.

Given a program specification $x$, the task is to generate a program $y$ that passes the verifier, $V(y)=1$.

In [1]:
VERUS_PATH = "/Users/wellecks/projects/verus/source/target-verus/release/verus"

import utils

## Treefinement

<img src="treefinement-background.png" width="600px">

**Treefinement** receives a specification as input, and begins by generating multiple program candidates and running the verifier to receive error messages. One of these program candidates is selected for iterative refinement.  During iterative refinement, a node $s$ contains to an output program candidate $y_s$ and its error messages $e(y_s)$. Each edge corresponds to refining the program, i.e. mapping from $(x, y_s, e(y_s))$ to a new program candidate $y_s'$, and running the verifier.

**Treefinement** performs tree search on this environment using a search strategy such as breadth-first search or depth-first search. We use the [$\mathrm{Rebase}$ [Wu et al 2024]](https://arxiv.org/abs/2408.00724) tree search strategy, which assumes access to a heuristic **value function** $v(s)$ that assigns a scalar score to each node. Given nodes $s_1,\ldots,s_B$ at the current depth, $\mathrm{Rebase}$ expands each node $s_i$ proportional to its value:
\begin{align}
    w_i=\text{Round}\left(B\frac{\exp\left(v(s_i)/\tau\right)}{\sum_j \exp\left(v(s_j)/\tau\right)}\right),
\end{align}
where $v(\cdot)$ is the heuristic value function, $B$ is a total expansion width hyperparameter, and $\tau$ is a temperature hyperparameter.

### Implementation

First, we assume that a LLM generator is accessible at `base_url`. The LLM generator should support the standard OpenAI API.


<!-- We use SGLang to do so, which lets us benefit from its features such as prefix caching. -->

In [62]:
import openai

config = {
    'generation_params': {
        'temperature': 0.7,
        'max_tokens': 1024,
    },
    'batch_size': 32,
    'base_url': "",
    'api_key': "",
    'model': "gpt-4o"
}

client = openai.Client()

#### Input specification

As our running example, we will use this specification that comes from the [HumanEval-Verus dataset](https://github.com/secure-foundations/human-eval-verus/blob/main/tasks/human_eval_042.rs):

In [98]:
input_program = """use vstd::prelude::*;

verus! {

// Return list with elements incremented by 1.
fn incr_list(l: Vec<i32>) -> (result: Vec<i32>)
    requires
        forall|i: int| 0 <= i < l.len() ==> l[i] + 1 <= i32::MAX,  // avoid overflow

    ensures
        result.len() == l.len(),
        forall|i: int| 0 <= i < l.len() ==> #[trigger] result[i] == l[i] + 1,
{
    """


#### 1. Initialization

We perform the initialization step using best-of-$N$, i.e. generating program candidates $y_1,\ldots,y_N$ and selecting the program candidate with the highest value $v(y)$.

We use the symbolic value function from [Aggarwal et al 2024](), which is based on the number of verified conditions and errors reported by the verifier.

In [99]:
def parallel_generator(config, program):
    paran = '{}'
    paran_close = '}'
    messages=[
        { "role": "system", "content": utils.system_prompt() },
        { "role": "user", "content": f"Consider the following verus code:\n```{program}\n```\n\nThe code contains the relevant spec functions and the preconditions (requires) and postconditions (ensures) for the main function. Your goal is to complete the function, by adding necessary procedure, along with proof statements (such as invariants, asserts, proof blocks etc) to prove the program. Only output the new program and not the entire code. You are not allowed to create new functions, however can use any functions already defined if within the scope. Remember to just output the completion without the function signature, requires and ensures. Only the body of the function is required. Remember to end in: \n```rust\n{paran_close} // End of function\n{paran_close} // verus!\nfn main() {paran}\n```\n\n"},
    ]
    
    response = client.chat.completions.create(
        model=config['model'],
        messages = messages,
        temperature=config['generation_params']['temperature'],
        max_tokens=1024,
        top_p=1,
        frequency_penalty=0,
        presence_penalty=0,
        seed=0,
        response_format={
            "type": "text"
        },
        n=config['batch_size']
    )
    generations = [choice.message.content for choice in response.choices]
    return generations



In [100]:
generations = parallel_generator(config, input_program)

valid, verified = [], []
for generation in generations:
    out = utils.check(utils.parse_generation(input_program, generation), VERUS_PATH)
    if out['verified'] == 0:
        valid.append(out)

    if out['verified'] == 1:
        verified.append(out)

print(len(valid)/config['batch_size'], 'Syntactically valid', sep=' ')
print(len(verified)/config['batch_size'], 'Verified', sep=' ')

0.625 Syntactically valid
0.0 Verified


Above, we see that some of the programs were syntactically valid, but did not pass the verifier. Let's select the best one for the next stage.

In [112]:
import numpy as np

scores = [utils.evaluate_code(x['extracted_code'], VERUS_PATH) for x in valid]
best_idx = np.argmax([x[1] for x in scores])
program = valid[best_idx]['extracted_code']
error_messages = scores[best_idx][1]
# scores, program
print("/* CODE: */\n%s" % program)
print("/* ERRORS:\n%s\n*/" % error_messages)


/* CODE: */
use vstd::prelude::*;

verus! {

// Return list with elements incremented by 1.
fn incr_list(l: Vec<i32>) -> (result: Vec<i32>)
    requires
        forall|i: int| 0 <= i < l.len() ==> l[i] + 1 <= i32::MAX,  // avoid overflow

    ensures
        result.len() == l.len(),
        forall|i: int| 0 <= i < l.len() ==> #[trigger] result[i] == l[i] + 1,
{
    let mut result = Vec::with_capacity(l.len());
    let l_len: usize = l.len();
    for i in 0..l_len
        invariant
            result.len() == i,
            l_len == l@.len(),
            forall|j: int| 0 <= j < i ==> #[trigger] result[j] == l[j] + 1,
    {
        let element = l[i];
        result.push(element + 1);
        assert(result[i as int] == l[i as int] + 1);
    }
    result
} // End of function
} // verus!
fn main() {}
/* ERRORS:
error: possible arithmetic underflow/overflow
  --> temp.rs:23:21
   |
23 |         result.push(element + 1);
   |                     ^^^^^^^^^^^

error: assertion failed
  --> tem

#### 2. Treefinement

In [103]:

import multiprocessing
from functools import partial

def async_refinement_generator(history, model, temperature, max_tokens, n, batch_size):
    run_llm_with_kwargs = partial(
        utils.refinement_generator, 
        model=model, temperature=temperature, max_tokens=max_tokens, n=n
    )
    with multiprocessing.Pool(batch_size) as pool:
        generations = pool.map(run_llm_with_kwargs, history)
    generations = [x[0] for x in generations if len(x) > 0]
    return generations


In [104]:
from copy import deepcopy

def initialize_history(program, error_message, expansion_width):
    paran_close = '}'
    paran = '{}'
    hints_fill_format = {
    'role': 'user',
    'content': [
            {
                'type': 'text',
                'text': f"""Given a Verus program with function signature, preconditions, postconditions, and code, fix the errors present in the code. Your task is to provide only the corrected body of the function, without repeating the function signature, requires, ensures, or specs. Focus on fixing proof statements or adjusting the code to ensure correct compilation. Do not modify function signatures, requires clauses, ensures clauses, or specs.
```rust
<function_body>
{paran_close} // End of function
fn main() {paran}
{paran_close} // verus!
```

Below is the program::
```rust
{program}
```

The program has following error message:
```
{error_message}
````

Solution Format:
[Thoughts on Error Message]
[Thoughts on Error Resolution]
[Thoughts on Corner Cases, such as Overflow etc.]
```rust
[Function Body, with closing parantheses]
```""",
            }
        ]
    }
    return [deepcopy([{
            "role": "system",
            "content": utils.system_prompt() 
        }, hints_fill_format]) for _ in range(expansion_width)]


In [109]:

import torch
from torch.nn import Softmax
import random
from copy import deepcopy
import numpy as np
import pickle

max_iters = 7
rebase_temperature = 0.1
save_dir = './'
verus_path = VERUS_PATH

done = False

file_name = utils.save_code_to_file(program, '43')
_, error_message = utils.run_code(verus_path, file_name)

history = initialize_history(program, error_message, config['batch_size'])

for iteration_number in range(max_iters):
    print('Iteration Number:', iteration_number)
    generations = async_refinement_generator(
        history=history,
        model=config['model'],
        temperature=config['generation_params']['temperature'],
        max_tokens=config['generation_params']['max_tokens'],
        n=1,
        batch_size=config['batch_size']
    )
    new_states = [history[i] + [{'role': 'assistant', 'content': x}] for i, x in enumerate(generations)]

    scores = [utils.evaluate_node(input_program, state, verus_path) for state in new_states]
    error_messages = [x[1] for x in scores]
    scores = [x[0] for x in scores]

    index_to_keep = [i for i, x in enumerate(scores) if x>=0]

    if len(index_to_keep) == 0:
        print('No valid states found, retrying')
        continue

    new_states = [new_states[i] for i in index_to_keep]
    scores = [scores[i] for i in index_to_keep]
    error_messages = [error_messages[i] for i in index_to_keep]

    # REBASE expansion: sample K instances from the new_states such that the probability of sampling an 
    # instance is proportional to the softmax of score.
    softmax_scores = Softmax(dim=0)(torch.tensor([x/rebase_temperature for x in scores])).tolist()
    new_indices = random.choices([i for i in range(len(new_states))], weights = softmax_scores, k=config['batch_size'])
    new_indices = sorted(new_indices)

    new_states = [deepcopy(new_states[i]) for i in new_indices]
    scores = [scores[i] for i in new_indices]
    error_messages = [deepcopy(error_messages[i]) for i in new_indices]

    for i in range(len(new_states)):
        new_states[i] += [{
                'role' : 'user',
                'content' : [
                    {
                        'type': 'text',
                        'text': f"I got the following errors:\n ```{error_messages[i]}```\n Follow the previous format, and fix all the errors. Then give the complete updated code."
                    }   
                ]
            }]
    history = new_states

    print("Scores (avg %.3f): %s" % (np.mean(scores), str(scores)))

    if any([x>=1 for x in scores]):
        print('Found a verified program!')
        for i, score in enumerate(scores):
            if score>=1:
                with open(f'{save_dir}/solution_{iteration_number}.pkl', 'wb') as f:
                    pickle.dump(history[i], f)

            code = utils.extract_code(history[i][-2]['content'])
            with open(f'{save_dir}/correct_code_%d.rs' % i, 'w') as f:
                f.write(code)
        done = True
        break


Iteration Number: 0
Scores (avg 0.616): [0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6, 0.6333333333333333, 0.6, 0.6333333333333333, 0.6, 0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6333333333333333, 0.6333333333333333, 0.6, 0.6, 0.6, 0.6, 0.6]
Iteration Number: 1
Scores (avg 0.622): [0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6, 0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6, 0.6, 0.6, 0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6333333333333333, 0.6333333333333333]
Iteration Number: 2
Scores (avg 0.989): [1, 1, 1, 1, 1, 1,

## Visualize a successful trajectory

**Treefinement** found a verified program! Let's visualize the initial program, refined programs, and their error messages on a successful trajectory.


In [145]:
import re
from IPython.display import Markdown

# Function to display Rust code and error messages in Markdown format
def display_rust_program_with_errors(program: str, errors: str):
    # Formatting the Rust code block
    rust_code_md = f"```rust\n{program}\n```"
    # Formatting the error message
    error_message_md = f"**Error Message:**\n```\n{errors}\n```"
    # Display the Markdown in the notebook
    return rust_code_md + "\n\n" + error_message_md


verified_index = np.argmax(scores)
md = "## Successful trajectory\n\n### Initialization\n\n"

md += display_rust_program_with_errors(program, error_messages)

for i in range(2, len(history[verified_index])):
    if i % 2 == 1:
        continue
    md += "\n\n---------------\n\n### Refinement %d\n\n" % (i // 2)
    program_ = input_program.strip()[:-1]+ utils.extract_code(history[verified_index][i]['content'])
    errors_ = history[verified_index][i+1]['content'][0]['text']
    errors_ = re.findall(r"I got the following errors:\n ```(.*?)```", errors_, re.DOTALL)[0]
    md += display_rust_program_with_errors(program_, errors_)

Markdown(md)


## Successful trajectory

### Initialization

```rust
use vstd::prelude::*;

verus! {

// Return list with elements incremented by 1.
fn incr_list(l: Vec<i32>) -> (result: Vec<i32>)
    requires
        forall|i: int| 0 <= i < l.len() ==> l[i] + 1 <= i32::MAX,  // avoid overflow

    ensures
        result.len() == l.len(),
        forall|i: int| 0 <= i < l.len() ==> #[trigger] result[i] == l[i] + 1,
{
    let mut result = Vec::with_capacity(l.len());
    let l_len: usize = l.len();
    for i in 0..l_len
        invariant
            result.len() == i,
            l_len == l@.len(),
            forall|j: int| 0 <= j < i ==> #[trigger] result[j] == l[j] + 1,
    {
        let element = l[i];
        result.push(element + 1);
        assert(result[i as int] == l[i as int] + 1);
    }
    result
} // End of function
} // verus!
fn main() {}
```

**Error Message:**
```
error: possible arithmetic underflow/overflow
  --> temp.rs:23:21
   |
23 |         result.push(element + 1);
   |                     ^^^^^^^^^^^

error: assertion failed
  --> temp.rs:24:16
   |
24 |         assert(result[i as int] == l[i as int] + 1);
   |                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ assertion failed

error: aborting due to 2 previous errors


```

---------------

### Refinement 1

```rust
use vstd::prelude::*;

verus! {

// Return list with elements incremented by 1.
fn incr_list(l: Vec<i32>) -> (result: Vec<i32>)
    requires
        forall|i: int| 0 <= i < l.len() ==> l[i] + 1 <= i32::MAX,  // avoid overflow

    ensures
        result.len() == l.len(),
        forall|i: int| 0 <= i < l.len() ==> #[trigger] result[i] == l[i] + 1,
{
    let mut result = Vec::with_capacity(l.len());
    let l_len: usize = l.len();
    for i in 0..l_len
        invariant
            result.len() == i,
            l_len == l@.len(),
            forall|j: int| 0 <= j < i ==> #[trigger] result[j] == l[j] + 1,
    {
        let element = l[i];
        assert(element + 1 <= i32::MAX); // Reassert the precondition for clarity
        result.push(element + 1);
        assert(result[i as int] == l[i as int] + 1);
    }
    result
} // End of function

fn main() {}
```

**Error Message:**
```
error: assertion failed
  --> temp.rs:22:16
   |
22 |         assert(element + 1 <= i32::MAX); // Reassert the precondition for clarity
   |                ^^^^^^^^^^^^^^^^^^^^^^^ assertion failed

error: aborting due to 1 previous error


```

---------------

### Refinement 2

```rust
use vstd::prelude::*;

verus! {

// Return list with elements incremented by 1.
fn incr_list(l: Vec<i32>) -> (result: Vec<i32>)
    requires
        forall|i: int| 0 <= i < l.len() ==> l[i] + 1 <= i32::MAX,  // avoid overflow

    ensures
        result.len() == l.len(),
        forall|i: int| 0 <= i < l.len() ==> #[trigger] result[i] == l[i] + 1,
{
    let mut result = Vec::with_capacity(l.len());
    let l_len: usize = l.len();
    for i in 0..l_len
        invariant
            result.len() == i,
            l_len == l@.len(),
            forall|j: int| 0 <= j < i ==> #[trigger] result[j] == l[j] + 1,
    {
        let element = l[i];
        // Removed the redundant assertion of element + 1 <= i32::MAX
        result.push(element + 1);
        assert(result[i as int] == l[i as int] + 1);
    }
    result
} // End of function

fn main() {}
} // verus!
```

**Error Message:**
```
error: possible arithmetic underflow/overflow
  --> temp.rs:23:21
   |
23 |         result.push(element + 1);
   |                     ^^^^^^^^^^^

error: assertion failed
  --> temp.rs:24:16
   |
24 |         assert(result[i as int] == l[i as int] + 1);
   |                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ assertion failed

error: aborting due to 2 previous errors


```

---------------

### Refinement 3

```rust
use vstd::prelude::*;

verus! {

// Return list with elements incremented by 1.
fn incr_list(l: Vec<i32>) -> (result: Vec<i32>)
    requires
        forall|i: int| 0 <= i < l.len() ==> l[i] + 1 <= i32::MAX,  // avoid overflow

    ensures
        result.len() == l.len(),
        forall|i: int| 0 <= i < l.len() ==> #[trigger] result[i] == l[i] + 1,
{
    let mut result = Vec::with_capacity(l.len());
    let l_len: usize = l.len();
    for i in 0..l_len
        invariant
            result.len() == i,
            l_len == l@.len(),
            forall|j: int| 0 <= j < i ==> #[trigger] result[j] == l[j] + 1,
            // This invariant explicitly states the precondition within the loop
            forall|j: int| 0 <= j < l.len() ==> l[j] + 1 <= i32::MAX,
    {
        let element = l[i];
        // Explicitly ensuring the precondition holds within the loop
        assert(element + 1 <= i32::MAX);
        result.push(element + 1);
        assert(result[i as int] == l[i as int] + 1);
    }
    result
} // End of function

fn main() {}
} // verus!
```

**Error Message:**
```

```

### Further reading

The implementation above is a simplified demonstration of Treefinement based on [[Aggarwal et al 2024]()]. In particular, the version here relies on a fixed base language model (such as GPT-4), which places an upper bound on its performance. 

Check out [[Aggarwal et al 2024]()] for a version that iteratively improves by updating the generator based on feedback from collected trajectories, which enables a much more powerful treefinement search than the one shown above, even with openly available models (e.g., Llama 3.1).

You can also check out the paper for a comparison between the Treefinement search, parallel sampling, and one-step refinement with equal compute budgets, with Treefinement giving the best cost-performance tradeoff by a wide margin.