In [1]:
from utils import ChatOpenRouter
from tools.math_tool import get_math_tool

from typing import Sequence

from langchain import hub
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import (
    BaseMessage,
    FunctionMessage,
    HumanMessage,
    SystemMessage,
)
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableBranch
from langchain_core.tools import BaseTool
from langchain_openai import ChatOpenAI
from utils import LLMCompilerPlanParser, Task

import dotenv
dotenv.load_dotenv()

prompt = hub.pull("wfh/llm-compiler")
print(prompt.pretty_print())


Given a user query, create a plan to solve it with the utmost parallelizability. Each plan should comprise an action from the following [33;1m[1;3m{num_tools}[0m types:
[33;1m[1;3m{tool_descriptions}[0m
[33;1m[1;3m{num_tools}[0m. join(): Collects and combines results from prior actions.

 - An LLM agent is called upon invoking join() to either finalize the user query or wait until the plans are executed.
 - join should always be the last action in the plan, and will be called in two scenarios:
   (a) if the answer can be determined by gathering the outputs from tasks to generate the final response.
   (b) if the answer cannot be determined in the planning phase before you execute the plans. Guidelines:
 - Each action described above contains input/output types and description.
    - You must strictly adhere to the input and output types for each action.
    - The action descriptions contain the guidelines. You MUST strictly follow those guidelines when you use the actions.
 -



In [2]:
def create_planner(
    llm: BaseChatModel, tools: Sequence[BaseTool], base_prompt: ChatPromptTemplate
):
    tool_descriptions = "\n".join(
        f"{i+1}. {tool.description}\n"
        for i, tool in enumerate(
            tools
        )  # +1 to offset the 0 starting index, we want it count normally from 1.
    )
    planner_prompt = base_prompt.partial(
        replan="",
        num_tools=len(tools)
        + 1,  # Add one because we're adding the join() tool at the end.
        tool_descriptions=tool_descriptions,
    )
    replanner_prompt = base_prompt.partial(
        replan=' - You are given "Previous Plan" which is the plan that the previous agent created along with the execution results '
        "(given as Observation) of each plan and a general thought (given as Thought) about the executed results."
        'You MUST use these information to create the next plan under "Current Plan".\n'
        ' - When starting the Current Plan, you should start with "Thought" that outlines the strategy for the next plan.\n'
        " - In the Current Plan, you should NEVER repeat the actions that are already executed in the Previous Plan.\n"
        " - You must continue the task index from the end of the previous one. Do not repeat task indices.",
        num_tools=len(tools) + 1,
        tool_descriptions=tool_descriptions,
    )

    def should_replan(state: list):
        # Context is passed as a system message
        return isinstance(state[-1], SystemMessage)

    def wrap_messages(state: list):
        return {"messages": state}

    def wrap_and_get_last_index(state: list):
        next_task = 0
        for message in state[::-1]:
            if isinstance(message, FunctionMessage):
                next_task = message.additional_kwargs["idx"] + 1
                break
        state[-1].content = state[-1].content + f" - Begin counting at : {next_task}"
        return {"messages": state}

    return (
        RunnableBranch(
            (should_replan, wrap_and_get_last_index | replanner_prompt),
            wrap_messages | planner_prompt,
        )
        | llm
        | LLMCompilerPlanParser(tools=tools)
    )

In [7]:
calculate = get_math_tool(ChatOpenRouter(model="openai/gpt-4o-mini-2024-07-18"))
from langchain_community.tools import DuckDuckGoSearchResults


search = DuckDuckGoSearchResults(
    name="search",
    max_results=1,
    description='search(query="the search query") - a search engine.',
)

tools = [search, calculate]

In [8]:
llm = ChatOpenRouter(model="openai/gpt-4o-mini-2024-07-18")
# This is the primary "agent" in our application
planner = create_planner(llm, tools, prompt)

In [9]:
example_question = "What's the temperature in SF raised to the 3rd power?"

for task in planner.stream([HumanMessage(content=example_question)]):
    print(task["tool"], task["args"])
    print("---")

name='search' description='search(query="the search query") - a search engine.' api_wrapper=DuckDuckGoSearchAPIWrapper(region='wt-wt', safesearch='moderate', time='y', max_results=5, backend='auto', source='text') {'query': 'current temperature in San Francisco'}
---
name='math' description='math(problem: str, context: Optional[list[str]]) -> float:\n - Solves the provided math problem.\n - `problem` can be either a simple math problem (e.g. "1 + 3") or a word problem (e.g. "how many apples are there if there are 3 apples and 2 apples").\n - You cannot calculate multiple expressions in one call. For instance, `math(\'1 + 3, 2 + 4\')` does not work. If you need to calculate multiple expressions, you need to call them separately like `math(\'1 + 3\')` and then `math(\'2 + 4\')`\n - Minimize the number of `math` actions as much as possible. For instance, instead of calling 2. math("what is the 10% of $1") and then call 3. math("$1 + $2"), you MUST call 2. math("what is the 110% of $1") in

In [10]:
import re
import time
from concurrent.futures import ThreadPoolExecutor, wait
from typing import Any, Dict, Iterable, List, Union

from langchain_core.runnables import (
    chain as as_runnable,
)
from typing_extensions import TypedDict


def _get_observations(messages: List[BaseMessage]) -> Dict[int, Any]:
    # Get all previous tool responses
    results = {}
    for message in messages[::-1]:
        if isinstance(message, FunctionMessage):
            results[int(message.additional_kwargs["idx"])] = message.content
    return results


class SchedulerInput(TypedDict):
    messages: List[BaseMessage]
    tasks: Iterable[Task]


def _execute_task(task, observations, config):
    tool_to_use = task["tool"]
    if isinstance(tool_to_use, str):
        return tool_to_use
    args = task["args"]
    try:
        if isinstance(args, str):
            resolved_args = _resolve_arg(args, observations)
        elif isinstance(args, dict):
            resolved_args = {
                key: _resolve_arg(val, observations) for key, val in args.items()
            }
        else:
            # This will likely fail
            resolved_args = args
    except Exception as e:
        return (
            f"ERROR(Failed to call {tool_to_use.name} with args {args}.)"
            f" Args could not be resolved. Error: {repr(e)}"
        )
    try:
        return tool_to_use.invoke(resolved_args, config)
    except Exception as e:
        return (
            f"ERROR(Failed to call {tool_to_use.name} with args {args}."
            + f" Args resolved to {resolved_args}. Error: {repr(e)})"
        )


def _resolve_arg(arg: Union[str, Any], observations: Dict[int, Any]):
    # $1 or ${1} -> 1
    ID_PATTERN = r"\$\{?(\d+)\}?"

    def replace_match(match):
        # If the string is ${123}, match.group(0) is ${123}, and match.group(1) is 123.

        # Return the match group, in this case the index, from the string. This is the index
        # number we get back.
        idx = int(match.group(1))
        return str(observations.get(idx, match.group(0)))

    # For dependencies on other tasks
    if isinstance(arg, str):
        return re.sub(ID_PATTERN, replace_match, arg)
    elif isinstance(arg, list):
        return [_resolve_arg(a, observations) for a in arg]
    else:
        return str(arg)


@as_runnable
def schedule_task(task_inputs, config):
    task: Task = task_inputs["task"]
    observations: Dict[int, Any] = task_inputs["observations"]
    try:
        observation = _execute_task(task, observations, config)
    except Exception:
        import traceback

        observation = traceback.format_exception()  # repr(e) +
    observations[task["idx"]] = observation


def schedule_pending_task(
    task: Task, observations: Dict[int, Any], retry_after: float = 0.2
):
    while True:
        deps = task["dependencies"]
        if deps and (any([dep not in observations for dep in deps])):
            # Dependencies not yet satisfied
            time.sleep(retry_after)
            continue
        schedule_task.invoke({"task": task, "observations": observations})
        break


@as_runnable
def schedule_tasks(scheduler_input: SchedulerInput) -> List[FunctionMessage]:
    """Group the tasks into a DAG schedule."""
    # For streaming, we are making a few simplifying assumption:
    # 1. The LLM does not create cyclic dependencies
    # 2. That the LLM will not generate tasks with future deps
    # If this ceases to be a good assumption, you can either
    # adjust to do a proper topological sort (not-stream)
    # or use a more complicated data structure
    tasks = scheduler_input["tasks"]
    args_for_tasks = {}
    messages = scheduler_input["messages"]
    # If we are re-planning, we may have calls that depend on previous
    # plans. Start with those.
    observations = _get_observations(messages)
    task_names = {}
    originals = set(observations)
    # ^^ We assume each task inserts a different key above to
    # avoid race conditions...
    futures = []
    retry_after = 0.25  # Retry every quarter second
    with ThreadPoolExecutor() as executor:
        for task in tasks:
            deps = task["dependencies"]
            task_names[task["idx"]] = (
                task["tool"] if isinstance(task["tool"], str) else task["tool"].name
            )
            args_for_tasks[task["idx"]] = task["args"]
            if (
                # Depends on other tasks
                deps and (any([dep not in observations for dep in deps]))
            ):
                futures.append(
                    executor.submit(
                        schedule_pending_task, task, observations, retry_after
                    )
                )
            else:
                # No deps or all deps satisfied
                # can schedule now
                schedule_task.invoke(dict(task=task, observations=observations))
                # futures.append(executor.submit(schedule_task.invoke, dict(task=task, observations=observations)))

        # All tasks have been submitted or enqueued
        # Wait for them to complete
        wait(futures)
    # Convert observations to new tool messages to add to the state
    new_observations = {
        k: (task_names[k], args_for_tasks[k], observations[k])
        for k in sorted(observations.keys() - originals)
    }
    tool_messages = [
        FunctionMessage(
            name=name,
            content=str(obs),
            additional_kwargs={"idx": k, "args": task_args},
            tool_call_id=k,
        )
        for k, (name, task_args, obs) in new_observations.items()
    ]
    return tool_messages

In [11]:
import itertools


@as_runnable
def plan_and_schedule(state):
    messages = state["messages"]
    tasks = planner.stream(messages)
    # Begin executing the planner immediately
    try:
        tasks = itertools.chain([next(tasks)], tasks)
    except StopIteration:
        # Handle the case where tasks is empty.
        tasks = iter([])
    scheduled_tasks = schedule_tasks.invoke(
        {
            "messages": messages,
            "tasks": tasks,
        }
    )
    return {"messages": scheduled_tasks}

In [13]:
tool_messages = plan_and_schedule.invoke(
    {"messages": [HumanMessage(content=example_question)]}
)["messages"]

In [14]:
tool_messages

[FunctionMessage(content='snippet: Current conditions at SAN FRANCISCO DOWNTOWN (SFOC1) Lat: 37.77056°N Lon: 122.42694°W Elev: 150.0ft., title: 7-Day Forecast 37.77N 122.41W - National Weather Service, link: https://forecast.weather.gov/zipcity.php?inputstring=San+Francisco,CA, snippet: San Francisco Weather Forecasts. Weather Underground provides local & long-range weather forecasts, weatherreports, maps & tropical weather conditions for the San Francisco area., title: San Francisco, CA Weather Conditions | Weather Underground, link: https://www.wunderground.com/weather/us/ca/san-francisco, snippet: San Francisco, California - Current temperature and weather conditions. Detailed hourly weather forecast for today - including weather conditions, temperature, pressure, humidity, precipitation, dewpoint, wind, visibility, and UV index data., title: Weather today - San Francisco, CA, link: https://www.weather-us.com/en/california-usa/san-francisco, snippet: Last Map Update: Sun, Mar 9, 202

In [16]:
from langchain_core.messages import AIMessage

from pydantic import BaseModel, Field


class FinalResponse(BaseModel):
    """The final response/answer."""

    response: str


class Replan(BaseModel):
    feedback: str = Field(
        description="Analysis of the previous attempts and recommendations on what needs to be fixed."
    )


class JoinOutputs(BaseModel):
    """Decide whether to replan or whether you can return the final response."""

    thought: str = Field(
        description="The chain of thought reasoning for the selected action"
    )
    action: Union[FinalResponse, Replan]


joiner_prompt = hub.pull("wfh/llm-compiler-joiner").partial(
    examples=""
)  # You can optionally add examples
llm = ChatOpenRouter(model_name="openai/gpt-4o-mini-2024-07-18")

runnable = joiner_prompt | llm.with_structured_output(
    JoinOutputs, method="function_calling"
)

def _parse_joiner_output(decision: JoinOutputs) -> List[BaseMessage]:
    response = [AIMessage(content=f"Thought: {decision.thought}")]
    if isinstance(decision.action, Replan):
        return {
            "messages": response
            + [
                SystemMessage(
                    content=f"Context from last attempt: {decision.action.feedback}"
                )
            ]
        }
    else:
        return {"messages": response + [AIMessage(content=decision.action.response)]}


def select_recent_messages(state) -> dict:
    messages = state["messages"]
    selected = []
    for msg in messages[::-1]:
        selected.append(msg)
        if isinstance(msg, HumanMessage):
            break
    return {"messages": selected[::-1]}


joiner = select_recent_messages | runnable | _parse_joiner_output



In [17]:
input_messages = [HumanMessage(content=example_question)] + tool_messages

In [18]:
joiner.invoke({"messages": input_messages})

{'messages': [AIMessage(content='Thought: I found a weather snippet about San Francisco, but I did not extract the actual current temperature from it. Therefore, I cannot calculate the temperature raised to the 3rd power.', additional_kwargs={}, response_metadata={}),
  SystemMessage(content='Context from last attempt: I need to replan to obtain the current temperature for San Francisco.', additional_kwargs={}, response_metadata={})]}

In [19]:
from langgraph.graph import END, StateGraph, START
from langgraph.graph.message import add_messages
from typing import Annotated


class State(TypedDict):
    messages: Annotated[list, add_messages]


graph_builder = StateGraph(State)

# 1.  Define vertices
# We defined plan_and_schedule above already
# Assign each node to a state variable to update
graph_builder.add_node("plan_and_schedule", plan_and_schedule)
graph_builder.add_node("join", joiner)


## Define edges
graph_builder.add_edge("plan_and_schedule", "join")

### This condition determines looping logic


def should_continue(state):
    messages = state["messages"]
    if isinstance(messages[-1], AIMessage):
        return END
    return "plan_and_schedule"


graph_builder.add_conditional_edges(
    "join",
    # Next, we pass in the function that will determine which node is called next.
    should_continue,
)
graph_builder.add_edge(START, "plan_and_schedule")
chain = graph_builder.compile()

In [20]:
for step in chain.stream(
    {"messages": [HumanMessage(content="What's the GDP of New York?")]}
):
    print(step)
    print("---")

{'plan_and_schedule': {'messages': [FunctionMessage(content='snippet: In 2023, the real gross domestic product (GDP) of New York was about 1.78 trillion U.S. dollars. This is a slight increase from the previous year, when the state\'s GDP stood at 1.76 trillion U.S ..., title: Real GDP New York U.S. 2023 | Statista, link: https://www.statista.com/statistics/188087/gdp-of-the-us-federal-state-of-new-york-since-1997/, snippet: About $1.8 trillion in 2023. Gross domestic product (GDP) measures the value of goods and services a country or state produces — it\'s the sum of consumer spending, business investment, government spending, and net exports. It is often used to quantify the size of its economy. The $1.8 trillion is the "real GDP," which is adjusted to account for inflation to make it easier to compare ..., title: What is the gross domestic product (GDP) in New York, link: https://usafacts.org/answers/what-is-the-gross-domestic-product-gdp/state/new-york/, snippet: As of 2023, Canada

In [21]:
steps = chain.stream(
    {
        "messages": [
            HumanMessage(
                content="What's the oldest parrot alive, and how much longer is that than the average?"
            )
        ]
    },
    {
        "recursion_limit": 100,
    },
)
for step in steps:
    print(step)
    print("---")

{'plan_and_schedule': {'messages': [FunctionMessage(content="snippet: The oldest recorded bird in captivity was Cookie, who was a Pink Cockatoo that lived until 83! ... Parrots are famous for living for a long time, so it shouldn't be a surprise that 13 out of the 15 birds on our list are a species of parrot. And aren't we lucky about that! No matter the pet, we love them and want them to stay with us for as ..., title: 15 Pet Birds That Live a Long Time: Lifespans & Details ... - PangoVet, link: https://pangovet.com/pet-lifestyle/birds/pet-birds-that-live-a-long-time/, snippet: 9. Senegal Parrot Magda Ehlers/Pexels. Senegal Parrots have a compact, charming appearance and can live around 25 to 30 years. Their quieter nature makes them well-suited for apartment living. Known for their playful antics and loving personalities, Senegals enjoy time with their owners and can learn a range of tricks., title: 15 Parrot Breeds with the Longest Lifespans - pawdown.com, link: https://pawdown.com/

In [22]:
print(step["join"]["messages"][-1].content)

The oldest recorded parrot, Cookie, lived to be 82 years old. The average lifespan of parrots is generally around 20 to 50 years, depending on the species. Therefore, Cookie lived approximately 32 to 62 years longer than the average lifespan of a parrot.


In [23]:
for step in chain.stream(
    {
        "messages": [
            HumanMessage(
                content="What's ((3*(4+5)/0.5)+3245) + 8? What's 32/4.23? What's the sum of those two values?"
            )
        ]
    }
):
    print(step)

{'plan_and_schedule': {'messages': [FunctionMessage(content='3299.0', additional_kwargs={'idx': 1, 'args': {'problem': '(3*(4+5)/0.5)+3245'}}, response_metadata={}, name='math', id='096beb05-3242-4467-be6f-a0db96352335', tool_call_id=1), FunctionMessage(content='7.565011820330969', additional_kwargs={'idx': 2, 'args': {'problem': '32/4.23'}}, response_metadata={}, name='math', id='b279c330-394f-40b8-9032-6e3f4ec0385d', tool_call_id=2), FunctionMessage(content='join', additional_kwargs={'idx': 3, 'args': ()}, response_metadata={}, name='join', id='abf4c2d4-6196-4f42-9a14-51057470f024', tool_call_id=3)]}}
{'join': {'messages': [AIMessage(content='Thought: I have calculated both required values: the first value is 3299.0 and the second value is approximately 7.565. Therefore, I can simply sum these two values to get the final answer.', additional_kwargs={}, response_metadata={}, id='dd279bdc-aae6-4492-a121-ef330b623300'), AIMessage(content='The sum of the two values, 3299.0 and approximat

In [24]:
print(step["join"]["messages"][-1].content)

The sum of the two values, 3299.0 and approximately 7.565, is 3306.565.
