# Language Agent Tree Search


The basic idea is to combine reflection/evaluation and search (specifically monte-carlo trees search) to get better overall performance.

As with all search-based methods, you trade off some inference-time compute and execution time for quality, and you can use langsmith them to extract the best trajectory for finetuning to reduce the need for as expensive search in the future..

#### Expand and simulate
1. Execute N actions/tasks
2. Get observations from tasks
3. Evaluate the response -> give value of action V
4. Update global state with the next step from this rollout
#### Reflect
For all terminal states (final responses), 
1. Get the reward
2. Assign `c` = reflection IFF reward isn't a 'sucess'
#### Select
(from (Upper Confidence bounds applied to Trees))
1. next action is the one that maximizes the value of V + w sqrt(ln (counts(prev state) )/counts(current state))


#### "Backgpropagate"

```
V(s) = (V_{old}(s)(N(s) - 1)+r) / N(s)
```
Nodes in the path are equivalent to the trajectories.

Then after the backpropagation, we pick the 

## 0. Prerequisites

Install `langgraph` (for the framework), `langchain_openai` (for the LLM), and `langchain` + `tavily-python` (for the search engine).

We will use tavily search as a tool. You can get an API key [here](https://app.tavily.com/sign-in) or replace with a different tool of your choosing.

In [None]:
# %pip install -U --quiet  langchain langgraph langchain_openai
# %pip install -U --quiet tavily-python

In [1]:
import getpass
import os


def _set_if_undefined(var: str) -> None:
    if os.environ.get(var):
        return
    os.environ[var] = getpass.getpass(var)


# Optional: Configure tracing to visualize and debug the agent
_set_if_undefined("LANGCHAIN_API_KEY")
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_PROJECT"] = "LATS"

_set_if_undefined("OPENAI_API_KEY")
_set_if_undefined("TAVILY_API_KEY")

## Graph State

LATS is based on a  (greedy)  Monte-Carlo tree search. For each step, it picks N candidates, scores them, and then adds them to the tree. In future iterations, it picks nodes with the highest upper confidence bound, which is a fancy 

In [27]:
from __future__ import annotations

from typing import List, Optional

from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage


class Node:
    def __init__(
        self,
        messages: List[BaseMessage],
        reflection: Reflection,
        parent: Optional[Node] = None,
    ):
        self.messages = messages
        self.parent = parent
        self.children = []
        self.value = 0
        self.visits = 0
        self.reflection = reflection
        self._is_solved = reflection.found_solution if reflection else False
        if self._is_solved:
            self._mark_tree_as_solved()

    def __repr__(self) -> str:
        return (
            f"<Node value={self.value}, visits={self.visits},"
            f" solution={self.messages} reflection={self.reflection}/>"
        )

    @property
    def is_solved(self):
        """If any solutions exist, we can end the search."""
        return self._is_solved

    @property
    def best_child(self):
        """Select the child with the highest UCT to search next."""
        if not self.children:
            return None
        return max(self.children, key=lambda child: child.upper_confidence_bound())

    @property
    def best_child_score(self):
        """Return the child with the highest value."""
        if not self.children:
            return None
        return max(self.children, key=lambda child: child.value)

    @property
    def depth_below(self) -> int:
        """Check for how far we've rolled out the tree."""
        if self.children:
            return 1 + max([child.depth_below for child in self.children])
        return 1

    def upper_confidence_bound(self, exploration_weight=1.0):
        """Return the UCT score. This helps balance exploration vs. exploitation of a branch."""
        if self.parent is None:
            raise ValueError("Cannot obtain UCT from root node")
        if self.visits == 0:
            return self.value
        # Encourages exploitation of high-value trajectories
        average_reward = self.value / self.visits
        # Encourages exploration of less-visited trajectories
        exploration_term = math.sqrt(math.log(self.parent.visits) / self.visits)
        return average_reward + exploration_weight * exploration_term

    def update_reward(self, reward: float):
        """Update the score of this node and its parents."""
        self.visits += 1
        self.value += reward
        parent = self.parent
        while parent:
            parent.value += reward
            parent = parent.parent

    def get_messages(self) -> List[BaseMessage]:
        """Get messages representing this search branch."""
        messages = []
        parent = self.parent
        while parent:
            messages.extend(
                [
                    HumanMessage(content=self.reflection),
                    self.messages,
                ]
            )
        return messages[::-1]  # root solution, reflection, child 1, ...

    def _mark_tree_as_solved(self):
        parent = self.parent
        while parent:
            parent._is_solved = True
            parent = parent.parent

In [28]:
from typing_extensions import TypedDict


class TreeState(TypedDict):
    # The full tree
    root: Node
    # The rolled-out steps so far
    input: str
    solution: AIMessage

## Tools

In [29]:
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.utilities.tavily_search import TavilySearchAPIWrapper
from langgraph.prebuilt.tool_executor import ToolExecutor, ToolInvocation

search = TavilySearchAPIWrapper()
tavily_tool = TavilySearchResults(api_wrapper=search, max_results=5)
tools = [tavily_tool]
tool_executor = ToolExecutor(tools=tools)

## Define Agent

In [30]:
# from collections import defaultdict
# from typing import List

from langchain.output_parsers.openai_tools import (
    JsonOutputToolsParser,
    PydanticToolsParser,
)

#### Generate the initial candidate

In [71]:
import datetime
from typing import List

from langchain_core.prompt_values import ChatPromptValue
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.pydantic_v1 import BaseModel, Field, ValidationError
from langchain_core.runnables import RunnableConfig
from langchain_openai import ChatOpenAI

prompt_template = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "You are an AI assistant.",
        ),
        ("user", "{input}"),
        MessagesPlaceholder(variable_name="messages", optional=True),
    ]
)

llm = ChatOpenAI(model="gpt-3.5-turbo")

initial_answer_chain = prompt_template | llm.bind_tools(tools=tools)

In [72]:
initial_response = initial_answer_chain.invoke(
    {"input": "Write a research report on lithium pollution."}
)
initial_response

AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_qmSe7jJJ3yIcjNqqJHSPlGWk', 'function': {'arguments': '{"query":"lithium pollution research report"}', 'name': 'tavily_search_results_json'}, 'type': 'function'}]})

#### Generate N candidates

In [57]:
# This generates N candidate values
# for a single input to sample actions from the environment


def generate_candidates(messages: ChatPromptValue, config: RunnableConfig):
    n = config["configurable"].get("N", 5)
    bound_kwargs = llm.bind_tools(tools=tools).kwargs
    chat_result = llm.generate(
        [messages.to_messages()], n=n, callbacks=config["callbacks"], **bound_kwargs
    )
    return [gen.message for gen in chat_result.generations[0]]


expansion_chain = prompt_template | generate_candidates

In [58]:
res = expansion_chain.invoke({"input": "Write a research report on lithium pollution."})

#### Reflect on outputs

In [60]:
from langchain.chains import create_structured_output_runnable


class Reflection(BaseModel):
    reflections: str = Field(
        description="The critique and reflections on the sufficiency, superfluency,"
        " and general quality of the response"
    )
    score: int = Field(
        description="Score from 1-10 on the quality of the candidate response."
    )
    found_solution: bool = Field(
        description="Whether the response has fully solved the question or task."
    )


prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "Reflect and grade the assistant response to the user question below.",
        ),
        ("user", "{input}"),
        MessagesPlaceholder(variable_name="candidate"),
        (
            "system",
            "Reflect on the assistant response above, critique, and score the response.",
        ),
    ]
)
reflection_chain = create_structured_output_runnable(
    output_schema=Reflection, llm=llm, prompt=prompt
)

In [61]:
# n = 5  # number of generated actions (beam width)
# w = 1  # Exploration weight.
# L = 5  # Depth limit / max rollout
# K = 3  # Number of rollouts

## Define Graph

Now we may construct the actual graph.


In [76]:
import json
import operator
from collections import deque

from langgraph.graph import StateGraph

parser = JsonOutputToolsParser(return_id=True)


def start(state: TreeState) -> dict:
    res = initial_answer_chain.invoke({"input": state["input"]})
    parsed = parser.invoke(res)
    tool_responses = tool_executor.batch(
        [ToolInvocation(tool=r["type"], tool_input=r["args"]) for r in parsed]
    )
    output_messages = [res] + [
        ToolMessage(content=json.dumps(resp), tool_call_id=tool_call["id"])
        for resp, tool_call in zip(tool_responses, parsed)
    ]
    print(output_messages)
    reflection = reflection_chain.invoke(
        {"input": state["input"], "candidate": output_messages}
    )
    root = Node(output_messages, reflection=reflection)
    return {
        **state,
        "root": root,
    }


def expand(state: TreeState, config: RunnableConfig) -> dict:
    root = state["root"]
    best_candidate: Node = root.best_child if root.children else root
    messages = best_candidate.get_messages()
    # Generate N candidates from the single child candidate
    new_candidates = expansion_chain.invoke(
        {"input": state["input"], "messages": messages}, config
    )
    parsed = parser.batch(new_candidates)
    tool_responses = tool_executor.batch(
        [ToolInvocation(tool=r["type"], tool_input=r["args"]) for r in res]
    )

    # Reflect on each candidate
    # For tasks with external validation, you'd add that here.
    reflections = reflection_chain.batch(
        [{"input": state["input"], "candidate": [msg]} for msg in new_candidates],
        config,
    )
    # Grow tree
    child_nodes = [
        Node(cand, parent=best_candidate, reflection=reflection)
        for cand, reflection in zip(new_candidates, reflections)
    ]
    best_candidate.children.extend(child_nodes)
    # We have already extended the tree directly, so we just return the state
    return state


def select_solution(state: TreeState):
    all_nodes = [state["root"]]
    nodes = deque()
    nodes.append(state["root"])
    while nodes:
        node = nodes.popleft()
        all_nodes.extend(node.children)
        for n in node.children:
            nodes.append(n)
    # TODO: Diff between value and reward?
    best_node = max(all_nodes, key=lambda node: node.value)
    return {**state, "solution": best_node.solution}


def should_loop(state: TreeState):
    root = state["root"]
    if root.is_solved:
        return "select_solution"
    if root.depth_below > 5:
        return "select_solution"
    return "expand"


builder = StateGraph(TreeState)
builder.add_node("start", start)
builder.add_node("expand", expand)
builder.add_node("select_solution", select_solution)
builder.set_entry_point("start")


builder.add_conditional_edges(
    "start",
    # Either expand/rollout or finish
    should_loop,
)
builder.add_conditional_edges(
    "expand",
    # Either continue to rollout or finish
    should_loop,
)

builder.set_finish_point("select_solution")
graph = builder.compile()

In [77]:
res = graph.invoke({"input": "What's the score of the 49'rs chiefs game?"})

[AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_w2aqRiOdWwXbxytaNcQRKeBU', 'function': {'arguments': '{"query":"49ers vs Chiefs current score"}', 'name': 'tavily_search_results_json'}, 'type': 'function'}]}), ToolMessage(content='[{"url": "https://www.cbssports.com/nfl/news/2024-super-bowl-chiefs-vs-49ers-score-patrick-mahomes-leads-ot-comeback-as-k-c-wins-back-to-back-titles/live/", "content": "The championship-winning drive, which included a fourth-and-1 scramble from Mahomes and a clutch 7-yard catch from tight end Travis Kelce, was a must-score for K.C. The NFL\'s new playoff overtime rules -- both teams are guaranteed at least one possession in the extra period -- were in effect for the first time, and the Chiefs needed to answer the Niners\' field goal.\\n Held out of the end zone until that point, Kansas City grabbed its first lead of the game at 13-10.\\nJennings\' touchdown receiving (followed by a missed extra point) concluded a 75-yard drive that put th

BadRequestError: Error code: 400 - {'error': {'message': "Invalid parameter: 'tool_calls' cannot be used when 'functions' are present. Please use 'tools' instead of 'functions'.", 'type': 'invalid_request_error', 'param': 'messages.[2].tool_calls', 'code': None}}