# Language Agent Tree Search


The basic idea is to combine reflection/evaluation and search (specifically monte-carlo trees search) to get better overall performance.

As with all search-based methods, you trade off some inference-time compute and execution time for quality, and you can use langsmith them to extract the best trajectory for finetuning to reduce the need for as expensive search in the future..

#### Expand and simulate
1. Execute N actions/tasks
2. Get observations from tasks
3. Evaluate the response -> give value of action V
4. Update global state with the next step from this rollout
#### Reflect
For all terminal states (final responses), 
1. Get the reward
2. Assign `c` = reflection IFF reward isn't a 'sucess'
#### Select
(from (Upper Confidence bounds applied to Trees))
1. next action is the one that maximizes the value of V + w sqrt(ln (counts(prev state) )/counts(current state))


#### "Backgpropagate"

```
V(s) = (V_{old}(s)(N(s) - 1)+r) / N(s)
```
Nodes in the path are equivalent to the trajectories.

Then after the backpropagation, we pick the 

## 0. Prerequisites

Install `langgraph` (for the framework), `langchain_openai` (for the LLM), and `langchain` + `tavily-python` (for the search engine).

We will use tavily search as a tool. You can get an API key [here](https://app.tavily.com/sign-in) or replace with a different tool of your choosing.

In [1]:
# %pip install -U --quiet  langchain langgraph langchain_openai
# %pip install -U --quiet tavily-python

In [2]:
import getpass
import os


def _set_if_undefined(var: str) -> None:
    if os.environ.get(var):
        return
    os.environ[var] = getpass.getpass(var)


# Optional: Configure tracing to visualize and debug the agent
_set_if_undefined("LANGCHAIN_API_KEY")
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_PROJECT"] = "LATS"

_set_if_undefined("OPENAI_API_KEY")
_set_if_undefined("TAVILY_API_KEY")

## Minimal example

Either do code generation or search. Maybe let's just to web research since we don't want to do full code rollouts.

In [6]:
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.utilities.tavily_search import TavilySearchAPIWrapper

search = TavilySearchAPIWrapper()
tavily_tool = TavilySearchResults(api_wrapper=search, max_results=5)

In [21]:
from collections import defaultdict
from typing import List

from langchain.output_parsers.openai_tools import (
    JsonOutputToolsParser,
    PydanticToolsParser,
)
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage

from langgraph.prebuilt.tool_executor import ToolExecutor, ToolInvocation

In [160]:
import datetime
from typing import List

from langchain.output_parsers.openai_tools import JsonOutputToolsParser
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompt_values import PromptValue
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.pydantic_v1 import BaseModel, Field, ValidationError
from langchain_openai import ChatOpenAI

prompt_template = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "You are a helpful assistant.",
        ),
        ("user", "{input}"),
        MessagesPlaceholder(variable_name="messages", optional=True),
    ]
)


llm = ChatOpenAI(model="gpt-3.5-turbo")
initial_answer_chain = prompt_template | llm | StrOutputParser()


def expand(prompt_value: PromptValue, config):
    n = config["configurable"].get("n", 5)
    llm_res = llm.generate([prompt_value.to_messages()], n=n, temperature=1.0)
    # Could consider scoring dupped values higher
    dedupped = {gen.text: gen for gen in llm_res.generations[0]}
    return [AIMessage(content=gen.text) for gen in dedupped.values()]


expansion_prompt_template = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "You are a helpful assistant. Generate an "
            "improved response based on the provided reflections.",
        ),
        ("user", "{input}"),
        MessagesPlaceholder(variable_name="messages", optional=True),
    ]
)
expansion_chain = expansion_prompt_template | expand

In [148]:
from langchain_core.pydantic_v1 import BaseModel


class Reflection(BaseModel):
    critique: str = Field(
        description="~50-100 word critique of the current response, "
        "outline parts that are missing, superfluous, or low quality."
    )
    is_finished: bool = Field(
        description="Whether this response completely and accurately resolves the user's request."
    )
    score: int = Field(gte=1, lte=10, description="Score of the response.")

In [164]:
from typing import Optional


class Node:
    def __init__(
        self,
        solution: str,
        parent: Optional[Node] = None,
        reflection: Optional[Reflection] = None,
    ):
        self.solution = solution
        self.parent = parent
        self.children = []
        self.value = 0
        self.visits = 0
        self.reflection = reflection
        self.is_finished = False
        if reflection is None and reflection.is_finished:
            self._mark_as_finished()

    def upper_confidence_bound(self, exploration_weight: float = 1.0):
        if self.parent is None:
            raise ValueError("Cannot compute UTC for root node.")
        if self.visits == 0:
            return self.value
        return (self.value / self.visits) + exploration_weight * math.sqrt(
            math.log(self.parent.visits) / self.visits
        )

    @property
    def best_child(self):
        if not self.children:
            return None
        return max(self.children, key=lambda child: child.uct())

    @property
    def best_child_score(self):
        if not self.children:
            return None
        return max(self.children, key=lambda child: child.value)

    @property
    def depth_at_node(self):
        if self.children:
            return 1 + max([child.depth_at_node for child in self.children])
        return 1

    def add_reflection(self, reflection: Reflection):
        if self.reflection is not None:
            raise ValueError("Cannot overwrite existing reflection")
        self.reflection = reflection
        if reflection is None and reflection.is_finished:
            self._mark_as_finished()

    def update(self, reward: float):
        self.visits += 1
        self.value += reward

    def _mark_as_finished(self):
        self.is_finished = True
        parent = self.parent
        while parent:
            parent.is_finished = True
            parent = parent.parent

In [None]:
n=5 # number of generated actions (beam width)
w=1 # Exploration weight. 
L = 5 # Depth limit / max rollout
K = 3 # Number of rollouts
c.. # context

In [None]:
from 

```
State:
- root

Nodes:
# first_attempt: llm -> generate code
# Rollout
 - Selection
 - Expansion
# Reflection / Simulation
# Backpropagation
```

In [None]:
# with langsmith.trace("mcts", inputs={"item": item}) as trace:
#         # Get the first output
#         cur_func_impl = gen.func_impl(item["prompt"], model, "simple")
#         root = Node(cur_func_impl)  # initial solution (for pass@1 metric)
#         test_feedback.append(feedback)
#         with langsmith.trace(
#             "self_reflection", inputs={"fun": cur_func_impl}
#         ) as trace:
#             reflection = gen.self_reflection(cur_func_impl, feedback, model)
#             trace.end(outputs={"reflection": reflection})
#         reflections += [reflection]
#         root.test_feedback = feedback
#         root.reflection = reflection

#         for cur_iter in range(max_iters):
#             # Selection

#             node = root
#             trajectory = {"solutions": [], "feedbacks": []}

#             while node.children:
#                 node = node.best_child()
#                 trajectory["solutions"].append(node.solution)

#             # Expansion
#             for _ in range(n):
#                 new_solution = None
#                 strategy = "mcts"
#                 prev_func_impl = node.solution
#                 feedback = node.test_feedback
#                 reflection = node.reflection
#                 acc_feedback, acc_reflection = gather_context_from_tree(node)
#                 new_solution = gen.func_impl(
#                     func_sig=item["prompt"],
#                     model=model,
#                     strategy=strategy,
#                     prev_func_impl=prev_func_impl,
#                     feedback=feedback,
#                     self_reflection=reflection,
#                     acc_feedback=acc_feedback,
#                     acc_reflection=acc_reflection,
#                     )

#                 combined_context = "\nPrevious Trial\n\n" + new_solution

#                 child = Node(
#                     new_solution,
#                     parent=node,
#                     context=combined_context,
#                     depth=node.depth + 1,
#                 )
#                 node.children.append(child)

#                 # Simulation
#                 reward_real = 0
#                 for child in node.children:
#                     reflection = gen.self_reflection(
#                         child.solution, feedback_internal, model
#                     )
#                     reflections.append(reflection)
#                     child.reflection = reflection
#                     child.test_feedback = feedback_internal
#                     child.context += (
#                         "\n\nPrevious Trial\n\n"
#                         + child.solution
#                         + "\n\nTest results: \n"
#                         + feedback_internal
#                         + "\n\nSelf-reflection: "
#                         + reflection
#                     )

#                     if "Tested passed:" in feedback_internal:
#                         # Split at "Tests failed:" and get the part before it (which contains the passed tests)
#                         passed_section = feedback_internal.split("Tests failed:")[0]
#                         # Split at "Tested passed:" and get the part after it, then count the non-empty lines
#                         reward_internal = len(
#                             [
#                                 line
#                                 for line in passed_section.split("Tested passed:")[
#                                     1
#                                 ].splitlines()
#                                 if line.strip() != ""
#                             ]
#                         )
#                         reward_internal = reward_internal / len(tests_i)
#                     else:
#                         reward_internal = 0
#                     if is_passing_internal or cur_iter == max_iters - 1:
#                         is_passing = exe.evaluate(
#                             item["entry_point"],
#                             child.solution,
#                             item["test"],
#                             timeout=10,
#                         )
#                         if is_passing:
#                             item["solution"] = child.solution
#                             is_solved = True
#                             reward_real = 1
#                         break

#                 if is_solved:
#                     break
#                 reward = reward_internal + reward_real
#                 child.update(reward)

#                 # Backpropagation
#                 temp = child
#                 while temp.parent:
#                     temp = temp.parent
#                     temp.update(reward)

#     # Choose the best solution after all iterations
#     best_solution = root.best_child_value().solution
#     item["solution"] = best_solution
#     reflections.append("MCTS reflections")
#     implementations.append(best_solution)

In [144]:
# from datasets import load_dataset

# dataset = load_dataset("deepmind/code_contests", split="valid")
# example = dataset[70]