`The Art of Prompt Design`

# Tree Of Thoughts 

In [1]:
import guidance
from guidance import models, gen, user, assistant, system
path_to_model = "/Users/samuelrodriguez/models/mistral/mistral-7b-v0.1.Q3_K_M.gguf"
mistral7 = models.LlamaCpp(path_to_model, n_ctx=4096)
model = mistral7

In [2]:
standard_prompt = '''Use numbers and basic arithmetic operations (+ - * /) to obtain 24.
Input: 4 4 6 8
Answer: (4 + 8) * (6 - 4) = 24
Input: 2 9 10 12
Answer: 2 * 12 * (10 - 9) = 24
Input: 4 9 10 13
Answer: (13 - 9) * (10 - 4) = 24
Input: 1 4 8 8
Answer: (8 / 4 + 1) * 8 = 24
Input: 5 5 5 9
Answer: 5 + 5 + 5 + 9 = 24
Input: {input}
'''

# 5-shot
cot_prompt = '''Use numbers and basic arithmetic operations (+ - * /) to obtain 24. Each step, you are only allowed to choose two of the remaining numbers to obtain a new number.
Input: 4 4 6 8
Steps:
4 + 8 = 12 (left: 4 6 12)
6 - 4 = 2 (left: 2 12)
2 * 12 = 24 (left: 24)
Answer: (6 - 4) * (4 + 8) = 24
Input: 2 9 10 12
Steps:
12 * 2 = 24 (left: 9 10 24)
10 - 9 = 1 (left: 1 24)
24 * 1 = 24 (left: 24)
Answer: (12 * 2) * (10 - 9) = 24
Input: 4 9 10 13
Steps:
13 - 10 = 3 (left: 3 4 9)
9 - 3 = 6 (left: 4 6)
4 * 6 = 24 (left: 24)
Answer: 4 * (9 - (13 - 10)) = 24
Input: 1 4 8 8
Steps:
8 / 4 = 2 (left: 1 2 8)
1 + 2 = 3 (left: 3 8)
3 * 8 = 24 (left: 24)
Answer: (1 + 8 / 4) * 8 = 24
Input: 5 5 5 9
Steps:
5 + 5 = 10 (left: 5 9 10)
10 + 5 = 15 (left: 9 15)
15 + 9 = 24 (left: 24)
Answer: ((5 + 5) + 5) + 9 = 24
Input: {input}
Steps:
'''

system_prompt = 'You give answers striclty following the format of the examples below. With no extra words or explanations.'

# Standard Prompt

In [3]:
lm = model + standard_prompt.format(input="3 3 4 6")
lm += gen("answer", suffix="\n")

# Chain of Thought Prompt

In [4]:
@guidance
def run_cot(lm, input, steps=3):
    lm = model + cot_prompt.format(input=input)
    i = 0
    while i < steps:
        lm += gen("step", suffix="\n")
        i += 1
    lm += "Answer:" + gen("answer", suffix="\n")
    return lm
        
numbers = '3 3 4 6'
lm = model
lm += run_cot(input=numbers)

# Tree of Thought 

In [5]:
propose_prompt = '''Input: 2 8 8 14
Possible next steps:
2 + 8 = 10 -> (8 + 2) 8 14 -> 10 8 14
8 / 2 = 4 -> (8 / 2) 8 14 -> 4 8 14
14 + 2 = 16 -> (14 + 2) 8 8 -> 16 8 8
2 * 8 = 16 -> (2 * 8) 8 14 -> 16 8 14
8 - 2 = 6 -> (8 - 2) 8 14 -> 6 8 14
14 - 8 = 6 -> (14 - 8) 2 8 -> 6 2 8
14 / 2 = 7 -> (14 / 2) 8 8 -> 7 8 8
14 - 2 = 12 -> (14 - 2) 8 8 -> 12 8 8
Input: 1 3 8
Possible next steps:
1 + 3 = 4 -> (1 + 3) 8 -> 4 8
1 * 3 = 3 -> (1 * 3) 8 -> 3 8
3 + 8 = 11 -> (3 + 8) 1 -> 11 1
3 * 8 = 24 -> (3 * 8) 1 -> 24 1
Input: 2 12
Possible next steps:
2 * 12 = 24 -> (2 * 12) -> 24
2 + 12 = 14 -> (2 + 12) -> 14
12 - 2 = 10 -> (12 - 2) -> 10
12 / 2 = 6 -> (12 / 2) -> 6
Input: {input}
Possible next steps:
'''

value_prompt = '''Evaluate if given numbers can reach 24. The operations allowed are + - * /.
10 14
10 + 14 = 24
sure
11 12
11 + 12 = 23
12 - 11 = 1
11 * 12 = 132
11 / 12 = 0.91
impossible
4 4 10
4 + 4 + 10 = 8 + 10 = 18
4 * 10 - 4 = 40 - 4 = 36
(10 - 4) * 4 = 6 * 4 = 24
sure
4 9 11
9 + 11 + 4 = 20 + 4 = 24
sure
5 7 8
5 + 7 + 8 = 12 + 8 = 20
(8 - 5) * 7 = 3 * 7 = 21
I cannot obtain 24 now, but numbers are within a reasonable range
likely
5 6 6
5 + 6 + 6 = 17
(6 - 5) * 6 = 1 * 6 = 6
I cannot obtain 24 now, but numbers are within a reasonable range
likely
10 10 11
10 + 10 + 11 = 31
(11 - 10) * 10 = 10
10 10 11 are all too big
impossible
1 3 3
1 * 3 * 3 = 9
(1 + 3) * 3 = 12
1 3 3 are all too small
impossible
17
17 is smaller than 24
impossible
24
24 is equal to 24
sure
8
8 is smaller than 24
impossible
26
26 is bigger than 24
impossible
{input}
'''

evaluate_operands = '''
Do the numbers in the equation match the numbers in the input? (sure/impossible)
numbers: 2 8 8 14
equation: (2 + 8) * (8 + 14)
numbers frequencies: frequencies('2 8 8 14') = {{2: 1, 8: 2, 14: 1}}
equation frequencies: frequencies('(2 + 8) * (8 + 14)') = {{2: 1, 8: 2, 14: 1}}
frequencies are the same
answer: sure

numbers: {input}
equation: {equation}
'''

In [6]:
from guidance import one_or_more, select, zero_or_more, capture, Tool

@guidance(stateless=True)
def number(lm):
    n = one_or_more(select(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']))
    return lm + select(['-' + n, n])

@guidance(stateless=True)
def operator(lm):
    return lm + select(['+' , '*', '**', '/', '-'])

@guidance(stateless=True)
def expression(lm):
    return lm + select([
        number(),
        expression() + zero_or_more(' ') +  operator() + zero_or_more(' ') +  expression(),
        '(' + expression() + ')'
    ])

@guidance(stateless=True)
def four_numbers(lm):
    return lm + number() + ' ' + number() + ' ' + number() + ' ' + number()

@guidance(stateless=True)
def calculator_call(lm):
    return lm + capture(expression(), 'tool_args') + ' ='

@guidance
def calculator(lm):
    expression = lm['tool_args']
    lm += f' {int(eval(expression))}'
    return lm

@guidance(stateless=True)
def frequencies_call(lm):
    '''Tool can be called with a string of 4 numbers or a string with a mathematical expression.'''
    return lm + "frequencies('" + capture(select([expression(),  four_numbers()]), 'tool_args') + "')"

@guidance
def frequencies(lm):
    '''Calulate the frequency of each number in the input expression.'''
    import re
    expression = lm['tool_args']
    numbers =  [int(number) for number in re.findall(r'\d+', expression)]
    freqs = {n:numbers.count(n) for n in set(numbers)}
    lm += f' = {freqs}'
    return lm

calculator_tool = Tool(calculator_call(), calculator)
frequency_tool = Tool(frequencies_call(), frequencies)


In [7]:
def get_current_numbers(y: str) -> str:
    last_line = y.strip().split('\n')[-1]
    return last_line.split('left: ')[-1].split(')')[0]

@guidance
def evaluate_proposal(lm, input, y):
    last_line = y.strip().split('\n')[-1]
    if 'left: ' not in last_line:
        equation, result = last_line.lower().replace('answer: ', '').split('=')
        if result.strip() != '24':
            evaluation = 'impossible'
        else:
            op_lm = lm + evaluate_operands.format(input=input, equation=equation) + gen('result', tools=[frequency_tool], stop_regex='sure|impossible', save_stop_text=True)
            evaluation = op_lm['result_stop_text']
    else:
        lm += value_prompt.format(input=get_current_numbers(y))
        lm += gen('evaluation', stop_regex='sure|likely|impossible', save_stop_text=True, tools=[calculator_tool])
        evaluation = lm['evaluation_stop_text']
    return lm.set('evaluation', evaluation)

def generate_proposals(x, y, n=4):
    current_numbers = get_current_numbers(y if y else x)
    if len(current_numbers.split(' ')) == 1:
        prompt = cot_prompt.format(input=x) + y + 'Answer: '
        lm = model + prompt + gen('answer', suffix='\n', tools=[calculator_tool])
        return [y + f'Answer: {lm["answer"]}\n']
    else:
        lm = model + propose_prompt.format(input=current_numbers)
        i = 0
        while i < n:
            lm += gen('operation', suffix=' -> ', list_append=True, regex='\d+\s[-+*/]\s\d+ = \d+') + gen(suffix=' -> ') #, tools=[calculator_tool]) 
            lm += gen('numbers_left', suffix='\n', list_append=True)
            i += 1
        return [y + op + f' (left: {numbers_left})\n' for op, numbers_left in zip(lm['operation'], lm['numbers_left'])]

In [8]:
import itertools
model.echo = True
N_PROPOSALS = 3
N_SELECTED = 2
value_map = {'impossible': 0.001, 'likely': 1, 'sure': 20}
x = '1 1 11 11'
ys = ['']
print(f"INPUT: {x}")
for step in range(4):
    print(f"STEP {step}")
    # generate
    new_ys = [generate_proposals(x, y, n=N_PROPOSALS) for y in ys]
    new_ys = list(itertools.chain(*new_ys))
    ids = list(range(len(new_ys)))

    print("Proposals:")
    for p in new_ys:
        print(p)

    # evaluate 
    values = []
    for y in new_ys:
        lm = model + evaluate_proposal(input=x, y=y)
        values.append(value_map[lm['evaluation']])

    print("Values")
    for y, value in zip(new_ys, values):
        print(y, f' --->: {value}')

    # select
    select_ids = sorted(ids, key=lambda x: values[x], reverse=True)[:N_SELECTED]
    select_new_ys = [new_ys[select_id] for select_id in select_ids]
    print("Selected")
    for p in select_new_ys:
        print(p)
    ys = select_new_ys
    print("#" * 50)

print("Final proposals:")
for p in ys:
    print(p)

Values
1 + 1 = 2 (left: 2 11 11)
2 + 11 = 13 (left: 13 11)
13 + 11 = 24 (left: 24)
Answer: (1 + 1) + (11 + 11) = 24

  --->: 20
1 + 1 = 2 (left: 2 11 11)
11 + 2 = 13 (left: 13 11)
13 + 11 = 24 (left: 24)
Answer: (1 + 1) + (11 + 11) = 24

  --->: 20
Selected
1 + 1 = 2 (left: 2 11 11)
2 + 11 = 13 (left: 13 11)
13 + 11 = 24 (left: 24)
Answer: (1 + 1) + (11 + 11) = 24


1 + 1 = 2 (left: 2 11 11)
11 + 2 = 13 (left: 13 11)
13 + 11 = 24 (left: 24)
Answer: (1 + 1) + (11 + 11) = 24


##################################################
Final proposals:
1 + 1 = 2 (left: 2 11 11)
2 + 11 = 13 (left: 13 11)
13 + 11 = 24 (left: 24)
Answer: (1 + 1) + (11 + 11) = 24


1 + 1 = 2 (left: 2 11 11)
11 + 2 = 13 (left: 13 11)
13 + 11 = 24 (left: 24)
Answer: (1 + 1) + (11 + 11) = 24


