## Tree of Thoughts for problem solving with large language models

TLDR: This blog post is about using "Tree of Thoughts", a tree-based framework to solve the Game of 24 tasks with a large language model.

Tree of Thoughts (ToT) is a framework used by LLMs to solve complex reasoning problems. The intermediate steps in a reasoning process are split into “thoughts”, with the ToT algorithm encouraging exploration of these thoughts through search algorithms.


### 1. Load Model

We'll be using Hugging face ```transformers``` to generate text with our LLMs. First, we start off by importing the necessary libraries.

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM

We'll use the popular open-source language model, Mistral-7B. We can load the model and the tokenizer by:



In [None]:
model_id = "mistralai/Mistral-7B-v0.3"
model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)

Next, we create a function called ```mistral``` which we'll use to feed in our prompts and receive completions.

In [None]:
def mistral(prompt):
    inputs = tokenizer(prompt, return_tensors="pt")
    outputs = model.generate(**inputs, max_new_tokens=20)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

mistral("Hi! My name is ")

Alternative, we can also use OpenAI's GPT-3.5/4 models. We can load the model by:

In [54]:
import os
import openai
from openai import OpenAI
from constants import OPENAI_API_KEY

api_key = os.getenv("OPENAI_API_KEY", OPENAI_API_KEY)

if api_key != "":
    openai.api_key = OPENAI_API_KEY
else:
    print("Warning: OPENAI_API_KEY is not set")

client = OpenAI(api_key=api_key)

In [55]:
global response

def gpt(prompt, model="gpt-4", temperature=0.7, max_tokens=1000, n=1, stop=None) -> list:
    
    messages = [{"role": "user", "content": prompt}]
    
    outputs = []

    res = client.chat.completions.create(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, n=n, stop=stop)
    response = res

    for choice in res.choices:
        outputs.extend([choice.message.content])

    return outputs 

### 2. Implementing Tree of Thoughts (ToT) 

ToT can be broken down into 4 key steps:

(a) Thought Decomposition

(b) Thought Generation
- In this step, the LLM is prompted to generate thoughts by either one of two ways:
    - Sample: The thoughts are generated by sampling i.i.d thoughts from a Chain of Thought prompt.
    - Propose: The thoughts are propsed sequentially depending on the previous prompts. 

(c) Thought Evaluation
- The LLMs are prompted to evaluate the thoughts generated in the previous step, by either: 
    - Value:
    - Vote:   

(d) Search Algorithm

.


In this tutorial, we'll be using ToT with Mistral to solve the Game of 24.

The Game of 24 is a task where given a sequence of 4 numbers, we’ll need to find the correct mathematical operations (add, subtract, multiply, divide) that’ll lead to the number 24. For example, if the sequence is {4, 9, 10, 13}, the correct operations using the 4 numbers are: (10 - 4) * (13 - 9) = 24. Each number in the sequence can only be used once.


In [56]:
# 5-shot
standard_prompt = '''Use numbers and basic arithmetic operations (+ - * /) to obtain 24.
Input: 4 4 6 8
Answer: (4 + 8) * (6 - 4) = 24
Input: 2 9 10 12
Answer: 2 * 12 * (10 - 9) = 24
Input: 4 9 10 13
Answer: (13 - 9) * (10 - 4) = 24
Input: 1 4 8 8
Answer: (8 / 4 + 1) * 8 = 24
Input: 5 5 5 9
Answer: 5 + 5 + 5 + 9 = 24
Input: {input}
'''

# 5-shot
cot_prompt = '''Use numbers and basic arithmetic operations (+ - * /) to obtain 24. Each step, you are only allowed to choose two of the remaining numbers to obtain a new number.
Input: 4 4 6 8
Steps:
4 + 8 = 12 (left: 4 6 12)
6 - 4 = 2 (left: 2 12)
2 * 12 = 24 (left: 24)
Answer: (6 - 4) * (4 + 8) = 24
Input: 2 9 10 12
Steps:
12 * 2 = 24 (left: 9 10 24)
10 - 9 = 1 (left: 1 24)
24 * 1 = 24 (left: 24)
Answer: (12 * 2) * (10 - 9) = 24
Input: 4 9 10 13
Steps:
13 - 10 = 3 (left: 3 4 9)
9 - 3 = 6 (left: 4 6)
4 * 6 = 24 (left: 24)
Answer: 4 * (9 - (13 - 10)) = 24
Input: 1 4 8 8
Steps:
8 / 4 = 2 (left: 1 2 8)
1 + 2 = 3 (left: 3 8)
3 * 8 = 24 (left: 24)
Answer: (1 + 8 / 4) * 8 = 24
Input: 5 5 5 9
Steps:
5 + 5 = 10 (left: 5 9 10)
10 + 5 = 15 (left: 9 15)
15 + 9 = 24 (left: 24)
Answer: ((5 + 5) + 5) + 9 = 24
Input: {input}
'''

# 1-shot
propose_prompt = '''Input: 2 8 8 14
Possible next steps:
2 + 8 = 10 (left: 8 10 14)
8 / 2 = 4 (left: 4 8 14)
14 + 2 = 16 (left: 8 8 16)
2 * 8 = 16 (left: 8 14 16)
8 - 2 = 6 (left: 6 8 14)
14 - 8 = 6 (left: 2 6 8)
14 /  2 = 7 (left: 7 8 8)
14 - 2 = 12 (left: 8 8 12)
Input: {input}
Possible next steps:
'''

value_prompt = '''Evaluate if given numbers can reach 24 (sure/likely/impossible)
10 14
10 + 14 = 24
sure
11 12
11 + 12 = 23
12 - 11 = 1
11 * 12 = 132
11 / 12 = 0.91
impossible
4 4 10
4 + 4 + 10 = 8 + 10 = 18
4 * 10 - 4 = 40 - 4 = 36
(10 - 4) * 4 = 6 * 4 = 24
sure
4 9 11
9 + 11 + 4 = 20 + 4 = 24
sure
5 7 8
5 + 7 + 8 = 12 + 8 = 20
(8 - 5) * 7 = 3 * 7 = 21
I cannot obtain 24 now, but numbers are within a reasonable range
likely
5 6 6
5 + 6 + 6 = 17
(6 - 5) * 6 = 1 * 6 = 6
I cannot obtain 24 now, but numbers are within a reasonable range
likely
10 10 11
10 + 10 + 11 = 31
(11 - 10) * 10 = 10
10 10 10 are all too big
impossible
1 3 3
1 * 3 * 3 = 9
(1 + 3) * 3 = 12
1 3 3 are all too small
impossible
{input}
'''

value_last_step_prompt = '''Use numbers and basic arithmetic operations (+ - * /) to obtain 24. Given an input and an answer, give a judgement (sure/impossible) if the answer is correct, i.e. it uses each input exactly once and no other numbers, and reach 24.
Input: 4 4 6 8
Answer: (4 + 8) * (6 - 4) = 24
Judge: 
sure
Input: 2 9 10 12
Answer: 2 * 12 * (10 - 9) = 24
Judge: 
sure
Input: 4 9 10 13
Answer: (13 - 9) * (10 - 4) = 24
Judge: 
sure
Input: 4 4 6 8
Answer: (4 + 8) * (6 - 4) + 1 = 25
Judge: 
impossible
Input: 2 9 10 12
Answer: 2 * (12 - 10) = 24
Judge: 
impossible
Input: 4 9 10 13
Answer: (13 - 4) * (10 - 9) = 24
Judge: 
impossible
Input: {input}
Answer: {answer}
Judge:'''

Next, we'll start implementing our ToT algorithm. We'll define a function for each core part of the ToT algorithm.



First, we'll define functions necessary for "Thought Generation".

In [57]:
def get_current_numbers(y: str) -> str:
    last_line = y.strip().split('\n')[-1]
    return last_line.split('left: ')[-1].split(')')[0]

In [47]:
# def propose_prompt_wrap(x: str, y: str='') -> str:
#     current_numbers = get_current_numbers(y if y else x)
#     if current_numbers == '24':
#         prompt = cot_prompt.format(input=x) + 'Steps:' + y
#         # print([prompt])
#     else:
#         prompt = propose_prompt.format(input=current_numbers)
#     return prompt
    

# Generation
def generate_thoughts(prompt):
    
    current_numbers = get_current_numbers(prompt) # current_numbers = get_current_numbers(y if y else x)
    prompt = propose_prompt.format(input=current_numbers)
    
    thoughts = gpt(prompt)[0].split('\n')

    return thoughts


def get_proposals(task, x, y): 
    propose_prompt = task.propose_prompt_wrap(x, y)
    proposals = gpt(propose_prompt, n=1, stop=None)[0].split('\n')
    return [y + _ + '\n' for _ in proposals]


In [59]:
num_of_steps = 1
thoughts = ['']

for _ in range(0, num_of_steps):
    thoughts = generate_thoughts(thoughts)
    print('Thoughts: ', thoughts)

PROMPT:  ['']


AttributeError: 'list' object has no attribute 'strip'

In [None]:
import itertools
num_of_steps = 1
thoughts = ['']

 new_ys = [get_proposals(task, x, y) for y in ys]
        # new_ys = list(itertools.chain(*new_ys))
        # print('FINAL YS: ', new_ys)
        
for step in range(0, num_of_steps):
    
    # Thought Generation
    thoughts = generate_thoughts(thoughts)
    #thoughts = list(itertools.chain(*thoughts))
    ids = list(range(len(thoughts)))

    print('THOUGHTS: ', thoughts)

    # Thought evaluation
    #values = evaluate_thoughts(thoughts)

    # Search algorithm



Next, we'll create the functions necessary for "Thought Evaluation", where each of the thoughts are evaluated by the LLM.

In [None]:
def get_current_numbers(y: str) -> str:
    last_line = y.strip().split('\n')[-1]
    return last_line.split('left: ')[-1].split(')')[0]

# def value_prompt_wrap(x: str, y: str) -> str:
#     last_line = y.strip().split('\n')[-1]
#     if 'left: ' not in last_line:  # last step
#         ans = last_line.lower().replace('answer: ', '')
#         return value_last_step_prompt.format(input=x, answer=ans)
#     current_numbers = get_current_numbers(y)
#     return value_prompt.format(input=current_numbers) # This replaces the input term


def get_value(thought):
    
    current_numbers = get_current_numbers(thought)
    value_prompt = value_prompt.format(input=current_numbers)
    value_outputs = mistral(value_prompt)
    
    value_names = [_.split('\n')[-1] for _ in value_outputs]
    value_map = {'impossible': 0.001, 'likely': 1, 'sure': 20}  
    
    for name, value in value_map.items():
        value = sum(value * value_names.count(name))

    return value


def evaluate_thoughts(thoughts):
    
    values = []

    for thought in thoughts:
        value = get_value(thought)

    values.append(value)


Finally, we'll implement the "Search Algorithm" which will be used to search through the thoughts generated by the LLM.

In [9]:
# TODO: Implement search algorithm

### 3. Run ToT with sample data

We'll test our implementation with some sample data i.e the sequence 4 5 6 10. If ToT works sucessfully, it should output the operations that can be performed to reach 24.

In [None]:
data = "4 5 6 10"

In [52]:
# def propose_prompt_wrap(x: str, y: str='') -> str:
#     current_numbers = get_current_numbers(y if y else x)
#     if current_numbers == '24':
#         prompt = cot_prompt.format(input=x) + 'Steps:' + y
#         # print([prompt])
#     else:
#         prompt = propose_prompt.format(input=current_numbers)
#     return prompt
    

# Generation
def generate_thoughts(prompt):
    for p in prompt:
        
    
    print('PROMPT: ', prompt)
    current_numbers = get_current_numbers(prompt)
    prompt = propose_prompt.format(input=current_numbers)
    
    thoughts = gpt(prompt)[0].split('\n')

    return thoughts


def get_proposals(task, x, y): 
    propose_prompt = task.propose_prompt_wrap(x, y)
    proposals = gpt(propose_prompt, n=1, stop=None)[0].split('\n')
    return [y + _ + '\n' for _ in proposals]


In [53]:
import itertools
num_of_steps = 1
thoughts = ['']

 new_ys = [get_proposals(task, x, y) for y in ys]
        # new_ys = list(itertools.chain(*new_ys))
        # print('FINAL YS: ', new_ys)
        
for step in range(0, num_of_steps):
    
    # Thought Generation
    thoughts = generate_thoughts(thoughts)
    #thoughts = list(itertools.chain(*thoughts))
    ids = list(range(len(thoughts)))

    print('THOUGHTS: ', thoughts)

    # Thought evaluation
    #values = evaluate_thoughts(thoughts)

    # Search algorithm



PROMPT:  []


AttributeError: 'list' object has no attribute 'strip'

In [None]:
# import itertools

# num_of_steps = 4

# for step in num_of_steps:
#     thoughts = 
    
#     # Generation (Propose / Sample)
#     new_ys = [get_proposals(x, y) for y in ys]
#     new_ys = list(itertools.chain(*new_ys))
#     ids = list(range(len(new_ys)))

#     # Evaluation (Value / Vote)
#     values = get_values(task, x, new_ys, args.n_evaluate_sample)
    
#     # Selection (Sample/Greedy)
#     select_ids = sorted(ids, key=lambda x: values[x], reverse=True)[:args.n_select_sample]
#     select_new_ys = [new_ys[select_id] for select_id in select_ids]
    
#     #infos.append({'step': step, 'x': x, 'ys': ys, 'new_ys': new_ys, 'values': values, 'select_new_ys': select_new_ys})
#     ys = select_new_ys
