# Language Agent Tree Search


The basic idea is to combine reflection/evaluation and search (specifically monte-carlo trees search) to get better overall performance.

As with all search-based methods, you trade off some inference-time compute and execution time for quality, and you can use langsmith them to extract the best trajectory for finetuning to reduce the need for as expensive search in the future..

#### Expand and simulate
1. Execute N actions/tasks
2. Get observations from tasks
3. Evaluate the response -> give value of action V
4. Update global state with the next step from this rollout
#### Reflect
For all terminal states (final responses), 
1. Get the reward
2. Assign `c` = reflection IFF reward isn't a 'sucess'
#### Select
(from (Upper Confidence bounds applied to Trees))
1. next action is the one that maximizes the value of V + w sqrt(ln (counts(prev state) )/counts(current state))


#### "Backgpropagate"

```
V(s) = (V_{old}(s)(N(s) - 1)+r) / N(s)
```
Nodes in the path are equivalent to the trajectories.

Then after the backpropagation, we pick the 

## 0. Prerequisites

Install `langgraph` (for the framework), `langchain_openai` (for the LLM), and `langchain` + `tavily-python` (for the search engine).

We will use tavily search as a tool. You can get an API key [here](https://app.tavily.com/sign-in) or replace with a different tool of your choosing.

In [1]:
# %pip install -U --quiet  langchain langgraph langchain_openai
# %pip install -U --quiet tavily-python

In [126]:
import getpass
import os


def _set_if_undefined(var: str) -> None:
    if os.environ.get(var):
        return
    os.environ[var] = getpass.getpass(var)


# Optional: Configure tracing to visualize and debug the agent
_set_if_undefined("LANGCHAIN_API_KEY")
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_PROJECT"] = "Reflexion"

_set_if_undefined("OPENAI_API_KEY")
_set_if_undefined("TAVILY_API_KEY")

## Minimal example

Either do code generation or search. Maybe let's just to web research since we don't want to do full code rollouts.

In [127]:
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.utilities.tavily_search import TavilySearchAPIWrapper

search = TavilySearchAPIWrapper()
tavily_tool = TavilySearchResults(api_wrapper=search, max_results=5)

In [128]:
from collections import defaultdict
from typing import List

from langchain.output_parsers.openai_tools import (
    JsonOutputToolsParser,
    PydanticToolsParser,
)
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage

from langgraph.prebuilt.tool_executor import ToolExecutor, ToolInvocation

In [129]:
import datetime
from typing import List

from langchain_core.prompt_values import ChatPromptValue
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.pydantic_v1 import BaseModel, Field, ValidationError
from langchain_core.runnables import RunnableConfig
from langchain_openai import ChatOpenAI
from langsmith import traceable

prompt_template = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "You are an AI assistant.",
        ),
        MessagesPlaceholder(variable_name="messages"),
    ]
).partial(
    time=lambda: datetime.datetime.now().isoformat(),
)


llm = ChatOpenAI(model="gpt-4-turbo-preview")

initial_answer_chain = prompt_template | llm

In [130]:
# This generates N candidate values
# for a single input
def generate_candidates(messages: ChatPromptValue, config: RunnableConfig):
    n = config["configurable"].get("N", 5)
    chat_result = llm.generate([messages.to_messages()], n=n)
    return [AIMessage(content=gen.text) for gen in chat_result.generations[0]]


expansion_chain = prompt_template | generate_candidates

In [133]:
from langchain.chains import create_structured_output_runnable


class Reflection(BaseModel):
    reflections: str = Field(
        description="The critique and reflections on the sufficiency, superfluency,"
        " and general quality of the response"
    )
    score: int = Field(
        description="Score from 1-10 on the quality of the candidate response."
    )
    found_solution: bool = Field(
        description="Whether the response has fully solved the question or task."
    )


prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "Reflect and grade the assistant response to the user question below.",
        ),
        MessagesPlaceholder(variable_name="input"),
        MessagesPlaceholder(variable_name="candidate"),
        (
            "system",
            "Reflect on the assistant response above, critique, and score the response.",
        ),
    ]
)
reflection_chain = create_structured_output_runnable(output_schema=Reflection, llm=llm)

In [118]:
from typing import Optional


class Node:
    def __init__(
        self,
        solution: AIMessage,
        parent: Optional[Node] = None,
        reflection: Optional[Reflection] = None,
    ):
        self.solution = solution
        self.parent = parent
        self.children = []
        self.value = 0
        self.visits = 0
        self.reflection = reflection
        self._is_solved = reflection.found_solution
        if self._is_solved:
            self._mark_tree_as_solved()

    @property
    def is_solved(self):
        return self._is_solved

    def _mark_tree_as_solved(self):
        parent = self.parent
        while parent:
            parent._is_solved = True
            parent = parent.parent

    def add_reflection(self, reflection: Reflection):
        if self.reflection is not None:
            raise ValueError("Cannot overwrite exisitng reflection")
        self.reflection = reflection
        if reflection.is_solved:
            self._is_solved = reflection.is_solved
            self._mark_tree_as_solved()

    def uct(self, exploration_weight=1.0):
        if self.visits == 0:
            return self.value
        return (self.value / self.visits) + exploration_weight * math.sqrt(
            math.log(self.parent.visits) / self.visits
        )

    def best_child(self):
        if not self.children:
            return None
        return max(self.children, key=lambda child: child.uct())

    def best_child_value(self):
        if not self.children:
            return None
        return max(self.children, key=lambda child: child.value)

    def update(self, reward: float):
        self.visits += 1
        self.value += reward

    def get_messages(self) -> List[BaseMessage]:
        messages = []
        parent = self.parent
        while parent:
            messages.extend(
                [
                    HumanMessage(content=self.reflection),
                    self.solution,
                ]
            )
        return messages[::-1]  # root solution, reflection, child 1, ...

    @property
    def max_depth(self) -> int:
        if self.children:
            return 1 + max([child.max_depth for child in self.children])
        return 1

In [90]:
# def get_score():
#     (self.value / self.visits) + exploration_weight * math.sqrt(
#         math.log(self.parent.visits) / self.visits
#     )

In [None]:
n=5 # number of generated actions (beam width)
w=1 # Exploration weight. 
L = 5 # Depth limit / max rollout
K = 3 # Number of rollouts
c.. # context

```
State:
- root

Nodes:
# first_attempt: llm -> generate code
# Rollout
 - Selection
 - Expansion
# Reflection / Simulation
# Backpropagation
```

In [138]:
import operator
from collections import deque

from typing_extensions import Annotated, TypedDict

from langgraph.graph import StateGraph


class TreeState(TypedDict):
    # The full tree
    root: Node
    # The rolled-out steps so far
    input: str
    solution: AIMessage


builder = StateGraph(TreeState)


def start(state: TreeState) -> dict:
    user_input = [HumanMessage(content=state["input"])]
    res = initial_answer_chain.invoke({"messages": user_input})
    reflection = reflection_chain.invoke({"input": user_input, "candidate": [res]})
    root = Node(res, reflection)
    return {
        **state,
        "root": root,
    }


def expand(state: TreeState, config: RunnableConfig) -> dict:
    root = state["root"]
    user_input = [HumanMessage(content=state["input"])]
    best_candidate: Node = root.best_child()
    messages = user_input + best_candidate.get_messages()
    # Generate N candidates from the single child candidate
    new_candidates = expansion_chain.invoke({"messages": messages}, config)
    # Reflect on each candidate
    # For tasks with external validation, you'd add that here.
    reflections = reflection_chain.batch(
        [[{"input": user_input, "candidate": [msg]}] for msg in new_candidates],
        config,
    )
    # Grow tree
    child_nodes = [
        Node(cand, reflection) for cand, reflection in zip(new_candidates, reflections)
    ]
    best_candidate.children.extend(child_nodes)
    # We have already extended the tree directly, so we just return the state
    return state


def select_solution(state: TreeState):
    all_nodes = []
    nodes = deque()
    nodes.append(state["root"])
    while nodes:
        node = nodes.popleft()
        all_nodes.extend(node.children)
        for n in node.children:
            nodes.append(n)
    # TODO: Diff between value and reward?
    best_node = max(all_nodes, key=lambda node: node.value)
    return {**state, "solution": best_node.solution}


def should_loop(state: TreeState):
    root = state["root"]
    if root.is_solved:
        return "select_solution"
    if root.max_depth() > 5:
        return "select_solution"
    return "expand"


builder.add_node("start", start)
builder.add_node("expand", expand)
builder.add_node("select_solution", select_solution)
builder.set_entry_point("start")


builder.add_conditional_edges(
    "start",
    # Either expand/rollout or finish
    should_loop,
)
builder.add_conditional_edges(
    "expand",
    # Either continue to rollout or finish
    should_loop,
)

builder.set_finish_point("select_solution")
graph = builder.compile()

In [140]:
graph.invoke({"input": "what's the capital of ninevah"})

ValueError: Invalid input type <class 'dict'>. Must be a PromptValue, str, or list of BaseMessages.

In [123]:
# with langsmith.trace("mcts", inputs={"item": item}) as trace:
#             if is_leetcode:
#                 tests_i = item["visible_tests"]
#             else:
#                 tests_i = gen.internal_tests(item["prompt"], test_model, 6)

#             with langsmith.trace(
#                 "first_attempt", inputs={"item": item}
#             ) as trace:
#                 while cur_func_impl is None:
#                     cur_func_impl = gen.func_impl(item["prompt"], model, "simple")
#                 trace.end(outputs={"cur_func_impl": cur_func_impl})
#                 root = Node(cur_func_impl)  # initial solution (for pass@1 metric)

#                 # Lists for logging
#                 reflections = []
#                 implementations = []
#                 test_feedback = []
#                 is_solved = False

#                 # first attempt

#                 implementations.append(cur_func_impl)
#                 assert isinstance(cur_func_impl, str)
#                 is_passing, feedback, _ = exe.execute(cur_func_impl, tests_i)
#                 trace.end(outputs={"is_passing": is_passing, "feedback": feedback})
#             test_feedback.append(feedback)
#             with langsmith.trace(
#                 "self_reflection", inputs={"fun": cur_func_impl}
#             ) as trace:
#                 reflection = gen.self_reflection(cur_func_impl, feedback, model)
#                 trace.end(outputs={"reflection": reflection})
#             reflections += [reflection]
#             root.test_feedback = feedback
#             root.reflection = reflection

#             for cur_iter in range(max_iters):
#                 # Selection

#                 node = root
#                 trajectory = {"solutions": [], "feedbacks": []}

#                 while node.children:
#                     node = node.best_child()
#                     trajectory["solutions"].append(node.solution)

#                 # Expansion
#                 for _ in range(n):
#                     new_solution = None
#                     strategy = "mcts"
#                     prev_func_impl = node.solution
#                     feedback = node.test_feedback
#                     reflection = node.reflection
#                     acc_feedback, acc_reflection = gather_context_from_tree(node)
#                     with langsmith.trace(
#                         f"expansion-{_}", inputs={"func_sig": item["prompt"], "model": model}
#                     ) as trace:
#                         while new_solution is None:

#                             new_solution = gen.func_impl(
#                                 # func_sig=item["prompt"],
#                                 # model=model,
#                                 # strategy=strategy,
#                                 prev_func_impl=prev_func_impl,
#                                 feedback=feedback,
#                                 self_reflection=reflection,
#                                 # is this the stuff that's unique?
#                                 acc_feedback=acc_feedback,
#                                 acc_reflection=acc_reflection,
#                             )

#                     combined_context = "\nPrevious Trial\n\n" + new_solution

#                     child = Node(
#                         new_solution,
#                         parent=node,
#                         context=combined_context,
#                         depth=node.depth + 1,
#                     )
#                     node.children.append(child)

#                     # Simulation
#                     reward_real = 0
#                     for child in node.children:
#                         is_passing_internal, feedback_internal, _ = exe.execute(
#                             child.solution, tests_i
#                         )
#                         if not is_passing_internal:
#                             reflection = gen.self_reflection(
#                                 child.solution, feedback_internal, model
#                             )
#                             reflections.append(reflection)
#                             child.reflection = reflection
#                             child.test_feedback = feedback_internal
#                             child.context += (
#                                 "\n\nPrevious Trial\n\n"
#                                 + child.solution
#                                 + "\n\nTest results: \n"
#                                 + feedback_internal
#                                 + "\n\nSelf-reflection: "
#                                 + reflection
#                             )
#                         else:
#                             child.context += (
#                                 "\n\nPrevious Trial\n\n"
#                                 + child.solution
#                                 + "\n\nTest results: \n"
#                                 + feedback_internal
#                             )
#                             child.reflection = ""
#                             child.test_feedback = feedback_internal

#                         if "Tested passed:" in feedback_internal:
#                             # Split at "Tests failed:" and get the part before it (which contains the passed tests)
#                             passed_section = feedback_internal.split("Tests failed:")[0]
#                             # Split at "Tested passed:" and get the part after it, then count the non-empty lines
#                             reward_internal = len(
#                                 [
#                                     line
#                                     for line in passed_section.split("Tested passed:")[
#                                         1
#                                     ].splitlines()
#                                     if line.strip() != ""
#                                 ]
#                             )
#                             reward_internal = reward_internal / len(tests_i)
#                         else:
#                             reward_internal = 0
#                         if is_passing_internal or cur_iter == max_iters - 1:
#                             is_passing = exe.evaluate(
#                                 item["entry_point"],
#                                 child.solution,
#                                 item["test"],
#                                 timeout=10,
#                             )
#                             if is_passing:
#                                 item["solution"] = child.solution
#                                 is_solved = True
#                                 reward_real = 1
#                             break

#                     if is_solved:
#                         break

#                     print(reward_internal)
#                     print(reward_real)
#                     reward = reward_internal + reward_real
#                     child.update(reward)

#                     # Backpropagation
#                     temp = child
#                     while temp.parent:
#                         temp = temp.parent
#                         temp.update(reward)

#         # Choose the best solution after all iterations
#         if is_solved:
#             best_solution = item["solution"]
#         else:
#             best_solution = root.best_child_value().solution
#             item["solution"] = best_solution

#         is_passing, cur_feedback, _ = exe.execute(new_solution, tests_i)
#         test_feedback.append(cur_feedback)
#         is_passing = exe.evaluate(
#             item["entry_point"], best_solution, item["test"], timeout=10
#         )
#         if is_passing:
#             num_success += 1

#         reflections.append("MCTS reflections")
#         implementations.append(best_solution)

#         item["is_solved"] = is_passing
#         item["reflections"] = reflections
#         item["implementations"] = implementations
#         item["test_feedback"] = test_feedback
#         item["acc"] = round(num_success / (idx + 1), 2)
#         write_jsonl(log_path, [item], append=True)

#         print_v(f"completed {idx+1}/{num_items}: acc = {round(num_success/(idx+1), 2)}")

In [None]:
from datasets import load_dataset

dataset = load_dataset("deepmind/code_contests", split="valid")
dataset["validation"]