# Tree of Thoughts


The Tree of Thoughts (ToT) algorithm combines Large Language Models (LLMs) and heuristic search , as presented in a [paper](https://arxiv.org/pdf/2305.10601.pdf) by Princeton University and Google DeepMind. 

The Tree of Thoughts (ToT) algorithm aims to address limitations of the Chain of Thoughts (CoT) approach, ToT enables LLMs to evaluate their thinking at each stage and abandon inefficient approaches, exploring alternative methods as shown in the figure below:

![Image adopted from the original paper](../figs/TOT.png)

In this example we will implement the ToT algorithm to solve a classical problem:  
> Given an input of 4 numbers. How to use these 4 numbers and basic arithmetic operations (+-*/) to obtain 24 in 1 equation

A set of 24 Game Puzzles can be downloaded from [here](https://github.com/princeton-nlp/tree-of-thought-llm/blob/master/src/tot/data/24/24.csv)

In [1]:
import pandas as pd
puzzles_df = pd.read_csv('puzzles/24.csv')
puzzles_df.sample(10)

Unnamed: 0,Rank,Puzzles,AMT (s),Solved rate,1-sigma Mean (s),1-sigma STD (s)
8,9,2 2 10 10,4.85,98.20%,5.13,1.63
208,209,5 5 11 13,6.15,96.10%,6.75,2.52
543,544,2 2 7 8,7.53,93.70%,7.88,2.98
437,438,1 2 3 5,7.1,95.40%,7.39,2.38
99,100,5 5 7 7,5.65,97.80%,5.97,1.91
220,221,1 4 6 8,6.21,94.40%,6.57,2.52
182,183,3 9 9 9,6.04,97.60%,6.47,2.16
541,542,1 2 2 4,7.52,95.50%,7.83,2.44
1054,1055,4 4 6 9,12.24,77.40%,10.48,5.57
1148,1149,3 3 9 10,14.94,78.10%,13.6,6.1


LLMs does not perform well on this task.
For example, given 3 4 4 7.
Human being can easily find the solution as ((7-4)+3)*4 = 24. Let's test whether advanced LLM can solve this puzzle.

In [5]:
from agentscope.agents import DialogAgent
import agentscope
MODEL_CONFIGURATION = {
    "config_name": "gpt-4",
    "model_type": "openai_chat",
    "model_name": "gpt-4",
    "api_key": "YOUR_API_KEY",
    "generate_args": {
        "temperature": 0.5
    }}
agentscope.init(model_configs=MODEL_CONFIGURATION)

[]

In [7]:
PUZZLE = "Input number is 3 4 4 7"
QUESTION_PROMPT = f"""
    Given an input of 4 numbers
    you need to use these 4 numbers and basic arithmetic operations (+-*/) to obtain 24 in 1 equation
"""
agent = DialogAgent(
    name="assistant",
    model_config_name="gpt-4",
    sys_prompt=QUESTION_PROMPT,
)
from agentscope.message import Msg 
question = Msg(name="user", content=PUZZLE)
res = agent(question)
print(res.content)

An equation to obtain 24 using the numbers 3, 4, 4, and 7 could be:

4 * 4 * 3 - 7 = 24


This is clearly not correct. Now let's implemet a ToT algorithm to solve this.

In [8]:
THINKER_PROMPT =(
"You're an TreeofThoughts, an superintelligent AI model devoted to helping Humans by any means necessary. "
"You're purpose is to generate a series of solutions to comply with the user's instructions, you must generate solutions on the basis of determining the most reliable solution in the shortest amount of time, while taking rejected solutions into account and learning from them."
"  Considering the following question: "
"{}" 
"Thinking Step by Steps to give your choice on possible next step to solve the problem."
"You can put any arithmetic operations between any two numbers, the intermediate results can also be negative or fractions."
"Try to think of all possible next steps to solve the problem, but choose the most reliable solution in the shortest amount of time."
" For example, if input is 2 8 8 14, possible next stepes could be: "
" 2 + 8 = 10 (left: 10 14) "
" 8 / 2 = 4 (left: 8 14) "
" 14 + 2 = 16 (left: 8 8) "
" 2 * 8 = 16 (left: 8 14 ) "
" 8 - 2 = 6 (left: 8 14) "
" 14 - 8 = 6 (left: 2 8) "
" 14 /  2 = 7 (left: 8 8) "
" 14 - 2 = 12 (left: 8 8)" 
" 2 - 8 = -6 (left: 8 14) "
" 8 - 14 = -6 (left: 2 8) "
" and so on. "
"Please try to think of all possible next steps to solve the problem and choose the most reliable solution in the shortest amount of time."
" Response in the following format that can be loaded by python json.loads()" 
"{{\n"
        '    "state": possible next steps\n in the format of "number1 operation number2 = result (left number to use)"\n'
        '    "thought": "thought summary to say to others"\n'
        "}}"
)

In [9]:
EVALUATOR_PROMPT = (
    "Consider the following question: {}."
    "To achieve the goal, pessimistically value the context of the past solutions and more importantly the latest generated solution you had AS A FLOAT BETWEEN 0 AND 1\n."
    "If the solutions is not directly concretely making fast progress in achieving the goal, give it a lower score."
    "Evaluate all solutions AS A FLOAT BETWEEN 0 and 1:\n,  DO NOT RETURN ANYTHING ELSE."
    "Response in the following format that can be loaded by python json.loads()"
"{{\n"
        '"score": The score of the solution'
"}}"
)


In [10]:
from agentscope.agents import DictDialogAgent
thinker = DictDialogAgent(
    name="thinker",
    model_config_name="gpt-4",
    sys_prompt=THINKER_PROMPT.format(QUESTION_PROMPT),
)
evaluator = DictDialogAgent(
    name="evaluator",
    model_config_name="gpt-4",
    sys_prompt=EVALUATOR_PROMPT.format(QUESTION_PROMPT),
)

- The thinker agent generates possible new moves."
- The evaluator agent assesses each move and assigns it a score based on heuristics."
- We also require a tree structure to store this information and consistently select the nodes with the highest scores for further exploration."

In [11]:
import heapq
class Node:
    def __init__(self, state, heuristic_value=None, parent=None):
        self.state = state
        self.parent = parent
        self.heuristic_value = heuristic_value

    def __lt__(self, other):
        return self.heuristic_value > other.heuristic_value


In [12]:
def generate_possible_states(current_state):
    prompt = "The current state is: {}, what is the next step?".format(current_state)
    msg = Msg(name="user", content=prompt)
    res = thinker(msg)
    states = res.content["state"]
    print('Exploring states:', states)
    return states

In [13]:
def evaluate_state(state):
    prompt = "Current state is {}, evaluate the state.".format(state)
    msg = Msg(name="thinker", content=prompt)
    res = evaluator(msg)
    print('Evaluating state:', state, res.content["score"])
    return res.content["score"]

In [14]:
import re
def goal_test(state):
    pattern = r'=\s*([+-]?\d+(?:\.\d+)?)(?=[^\(]*\()'
    res = re.findall(pattern, state)
    res = float(res[0])
    print('Goal test:', res)
    return res == 24.

In [15]:
def heuristic_search(initial_state):
    open_set = []
    visited = set()
    initial_states = generate_possible_states(initial_state)
    for state in initial_states:
        heuristic_value = evaluate_state(state)
        node = Node(state, heuristic_value)
        heapq.heappush(open_set, node)
    
    while open_set:
        current_node = heapq.heappop(open_set)
        current_state = current_node.state

        if goal_test(current_state):
            return reconstruct_path(current_node)

        visited.add(current_state)

        for next_state in generate_possible_states(current_state):
            if next_state not in visited:
                heuristic_value = evaluate_state(next_state)
                next_node = Node(next_state, heuristic_value, current_node)
                heapq.heappush(open_set, next_node)

    return None

def reconstruct_path(node):
    path = []
    while node:
        path.append(node.state)
        node = node.parent
    return path[::-1]

In [16]:
initial_state = "3 4 4 7"
heuristic_search(initial_state)

Exploring states: ['3 + 4 = 7 (left: 4 7)', '4 + 4 = 8 (left: 3 7)', '7 + 3 = 10 (left: 4 4)', '3 * 4 = 12 (left: 4 7)', '4 * 4 = 16 (left: 3 7)', '7 * 3 = 21 (left: 4 4)', '3 - 4 = -1 (left: 4 7)', '4 - 4 = 0 (left: 3 7)', '7 - 3 = 4 (left: 4 4)', '3 / 4 = 0.75 (left: 4 7)', '4 / 4 = 1 (left: 3 7)', '7 / 3 = 2.3333 (left: 4 4)']
Evaluating state: 3 + 4 = 7 (left: 4 7) 0.1
Evaluating state: 4 + 4 = 8 (left: 3 7) 0.15
Evaluating state: 7 + 3 = 10 (left: 4 4) 0.2
Evaluating state: 3 * 4 = 12 (left: 4 7) 0.3
Evaluating state: 4 * 4 = 16 (left: 3 7) 0.35
Evaluating state: 7 * 3 = 21 (left: 4 4) 0.4
Evaluating state: 3 - 4 = -1 (left: 4 7) 0.05
Evaluating state: 4 - 4 = 0 (left: 3 7) 0
Evaluating state: 7 - 3 = 4 (left: 4 4) 0.1
Evaluating state: 3 / 4 = 0.75 (left: 4 7) 0.05
Evaluating state: 4 / 4 = 1 (left: 3 7) 0.05
Evaluating state: 7 / 3 = 2.3333 (left: 4 4) 0.1
Goal test: 21.0


---