In [9]:
pip install langchain-anthropic==0.3.10 langchain-community==0.3.21 langchain-openai==0.3.12 langgraph==0.3.25 langsmith==0.3.24 tavily-python==0.5.4

Collecting langchain-community==0.3.21
  Downloading langchain_community-0.3.21-py3-none-any.whl.metadata (2.4 kB)
Collecting langchain-openai==0.3.12
  Downloading langchain_openai-0.3.12-py3-none-any.whl.metadata (2.3 kB)
Collecting langgraph==0.3.25
  Downloading langgraph-0.3.25-py3-none-any.whl.metadata (7.7 kB)
Collecting langsmith==0.3.24
  Downloading langsmith-0.3.24-py3-none-any.whl.metadata (15 kB)
Collecting tavily-python==0.5.4
  Downloading tavily_python-0.5.4-py3-none-any.whl.metadata (91 kB)
Collecting langchain<1.0.0,>=0.3.23 (from langchain-community==0.3.21)
  Downloading langchain-0.3.24-py3-none-any.whl.metadata (7.8 kB)
Collecting SQLAlchemy<3,>=1.4 (from langchain-community==0.3.21)
  Downloading sqlalchemy-2.0.40-cp313-cp313-macosx_11_0_arm64.whl.metadata (9.6 kB)
Collecting aiohttp<4.0.0,>=3.8.3 (from langchain-community==0.3.21)
  Downloading aiohttp-3.11.18-cp313-cp313-macosx_11_0_arm64.whl.metadata (7.7 kB)
Collecting dataclasses-json<0.7,>=0.5.7 (from langc

In [10]:
import math
from collections import deque
from typing import Optional

from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage

from pydantic import BaseModel, Field

In [11]:
class Reflection(BaseModel):
    reflections: str = Field(
        description="The critique and reflections on the sufficiency, superfluency,"
        " and general quality of the response"
    )
    score: int = Field(
        description="Score from 0-10 on the quality of the candidate response.",
        gte=0,
        lte=10,
    ) # type: ignore
    found_solution: bool = Field(
        description="Whether the response has fully solved the question or task."
    )

    def as_message(self):
        return HumanMessage(
            content=f"Reasoning: {self.reflections}\nScore: {self.score}"
        )

    @property
    def normalized_score(self) -> float:
        return self.score / 10.0

class Node:
    def __init__(
        self,
        messages: list[BaseMessage],
        reflection: Reflection,
        parent: Optional["Node"] = None,
    ):
        self.messages = messages
        self.parent = parent
        self.children = []
        self.value = 0
        self.visits = 0
        self.reflection = reflection
        self.depth = parent.depth + 1 if parent is not None else 1
        self._is_solved = reflection.found_solution if reflection else False
        if self._is_solved:
            self._mark_tree_as_solved()
        self.backpropagate(reflection.normalized_score)

    def __repr__(self) -> str:
        return (
            f"<Node value={self.value}, visits={self.visits},"
            f" solution={self.messages} reflection={self.reflection}/>"
        )

    @property
    def is_solved(self):
        """If any solutions exist, we can end the search."""
        return self._is_solved

    @property
    def is_terminal(self):
        return not self.children

    @property
    def best_child_score(self):
        """Return the child with the highest value."""
        if not self.children:
            return None
        return max(self.children, key=lambda child: int(child.is_solved) * child.value)

    @property
    def height(self) -> int:
        """Check for how far we've rolled out the tree."""
        if self.children:
            return 1 + max([child.height for child in self.children])
        return 1

    def upper_confidence_bound(self, exploration_weight=1.0):
        """Return the UCT score. This helps balance exploration vs. exploitation of a branch."""
        if self.parent is None:
            raise ValueError("Cannot obtain UCT from root node")
        if self.visits == 0:
            return self.value
        # Encourages exploitation of high-value trajectories
        average_reward = self.value / self.visits
        # Encourages exploration of less-visited trajectories
        exploration_term = math.sqrt(math.log(self.parent.visits) / self.visits)
        return average_reward + exploration_weight * exploration_term

    def backpropagate(self, reward: float):
        """Update the score of this node and its parents."""
        node = self
        while node:
            node.visits += 1
            node.value = (node.value * (node.visits - 1) + reward) / node.visits
            node = node.parent

    def get_messages(self, include_reflections: bool = True):
        if include_reflections:
            return self.messages + [self.reflection.as_message()]
        return self.messages

    def get_trajectory(self, include_reflections: bool = True) -> list[BaseMessage]:
        """Get messages representing this search branch."""
        messages = []
        node = self
        while node:
            messages.extend(
                node.get_messages(include_reflections=include_reflections)[::-1]
            )
            node = node.parent
        # Reverse the final back-tracked trajectory to return in the correct order
        return messages[::-1]  # root solution, reflection, child 1, ...

    def _get_all_children(self):
        all_nodes = []
        nodes = deque()
        nodes.append(self)
        while nodes:
            node = nodes.popleft()
            all_nodes.extend(node.children)
            for n in node.children:
                nodes.append(n)
        return all_nodes

    def get_best_solution(self):
        """Return the best solution from within the current sub-tree."""
        all_nodes = [self] + self._get_all_children()
        best_node = max(
            all_nodes,
            # We filter out all non-terminal, non-solution trajectories
            key=lambda node: int(node.is_terminal and node.is_solved) * node.value,
        )
        return best_node

    def _mark_tree_as_solved(self):
        parent = self.parent
        while parent:
            parent._is_solved = True
            parent = parent.parent

In [12]:
from typing_extensions import TypedDict

class TreeState(TypedDict):
    # The full tree
    root: Node
    # The original input
    input: str

In [13]:
from langchain_openai import ChatOpenAI

llm = ChatOpenAI(model="gpt-4o")

from langchain_community.tools.tavily_search import TavilySearchResults
from langgraph.prebuilt import ToolNode

tavily_tool = TavilySearchResults(max_results=5)
tools = [tavily_tool]
tool_node = ToolNode(tools=tools)


from langchain_core.output_parsers.openai_tools import (
    JsonOutputToolsParser,
    PydanticToolsParser,
)
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import chain as as_runnable

prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "Reflect and grade the assistant response to the user question below.",
        ),
        ("user", "{input}"),
        MessagesPlaceholder(variable_name="candidate"),
    ]
)

reflection_llm_chain = (
    prompt
    | llm.bind_tools(tools=[Reflection], tool_choice="Reflection").with_config(
        run_name="Reflection"
    )
    | PydanticToolsParser(tools=[Reflection])
)


@as_runnable
def reflection_chain(inputs) -> Reflection:
    tool_choices = reflection_llm_chain.invoke(inputs)
    reflection = tool_choices[0]
    if not isinstance(inputs["candidate"][-1], AIMessage):
        reflection.found_solution = False
    return reflection

In [14]:
from langchain_core.prompt_values import ChatPromptValue
from langchain_core.runnables import RunnableConfig

prompt_template = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "You are an AI assistant.",
        ),
        ("user", "{input}"),
        MessagesPlaceholder(variable_name="messages", optional=True),
    ]
)


initial_answer_chain = prompt_template | llm.bind_tools(tools=tools).with_config(
    run_name="GenerateInitialCandidate"
)

parser = JsonOutputToolsParser(return_id=True)
initial_response = initial_answer_chain.invoke(
    {"input": "Write a research report on lithium pollution."}
)
initial_response

AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_deJYRta1OC207ayRQkKz0j1B', 'function': {'arguments': '{"query":"lithium pollution environmental impact research 2023"}', 'name': 'tavily_search_results_json'}, 'type': 'function'}], 'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 28, 'prompt_tokens': 93, 'total_tokens': 121, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_name': 'gpt-4o-2024-08-06', 'system_fingerprint': 'fp_d8864f8b6b', 'id': 'chatcmpl-BPMKSBD5xdVxm9JuspQ5aIMvYuTvf', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-0366564b-fb66-492f-9b4b-cf8dd2e17cd7-0', tool_calls=[{'name': 'tavily_search_results_json', 'args': {'query': 'lithium pollution environmental impact research 2023'}, 'id': 'call_deJYRta1OC207ayRQkKz0j1B', 'type': 'tool_call'}], usage_metadata

{
    "AIMessage": {
        "content": "",
        "additional_kwargs": {
            "tool_calls": [
                {
                    "id": "call_deJYRta1OC207ayRQkKz0j1B",
                    "function": {
                        "arguments": "{\"query\":\"lithium pollution environmental impact research 2023\"}",
                        "name": "tavily_search_results_json"
                    },
                    "type": "function"
                }
            ],
            "refusal": null
        },
        "response_metadata": {
            "token_usage": {
                "completion_tokens": 28,
                "prompt_tokens": 93,
                "total_tokens": 121,
                "completion_tokens_details": {
                    "accepted_prediction_tokens": 0,
                    "audio_tokens": 0,
                    "reasoning_tokens": 0,
                    "rejected_prediction_tokens": 0
                },
                "prompt_tokens_details": {
                    "audio_tokens": 0,
                    "cached_tokens": 0
                }
            },
            "model_name": "gpt-4o-2024-08-06",
            "system_fingerprint": "fp_d8864f8b6b",
            "id": "chatcmpl-BPMKSBD5xdVxm9JuspQ5aIMvYuTvf",
            "finish_reason": "tool_calls",
            "logprobs": null
        },
        "id": "run-0366564b-fb66-492f-9b4b-cf8dd2e17cd7-0",
        "tool_calls": [
            {
                "name": "tavily_search_results_json",
                "args": {
                    "query": "lithium pollution environmental impact research 2023"
                },
                "id": "call_deJYRta1OC207ayRQkKz0j1B",
                "type": "tool_call"
            }
        ],
        "usage_metadata": {
            "input_tokens": 93,
            "output_tokens": 28,
            "total_tokens": 121,
            "input_token_details": {
                "audio": 0,
                "cache_read": 0
            },
            "output_token_details": {
                "audio": 0,
                "reasoning": 0
            }
        }
    }
}

In [16]:
# Define the node we will add to the graph
def generate_initial_response(state: TreeState) -> dict:
    """Generate the initial candidate response."""
    res = initial_answer_chain.invoke({"input": state["input"]})
    parsed = parser.invoke(res)
    tool_responses = [
        tool_node.invoke(
            {
                "messages": [
                    AIMessage(
                        content="",
                        tool_calls=[
                            {"name": r["type"], "args": r["args"], "id": r["id"]}
                        ],
                    )
                ]
            }
        )
        for r in parsed
    ]
    output_messages = [res] + [tr["messages"][0] for tr in tool_responses]
    reflection = reflection_chain.invoke(
        {"input": state["input"], "candidate": output_messages}
    )
    root = Node(output_messages, reflection=reflection)
    return {
        **state,
        "root": root,
    }

In [17]:
# This generates N candidate values
# for a single input to sample actions from the environment
def generate_candidates(messages: ChatPromptValue, config: RunnableConfig):
    n = config["configurable"].get("N", 5)
    bound_kwargs = llm.bind_tools(tools=tools).kwargs
    chat_result = llm.generate(
        [messages.to_messages()],
        n=n,
        callbacks=config["callbacks"],
        run_name="GenerateCandidates",
        **bound_kwargs,
    )
    return [gen.message for gen in chat_result.generations[0]]


expansion_chain = prompt_template | generate_candidates

In [18]:
res = expansion_chain.invoke({"input": "Write a research report on lithium pollution."})
res

[AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_mtjPL7h0h2YAeSwIY8U4krwG', 'function': {'arguments': '{"query":"lithium pollution research 2023"}', 'name': 'tavily_search_results_json'}, 'type': 'function'}], 'refusal': None}, response_metadata={'finish_reason': 'tool_calls', 'logprobs': None}, id='run-5c52bdfe-2eff-415d-a09d-df1272212667-0', tool_calls=[{'name': 'tavily_search_results_json', 'args': {'query': 'lithium pollution research 2023'}, 'id': 'call_mtjPL7h0h2YAeSwIY8U4krwG', 'type': 'tool_call'}], usage_metadata={'input_tokens': 93, 'output_tokens': 129, 'total_tokens': 222, 'input_token_details': {'audio': 0, 'cache_read': 0}, 'output_token_details': {'audio': 0, 'reasoning': 0}}),
 AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_eEzazPhsbh13NObVARwuy7YX', 'function': {'arguments': '{"query":"lithium pollution research report"}', 'name': 'tavily_search_results_json'}, 'type': 'function'}]}, response_metadata={'finish_reason': 'tool_c

[
    {
        "content": "",
        "additional_kwargs": {
            "tool_calls": [
                {
                    "id": "call_mtjPL7h0h2YAeSwIY8U4krwG",
                    "function": {
                        "arguments": "{\"query\":\"lithium pollution research 2023\"}",
                        "name": "tavily_search_results_json"
                    },
                    "type": "function"
                }
            ],
            "refusal": null
        },
        "response_metadata": {
            "finish_reason": "tool_calls",
            "logprobs": null
        },
        "id": "run-5c52bdfe-2eff-415d-a09d-df1272212667-0",
        "tool_calls": [
            {
                "name": "tavily_search_results_json",
                "args": {
                    "query": "lithium pollution research 2023"
                },
                "id": "call_mtjPL7h0h2YAeSwIY8U4krwG",
                "type": "tool_call"
            }
        ],
        "usage_metadata": {
            "input_tokens": 93,
            "output_tokens": 129,
            "total_tokens": 222,
            "input_token_details": {
                "audio": 0,
                "cache_read": 0
            },
            "output_token_details": {
                "audio": 0,
                "reasoning": 0
            }
        }
    },
    {
        "content": "",
        "additional_kwargs": {
            "tool_calls": [
                {
                    "id": "call_eEzazPhsbh13NObVARwuy7YX",
                    "function": {
                        "arguments": "{\"query\":\"lithium pollution research report\"}",
                        "name": "tavily_search_results_json"
                    },
                    "type": "function"
                }
            ]
        },
        "response_metadata": {
            "finish_reason": "tool_calls",
            "logprobs": null
        },
        "id": "run-5c52bdfe-2eff-415d-a09d-df1272212667-1",
        "tool_calls": [
            {
                "name": "tavily_search_results_json",
                "args": {
                    "query": "lithium pollution research report"
                },
                "id": "call_eEzazPhsbh13NObVARwuy7YX",
                "type": "tool_call"
            }
        ],
        "usage_metadata": {
            "input_tokens": 93,
            "output_tokens": 129,
            "total_tokens": 222,
            "input_token_details": {
                "audio": 0,
                "cache_read": 0
            },
            "output_token_details": {
                "audio": 0,
                "reasoning": 0
            }
        }
    },
    {
        "content": "",
        "additional_kwargs": {
            "tool_calls": [
                {
                    "id": "call_z3Za6Ferb1kyfTjIBmey6GmW",
                    "function": {
                        "arguments": "{\"query\":\"Lithium pollution research report 2023\"}",
                        "name": "tavily_search_results_json"
                    },
                    "type": "function"
                }
            ]
        },
        "response_metadata": {
            "finish_reason": "tool_calls",
            "logprobs": null
        },
        "id": "run-5c52bdfe-2eff-415d-a09d-df1272212667-2",
        "tool_calls": [
            {
                "name": "tavily_search_results_json",
                "args": {
                    "query": "Lithium pollution research report 2023"
                },
                "id": "call_z3Za6Ferb1kyfTjIBmey6GmW",
                "type": "tool_call"
            }
        ],
        "usage_metadata": {
            "input_tokens": 93,
            "output_tokens": 129,
            "total_tokens": 222,
            "input_token_details": {
                "audio": 0,
                "cache_read": 0
            },
            "output_token_details": {
                "audio": 0,
                "reasoning": 0
            }
        }
    },
    {
        "content": "",
        "additional_kwargs": {
            "tool_calls": [
                {
                    "id": "call_UNiZlo365zfKjICV6KUyEfzi",
                    "function": {
                        "arguments": "{\"query\":\"lithium pollution research report 2023\"}",
                        "name": "tavily_search_results_json"
                    },
                    "type": "function"
                }
            ]
        },
        "response_metadata": {
            "finish_reason": "tool_calls",
            "logprobs": null
        },
        "id": "run-5c52bdfe-2eff-415d-a09d-df1272212667-3",
        "tool_calls": [
            {
                "name": "tavily_search_results_json",
                "args": {
                    "query": "lithium pollution research report 2023"
                },
                "id": "call_UNiZlo365zfKjICV6KUyEfzi",
                "type": "tool_call"
            }
        ],
        "usage_metadata": {
            "input_tokens": 93,
            "output_tokens": 129,
            "total_tokens": 222,
            "input_token_details": {
                "audio": 0,
                "cache_read": 0
            },
            "output_token_details": {
                "audio": 0,
                "reasoning": 0
            }
        }
    },
    {
        "content": "",
        "additional_kwargs": {
            "tool_calls": [
                {
                    "id": "call_6IPmBOxetczY4uwA4mO9wGbi",
                    "function": {
                        "arguments": "{\"query\":\"lithium pollution research 2023\"}",
                        "name": "tavily_search_results_json"
                    },
                    "type": "function"
                }
            ]
        },
        "response_metadata": {
            "finish_reason": "tool_calls",
            "logprobs": null
        },
        "id": "run-5c52bdfe-2eff-415d-a09d-df1272212667-4",
        "tool_calls": [
            {
                "name": "tavily_search_results_json",
                "args": {
                    "query": "lithium pollution research 2023"
                },
                "id": "call_6IPmBOxetczY4uwA4mO9wGbi",
                "type": "tool_call"
            }
        ],
        "usage_metadata": {
            "input_tokens": 93,
            "output_tokens": 129,
            "total_tokens": 222,
            "input_token_details": {
                "audio": 0,
                "cache_read": 0
            },
            "output_token_details": {
                "audio": 0,
                "reasoning": 0
            }
        }
    }
]

In [19]:
from collections import defaultdict


def select(root: Node) -> dict:
    """Starting from the root node a child node is selected at each tree level until a leaf node is reached."""

    if not root.children:
        return root

    node = root
    while node.children:
        max_child = max(node.children, key=lambda child: child.upper_confidence_bound())
        node = max_child

    return node


def expand(state: TreeState, config: RunnableConfig) -> dict:
    """Starting from the "best" node in the tree, generate N candidates for the next step."""
    root = state["root"]
    best_candidate: Node = select(root)
    messages = best_candidate.get_trajectory()
    # Generate N candidates from the single child candidate
    new_candidates = expansion_chain.invoke(
        {"input": state["input"], "messages": messages}, config
    )
    parsed = parser.batch(new_candidates)
    flattened = [
        (i, tool_call)
        for i, tool_calls in enumerate(parsed)
        for tool_call in tool_calls
    ]
    tool_responses = [
        (
            i,
            tool_node.invoke(
                {
                    "messages": [
                        AIMessage(
                            content="",
                            tool_calls=[
                                {
                                    "name": tool_call["type"],
                                    "args": tool_call["args"],
                                    "id": tool_call["id"],
                                }
                            ],
                        )
                    ]
                }
            ),
        )
        for i, tool_call in flattened
    ]
    collected_responses = defaultdict(list)
    for i, resp in tool_responses:
        collected_responses[i].append(resp["messages"][0])
    output_messages = []
    for i, candidate in enumerate(new_candidates):
        output_messages.append([candidate] + collected_responses[i])

    # Reflect on each candidate
    # For tasks with external validation, you'd add that here.
    reflections = reflection_chain.batch(
        [{"input": state["input"], "candidate": msges} for msges in output_messages],
        config,
    )
    # Grow tree
    child_nodes = [
        Node(cand, parent=best_candidate, reflection=reflection)
        for cand, reflection in zip(output_messages, reflections)
    ]
    best_candidate.children.extend(child_nodes)
    # We have already extended the tree directly, so we just return the state
    return state

In [20]:
from typing import Literal

from langgraph.graph import END, StateGraph, START


def should_loop(state: TreeState):
    """Determine whether to continue the tree search."""
    root = state["root"]
    if root.is_solved:
        return END
    if root.height > 5:
        return END
    return "expand"


builder = StateGraph(TreeState)
builder.add_node("start", generate_initial_response)
builder.add_node("expand", expand)
builder.add_edge(START, "start")


builder.add_conditional_edges(
    "start",
    # Either expand/rollout or finish
    should_loop,
    ["expand", END],
)
builder.add_conditional_edges(
    "expand",
    # Either continue to rollout or finish
    should_loop,
    ["expand", END],
)

graph = builder.compile()

In [23]:
question = "Generate a table with the average size and weight, as well as the oldest recorded instance for each of the top 5 most common birds."
last_step = None
for step in graph.stream({"input": question}):
    last_step = step
    step_name, step_state = next(iter(step.items()))
    print(step_name)
    print("rolled out: ", step_state["root"].height)
    print("---")

start
rolled out:  1
---
expand
rolled out:  2
---
