`The Art of Prompt Design`

# Tree Of Thoughts Desing Pattern

The *Tree Of Thoughts* prompting strategy (by [Yae et al](https://arxiv.org/abs/2305.10601)) is a generalization of the *Chain Of Thought* approach (by [Wei et al](https://arxiv.org/abs/2201.11903)) . In CoT the model is prompted to generate a series of intermediate reasoning steps before reaching the final output. ToT uses a similar pattern but it also allows LMs to perform deliberate decision making by considering multiple different reasoning paths and self-evaluating choices to decide the next course of action. This strategy increases the chances of finding a valid thought chain in contrast to CoT which committs to a single path early on.

ToT frames any problem as a search over a tree. It consists of 4 steps:
1. Define a way (based on the problem properties) to decompose the problem into a sequence of "thoughts" or steps.
2. Generate a set of candidate next "thoughts" or steps.
3. Evaluate the set of candidate steps and give a score to each one based on how likely it is to lead to the correct output.
4. Search over the generated tree of thoughts and select the best path (e.g. Using BFS or DFS).

![Image Description](../../docs/figures/tot.png)


## Setup

First lets import the necessary modules and load the language model.


In [1]:
import guidance
from guidance import models, gen, user, assistant, system
path_to_model = "/Users/sam/models/mistral/mistral-7b-v0.1.Q3_K_M.gguf"
model = models.LlamaCpp(path_to_model, n_ctx=4096)

## Examples With Previous Prompting Patterns

For comparison, we'll first see how some of the prompting patterns illustrated in the image above perform.

Lets use the [Game of 24](https://www.4nums.com/game/difficulties/) mathematical reasoning challenge as an example problem. The goal is to find a sequence of arithmetic operations that can be applied to a set of 4 numbers to get the number 24. For example, given input `4 9 10 13`, a solution output could be `(10 - 4) * (13 - 9) = 24`. This is a good problem to test the limits of the different prompting strategies because it requires non-trivial planning and search.

### Input-Output Prompting

In this approach there are no intermeadiate thoughts bridging the input to the output. In other words the model just directly responds without "thinking" about it.

In [2]:
io_prompt = '''Use numbers and basic arithmetic operations (+ - * /) to obtain 24.
Input: 4 4 6 8
Answer: (4 + 8) * (6 - 4) = 24
Input: 2 9 10 12
Answer: 2 * 12 * (10 - 9) = 24
Input: 4 9 10 13
Answer: (13 - 9) * (10 - 4) = 24
Input: 1 4 8 8
Answer: (8 / 4 + 1) * 8 = 24
Input: 5 5 5 9
Answer: 5 + 5 + 5 + 9 = 24
Input: {input}
'''

lm = model + io_prompt.format(input="3 3 4 6")
lm += gen("answer", suffix="\n")

As we see the output just matches the pattern of the examples in the prompt but its not a valid equation (and if it was it wouldn't result in the correct answer).

### Chain Of Thought Prompting

Here we prompt the model to generate a series of intermediate reasoning steps before reaching the final answer.

In [3]:
# 5-shot
cot_prompt = '''Use numbers and basic arithmetic operations (+ - * /) to obtain 24. At each step, you are only allowed to choose two of the remaining numbers to obtain a new number.
Input: 4 4 6 8
Steps:
4 + 8 = 12 (left: 4 6 12)
6 - 4 = 2 (left: 2 12)
2 * 12 = 24 (left: 24)
Answer: (6 - 4) * (4 + 8) = 24
Input: 2 9 10 12
Steps:
12 * 2 = 24 (left: 9 10 24)
10 - 9 = 1 (left: 1 24)
24 * 1 = 24 (left: 24)
Answer: (12 * 2) * (10 - 9) = 24
Input: 4 9 10 13
Steps:
13 - 10 = 3 (left: 3 4 9)
9 - 3 = 6 (left: 4 6)
4 * 6 = 24 (left: 24)
Answer: 4 * (9 - (13 - 10)) = 24
Input: 1 4 8 8
Steps:
8 / 4 = 2 (left: 1 2 8)
1 + 2 = 3 (left: 3 8)
3 * 8 = 24 (left: 24)
Answer: (1 + 8 / 4) * 8 = 24
Input: 5 5 5 9
Steps:
5 + 5 = 10 (left: 5 9 10)
10 + 5 = 15 (left: 9 15)
15 + 9 = 24 (left: 24)
Answer: ((5 + 5) + 5) + 9 = 24
Input: {input}
Steps:
'''

To avoid drifting away from the expected output we can use guidance to control the number of steps generated and *guide* the model to produce the answer exactly at step 4 as per the prompt.

In [4]:

@guidance
def run_cot(lm, input, steps=3):
    lm = model + cot_prompt.format(input=input)
    i = 0
    while i < steps:
        lm += gen("step", suffix="\n")
        i += 1
    lm += "Answer:" + gen("answer", suffix="\n")
    return lm
        
numbers = '3 3 4 6'
lm = model
lm += run_cot(input=numbers)

Thats... better. Each intermediate step is a valid equation and the "numbers left" are also correct. 

However, the final equation: 
1. Doesn't follow the rules (it contains an additional hallucinated '3'). 
2. It is not valid. 
3. It doesn't result in '24'.

# Tree of Thought Prompting (with Guidance)

ToT requires that first we *define a way to decompose the problem* into a sequence of individual "thoughts" or steps. A natural decomposition in this case can be a pairwise operation as a single "thought" (i.e. `6 - 4 = 2 (left: 3 3 2)`), exactly as we saw in the previous example with CoT. 

We'll use this pattern in the following prompts to generate and evaluate candidate steps.

## Prompts

For generating the list of candidate next "thoughts" we can use a "propose" prompt. In the original paper the prompt asked the model to generate the equation and the numbers left immediately after (as we did in the CoT prompt). However, since we are using a 7B model instead of GPT-4, we might find it useful to bridge the equations to the numbers left with an intermediate step which surrounds the operands in parenthesis. This makes it easier to then correctly generate the numbers left. 

Example:

input = `2 2 1` 

then instead of asking the model to generate:

`2 + 2 = 4 (left: 4 1)`

we can ask it to generate:

`2 + 2 = 4 -> (2 + 2) 1 -> 4 1`

(As we'll see later, guidance allows us to easily store the relevant parts of the output i.e. `2 + 2 = 4` and `4 1` in variables and reformat them as needed to match the desired input for the next step)


In [5]:
propose_prompt = '''Input: 2 8 8 14
Possible next steps:
2 + 8 = 10 -> (8 + 2) 8 14 -> 10 8 14
8 / 2 = 4 -> (8 / 2) 8 14 -> 4 8 14
14 + 2 = 16 -> (14 + 2) 8 8 -> 16 8 8
2 * 8 = 16 -> (2 * 8) 8 14 -> 16 8 14
8 - 2 = 6 -> (8 - 2) 8 14 -> 6 8 14
14 - 8 = 6 -> (14 - 8) 2 8 -> 6 2 8
14 / 2 = 7 -> (14 / 2) 8 8 -> 7 8 8
14 - 2 = 12 -> (14 - 2) 8 8 -> 12 8 8
Input: 1 3 8
Possible next steps:
1 + 3 = 4 -> (1 + 3) 8 -> 4 8
1 * 3 = 3 -> (1 * 3) 8 -> 3 8
3 + 8 = 11 -> (3 + 8) 1 -> 11 1
3 * 8 = 24 -> (3 * 8) 1 -> 24 1
Input: 2 12
Possible next steps:
2 * 12 = 24 -> (2 * 12) -> 24
2 + 12 = 14 -> (2 + 12) -> 14
12 - 2 = 10 -> (12 - 2) -> 10
12 / 2 = 6 -> (12 / 2) -> 6
Input: {input}
Possible next steps:
'''


We can use the following prompt to evaluate the numbers left after each proposed step and assign an evaluation (e.g. 'sure', 'likely' or 'impossible') based on how likely 24 can be reached from them.

In [6]:
value_prompt = '''Evaluate if given numbers can reach 24. The operations allowed are + - * /.
10 14
10 + 14 = 24
sure
11 12
11 + 12 = 23
12 - 11 = 1
11 * 12 = 132
11 / 12 = 0.91
impossible
4 4 10
4 + 4 + 10 = 8 + 10 = 18
4 * 10 - 4 = 40 - 4 = 36
(10 - 4) * 4 = 6 * 4 = 24
sure
4 9 11
9 + 11 + 4 = 20 + 4 = 24
sure
5 7 8
5 + 7 + 8 = 12 + 8 = 20
(8 - 5) * 7 = 3 * 7 = 21
I cannot obtain 24 now, but numbers are within a reasonable range
likely
5 6 6
5 + 6 + 6 = 17
(6 - 5) * 6 = 1 * 6 = 6
I cannot obtain 24 now, but numbers are within a reasonable range
likely
10 10 11
10 + 10 + 11 = 31
(11 - 10) * 10 = 10
10 10 11 are all too big
impossible
1 3 3
1 * 3 * 3 = 9
(1 + 3) * 3 = 12
1 3 3 are all too small
impossible
17
17 is smaller than 24
impossible
24
24 is equal to 24
sure
8
8 is smaller than 24
impossible
26
26 is bigger than 24
impossible
{input}
'''

Once the final equation is generated we can check if it uses each of the input numbers exactly once. For the sake of consistency we can use the same labels as the value prompt above (i.e. 'sure' and 'impossible').

In [7]:
evaluate_operands = '''
Do the numbers in the equation match the numbers in the input? (sure/impossible)
numbers: 2 8 8 14
equation: (2 + 8) * (8 + 14)
numbers frequencies: frequencies('2 8 8 14') = {{2: 1, 8: 2, 14: 1}}
equation frequencies: frequencies('(2 + 8) * (8 + 14)') = {{2: 1, 8: 2, 14: 1}}
frequencies are the same
answer: sure

numbers: {input}
equation: {equation}
'''

Note how in the prompt we reference a function called `frequencies()` which can be passed as a tool to the `gen()` function so the model can use it to reliably count the number of times each number is used in the equation.

## Tools

Since a lot of the prompts rely on mathematical operations, it can be useful to define a calculator tool to reliably evaluate these. For that we first need to define the context free grammars which describe the format of the equations (as described in this project's [readme](../../README.md#context-free-grammars)) then we can pass the tool to the model so it can evaluate arbitrary equations.

In [8]:
from guidance import one_or_more, select, zero_or_more, capture, Tool

@guidance(stateless=True)
def number(lm):
    n = one_or_more(select(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']))
    return lm + select(['-' + n, n])

@guidance(stateless=True)
def operator(lm):
    return lm + select(['+' , '*', '**', '/', '-'])

@guidance(stateless=True)
def expression(lm):
    return lm + select([
        number(),
        expression() + zero_or_more(' ') +  operator() + zero_or_more(' ') +  expression(),
        '(' + expression() + ')'
    ])

@guidance
def calculator(lm):
    expression = lm['tool_args']
    lm += f' {int(eval(expression))}'
    return lm

@guidance(stateless=True)
def calculator_call(lm):
    return lm + capture(expression(), 'tool_args') + ' ='


calculator_tool = Tool(calculator_call(), calculator)


Lets do the same for the frequency tool mentioned before. Note that, in addition to the final equation, this tool also takes the original input numbers (e.g `3 3 4 6`) as input. So we also define a grammar for that.

In [9]:
@guidance(stateless=True)
def four_numbers(lm):
    return lm + number() + ' ' + number() + ' ' + number() + ' ' + number()

@guidance
def frequencies(lm):
    '''Calulate the frequency of each number in the input expression.'''
    import re
    expression = lm['tool_args']
    numbers =  [int(number) for number in re.findall(r'\d+', expression)]
    freqs = {n:numbers.count(n) for n in set(numbers)}
    lm += f' = {freqs}'
    return lm

@guidance(stateless=True)
def frequencies_call(lm):
    '''Tool can be called with a string of 4 numbers or a string with a mathematical expression.'''
    return lm + "frequencies('" + capture(select([expression(),  four_numbers()]), 'tool_args') + "')"

frequency_tool = Tool(frequencies_call(), frequencies)

## Flow

Now that we have all the necessary prompts and tools we can define the actual flow of the ToT strategy. We basically just need to define functions for the *generate* and *evaluate* steps. Since the *selection* step is relatively simple we can just do it inline.

In [10]:
def get_current_numbers(y: str) -> str:
    """Get the numbers that are left to be used in the equation."""
    last_line = y.strip().split('\n')[-1]
    return last_line.split('left: ')[-1].split(')')[0]


def generate_proposals(x, y, n=4):
    """
    Take original input "x" and steps taken so far "y" and concatenate "n" different proposed next steps to the steps taken so far.

    Args:
        x (str): The input string.
        y (str): The steps taken so far.
        n (int, optional): The number of proposals to generate. Defaults to 4.

    Returns:
        if there is only one number left:
            list: A list with a single string containing steps taken so far and the final answer. e.g. [y + 'Answer: 1 * 1 * (2 * 12) = 24']
        otherwise:
            list: A list of "steps taken + proposed next step" for "n" proposals. e.g. [y + proposal_1, y + proposal_2, ...]

    Examples:
        >>> generate_proposals('2 8 8 14', '2 + 8 = 10 (left: 10 8 14)', n=2)
        [
            '2 + 8 = 10 (left: 10 8 14)\n10 + 8 = 18 (left: 18 14)', 
            '2 + 8 = 10 (left: 10 8 14)\n10 * 8 = 80 (left: 80 14)'
        ]

        >>> generate_proposals('1 2 12', '2 * 12 = 24 (left: 1 24)\n1 * 24 = 24 (left: 24)')
        >>> ['2 * 12 = 24 (left: 1 24)\n1 * 24 = 24 (left: 24)\nAnswer: 1 * (2 * 12) = 24']
    """
    current_numbers = get_current_numbers(y if y else x)
    if len(current_numbers.split(' ')) == 1:
        # Use few-shot cot prompt to generate the answer
        prompt = cot_prompt.format(input=x) + y + 'Answer: '
        # Use the calculator tool to guarantee the result of the equation is correct
        lm = model + prompt + gen('answer', suffix='\n', tools=[calculator_tool], max_tokens=40)
        return [y + f'Answer: {lm["answer"]}\n']
    else:
        lm = model + propose_prompt.format(input=current_numbers)
        i = 0
        while i < n:
            # generate the equation and store it in "operation" list. Note that the regex is used to guarantee that a valid equation is generated
            lm += gen('operation', suffix=' -> ', list_append=True, regex=r'\d+\s[-+*/]\s\d+ = \d+') #, tools=[calculator_tool]) 
            # bridge step replacing original numbers with the operation
            lm += gen(suffix=' -> ') 
            # generate the new numbers left by replacing the operation with the result
            lm += gen('numbers_left', suffix='\n', list_append=True)
            i += 1
        # concatenate the steps taken so far with the proposed next steps in a format that matches the cot_prompt
        return [y + op + f' (left: {numbers_left})\n' for op, numbers_left in zip(lm['operation'], lm['numbers_left'])]

@guidance
def evaluate_proposal(lm, input, y):
    """Evaluate how likely it is that the numbers left can reach 24 or that the given answer is correct."""
    last_line = y.strip().split('\n')[-1]
    if 'left: ' in last_line:
        lm += value_prompt.format(input=get_current_numbers(y))
        lm += gen('evaluation', stop_regex='sure|likely|impossible', save_stop_text=True, tools=[calculator_tool])
        evaluation = lm['evaluation_stop_text']
    else:
        equation, result = last_line.lower().replace('answer: ', '').split('=')
        if result.strip() != '24':
            evaluation = 'impossible'
        else:
            lm += evaluate_operands.format(input=input, equation=equation)
            lm += gen('evaluation', stop_regex='sure|impossible', save_stop_text=True, tools=[frequency_tool])
            evaluation = lm['evaluation_stop_text']
    return lm.set('evaluation', evaluation)


Finally we can just *ToT*!

In [12]:
import itertools
model.echo = False
N_PROPOSALS = 3
N_SELECTED = 2
N_STEPS = 4

# assign a value to each evaluation so that we can sort the proposals by value
value_map = {'impossible': 0.01, 'likely': 1, 'sure': 10}
x = '1 1 11 11'
# start with an empty list of proposals
ys = ['']
print(f"INPUT: {x}")
for step in range(N_STEPS):
    print(f"STEP {step}")
    # generate
    new_ys = [generate_proposals(x, y, n=N_PROPOSALS) for y in ys]
    new_ys = list(itertools.chain(*new_ys))
    ids = list(range(len(new_ys)))

    print("Proposals:")
    for p in new_ys:
        print(p)

    # evaluate 
    values = []
    for y in new_ys:
        lm = model + evaluate_proposal(input=x, y=y)
        values.append(value_map[lm['evaluation']])

    print("Values")
    for y, value in zip(new_ys, values):
        print(y, f' ---> {value}')

    # select
    select_ids = sorted(ids, key=lambda x: values[x], reverse=True)[:N_SELECTED]
    select_new_ys = [new_ys[select_id] for select_id in select_ids]
    print("Selected")
    for p in select_new_ys:
        print(p)
    ys = select_new_ys
    print("#" * 50)

print("Final proposals:")
for p in ys:
    print(p)

INPUT: 1 1 11 11
STEP 0
Proposals:
1 + 1 = 2 (left: 2 11 11)

1 + 11 = 12 (left: 12 1 11)

11 + 1 = 12 (left: 12 1 11)

Values
1 + 1 = 2 (left: 2 11 11)
  ---> 10
1 + 11 = 12 (left: 12 1 11)
  ---> 10
11 + 1 = 12 (left: 12 1 11)
  ---> 10
Selected
1 + 1 = 2 (left: 2 11 11)

1 + 11 = 12 (left: 12 1 11)

##################################################
STEP 1
Proposals:
1 + 1 = 2 (left: 2 11 11)
2 + 11 = 13 (left: 13 11)

1 + 1 = 2 (left: 2 11 11)
11 + 2 = 13 (left: 13 11)

1 + 1 = 2 (left: 2 11 11)
11 - 2 = 9 (left: 9 11)

1 + 11 = 12 (left: 12 1 11)
12 + 1 = 13 (left: 13 11)

1 + 11 = 12 (left: 12 1 11)
12 - 1 = 11 (left: 11 11)

1 + 11 = 12 (left: 12 1 11)
12 * 1 = 12 (left: 12 11)

Values
1 + 1 = 2 (left: 2 11 11)
2 + 11 = 13 (left: 13 11)
  ---> 10
1 + 1 = 2 (left: 2 11 11)
11 + 2 = 13 (left: 13 11)
  ---> 10
1 + 1 = 2 (left: 2 11 11)
11 - 2 = 9 (left: 9 11)
  ---> 0.01
1 + 11 = 12 (left: 12 1 11)
12 + 1 = 13 (left: 13 11)
  ---> 10
1 + 11 = 12 (left: 12 1 11)
12 - 1 = 11 (left: 1