In [1]:
%%capture --no-stderr
%pip install -U --quiet langchain_openai langsmith langgraph langchain numexpr

In [1]:
import os
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI
from langchain.prompts import PromptTemplate

# Load environment variables
load_dotenv()
# Set up OpenAI API key
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
os.environ["ANTHROPIC_API_KEY"] = os.getenv("ANTHROPIC_API_KEY")

## Define Tools

In [2]:
from langchain_community.tools import DuckDuckGoSearchRun
from langchain_openai import ChatOpenAI
from math_tools import get_math_tool
from forecast_tools import get_forecast_tool

In [3]:
calculate = get_math_tool(ChatOpenAI(model="gpt-4-turbo-preview"))
# search = DuckDuckGoSearchRun()

In [40]:
from langchain.tools import StructuredTool
from langchain_community.tools import DuckDuckGoSearchRun
from pydantic import BaseModel, Field
from typing import Optional, Dict, Any

class SearchSchema(BaseModel):
    query: str = Field(
        description="The search query string to look up information. Example: 'current temperature in San Francisco'"
    )

def search_func(query: Optional[str] = None, **kwargs: Any) -> str:
    """Execute a web search with the given query."""
    # Try to get query from various possible inputs
    if not query:
        if isinstance(kwargs.get('args'), dict):
            query = kwargs['args'].get('query')
            if not query and kwargs['args']:
                # If no query field but has other values, use them
                query = ' '.join(str(v) for v in kwargs['args'].values() if v)
        elif isinstance(kwargs.get('args'), (list, tuple)):
            query = ' '.join(str(x) for x in kwargs['args'] if x)
        elif kwargs:
            # Try to find any string value in kwargs
            for val in kwargs.values():
                if val:  # Check any non-empty value
                    query = str(val)
                    break
    
    if not query:
        raise ValueError(
            "Search query cannot be empty. Please provide a query string. "
            "Example: {'query': 'GDP of New York'} or {'args': {'query': 'GDP of New York'}}"
        )
    
    try:
        search_runner = DuckDuckGoSearchRun()
        result = search_runner.run(query)
        if not result or result.lower().startswith('error'):
            raise ValueError(f"Search failed to return valid results for query: {query}")
        return result
    except Exception as e:
        raise ValueError(f"Search failed for query '{query}'. Error: {str(e)}")

search = StructuredTool(
    name="search",
    func=search_func,
    description="""Search the internet for current information. 
    Required input: A search query string.
    Example: search.invoke({"query": "current GDP of New York"})
    Or: search.invoke({"args": {"query": "current GDP of New York"}})
    """,
    args_schema=SearchSchema
)

In [41]:
# Example of correct usage:
search.invoke({"query": "GDP of New York"})

"In 2023, the real gross domestic product (GDP) of New York was about 1.78 trillion U.S. dollars. This is a slight increase from the previous year, when the state's GDP stood at 1.76 trillion U.S ... In 2023, GDP per person in New York was $91.5 thousand, up 2.0% from 2022. In 2023, real GDP was equivalent to $91,523 per person. Real gross domestic product per person in New York, chained 2017 dollars Learn how financial services, health care, professional and business services, retail trade, manufacturing, and educational services contribute to New York's $1.78 trillion GDP in 2023. Find out the number of workers, salaries, and products in each sector and how they rank nationally and globally. Find data on GDP by industry for New York state from 1997 to 2023. See the latest estimates, trends, and sources from the U.S. Bureau of Economic Analysis. Graph and download economic data for Total Gross Domestic Product for New York-Newark-Jersey City, NY-NJ-PA (MSA) (NGMP35620) from 2001 to 20

In [42]:
tools = [search, calculate]


In [7]:
calculate.invoke(
    {
        "problem": "What's the temp of sf + 5?",
        "context": ["Thet empreature of sf is 32 degrees"],
    }
)

'37'

In [8]:
forecast = get_forecast_tool(ChatOpenAI(model="gpt-4-turbo-preview"))


In [9]:
# forecast.invoke({
#     "problem": "What's the temp of sf + 5? Given temperature series [32]",
#     "context": ["The temperature of sf is 32 degrees"],
#     "forecast": "temperature forecast",
#     "Answer": "",
#     "Forecast results with confidence intervals": "",
#     "Question with forecasting problem.": "What's the temp of sf + 5?",
#     "code": "",
#     "forecasting expression that processes the data": ""
# })

## Planner

In [13]:
from typing import Sequence

from langchain import hub
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import (
    BaseMessage,
    FunctionMessage,
    HumanMessage,
    SystemMessage,
)
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableBranch
from langchain_core.tools import BaseTool
from langchain_openai import ChatOpenAI
from output_parser import LLMCompilerPlanParser, Task

prompt = hub.pull("wfh/llm-compiler")
print(prompt.pretty_print())




Given a user query, create a plan to solve it with the utmost parallelizability. Each plan should comprise an action from the following [33;1m[1;3m{num_tools}[0m types:
[33;1m[1;3m{tool_descriptions}[0m
[33;1m[1;3m{num_tools}[0m. join(): Collects and combines results from prior actions.

 - An LLM agent is called upon invoking join() to either finalize the user query or wait until the plans are executed.
 - join should always be the last action in the plan, and will be called in two scenarios:
   (a) if the answer can be determined by gathering the outputs from tasks to generate the final response.
   (b) if the answer cannot be determined in the planning phase before you execute the plans. Guidelines:
 - Each action described above contains input/output types and description.
    - You must strictly adhere to the input and output types for each action.
    - The action descriptions contain the guidelines. You MUST strictly follow those guidelines when you use the actions.
 -

In [14]:
def create_planner(
    llm: BaseChatModel, tools: Sequence[BaseTool], base_prompt: ChatPromptTemplate
):
    tool_descriptions = "\n".join(
        f"{i+1}. {tool.description}\n"
        for i, tool in enumerate(
            tools
        )  # +1 to offset the 0 starting index, we want it count normally from 1.
    )
    planner_prompt = base_prompt.partial(
        replan="",
        num_tools=len(tools)
        + 1,  # Add one because we're adding the join() tool at the end.
        tool_descriptions=tool_descriptions,
    )
    replanner_prompt = base_prompt.partial(
        replan=' - You are given "Previous Plan" which is the plan that the previous agent created along with the execution results '
        "(given as Observation) of each plan and a general thought (given as Thought) about the executed results."
        'You MUST use these information to create the next plan under "Current Plan".\n'
        ' - When starting the Current Plan, you should start with "Thought" that outlines the strategy for the next plan.\n'
        " - In the Current Plan, you should NEVER repeat the actions that are already executed in the Previous Plan.\n"
        " - You must continue the task index from the end of the previous one. Do not repeat task indices.",
        num_tools=len(tools) + 1,
        tool_descriptions=tool_descriptions,
    )

    def should_replan(state: list):
        # Context is passed as a system message
        return isinstance(state[-1], SystemMessage)

    def wrap_messages(state: list):
        return {"messages": state}

    def wrap_and_get_last_index(state: list):
        next_task = 0
        for message in state[::-1]:
            if isinstance(message, FunctionMessage):
                next_task = message.additional_kwargs["idx"] + 1
                break
        state[-1].content = state[-1].content + f" - Begin counting at : {next_task}"
        return {"messages": state}

    return (
        RunnableBranch(
            (should_replan, wrap_and_get_last_index | replanner_prompt),
            wrap_messages | planner_prompt,
        )
        | llm
        | LLMCompilerPlanParser(tools=tools)
    )

In [15]:
llm = ChatOpenAI(model="gpt-4-turbo-preview")

# This is the primary "agent" in our application
planner = create_planner(llm, tools, prompt)

In [16]:
example_question = "What's the temperature in SF raised to the 3rd power?"

for task in planner.stream([HumanMessage(content=example_question)]):
    print(task["tool"], task["args"])
    print("---")

name='search' description='Search the internet for current information. \n    Example input: {"query": "current temperature in San Francisco"}\n    Example output: "The current temperature in San Francisco is 65°F" ' args_schema=<class '__main__.SearchSchema'> func=<function search_func at 0x00000282E5B296C0> {}
---
name='math' description='math(problem: str, context: Optional[list[str]]) -> float:\n - Solves the provided math problem.\n - `problem` can be either a simple math problem (e.g. "1 + 3") or a word problem (e.g. "how many apples are there if there are 3 apples and 2 apples").\n - You cannot calculate multiple expressions in one call. For instance, `math(\'1 + 3, 2 + 4\')` does not work. If you need to calculate multiple expressions, you need to call them separately like `math(\'1 + 3\')` and then `math(\'2 + 4\')`\n - Minimize the number of `math` actions as much as possible. For instance, instead of calling 2. math("what is the 10% of $1") and then call 3. math("$1 + $2"), 

## Task Fetching Unit

{
    tool: BaseTool,
    dependencies: number[],
}

In [17]:
import re
import time
from concurrent.futures import ThreadPoolExecutor, wait
from typing import Any, Dict, Iterable, List, Union

from langchain_core.runnables import (
    chain as as_runnable,
)
from typing_extensions import TypedDict


def _get_observations(messages: List[BaseMessage]) -> Dict[int, Any]:
    # Get all previous tool responses
    results = {}
    for message in messages[::-1]:
        if isinstance(message, FunctionMessage):
            results[int(message.additional_kwargs["idx"])] = message.content
    return results


class SchedulerInput(TypedDict):
    messages: List[BaseMessage]
    tasks: Iterable[Task]


def _execute_task(task, observations, config):
    tool_to_use = task["tool"]
    if isinstance(tool_to_use, str):
        return tool_to_use
    args = task["args"]
    try:
        if isinstance(args, str):
            resolved_args = _resolve_arg(args, observations)
        elif isinstance(args, dict):
            resolved_args = {
                key: _resolve_arg(val, observations) for key, val in args.items()
            }
        else:
            # This will likely fail
            resolved_args = args
    except Exception as e:
        return (
            f"ERROR(Failed to call {tool_to_use.name} with args {args}.)"
            f" Args could not be resolved. Error: {repr(e)}"
        )
    try:
        return tool_to_use.invoke(resolved_args, config)
    except Exception as e:
        return (
            f"ERROR(Failed to call {tool_to_use.name} with args {args}."
            + f" Args resolved to {resolved_args}. Error: {repr(e)})"
        )


def _resolve_arg(arg: Union[str, Any], observations: Dict[int, Any]):
    # $1 or ${1} -> 1
    ID_PATTERN = r"\$\{?(\d+)\}?"

    def replace_match(match):
        # If the string is ${123}, match.group(0) is ${123}, and match.group(1) is 123.

        # Return the match group, in this case the index, from the string. This is the index
        # number we get back.
        idx = int(match.group(1))
        return str(observations.get(idx, match.group(0)))

    # For dependencies on other tasks
    if isinstance(arg, str):
        return re.sub(ID_PATTERN, replace_match, arg)
    elif isinstance(arg, list):
        return [_resolve_arg(a, observations) for a in arg]
    else:
        return str(arg)


@as_runnable
def schedule_task(task_inputs, config):
    task: Task = task_inputs["task"]
    observations: Dict[int, Any] = task_inputs["observations"]
    try:
        observation = _execute_task(task, observations, config)
    except Exception:
        import traceback

        observation = traceback.format_exception()  # repr(e) +
    observations[task["idx"]] = observation


def schedule_pending_task(
    task: Task, observations: Dict[int, Any], retry_after: float = 0.2
):
    while True:
        deps = task["dependencies"]
        if deps and (any([dep not in observations for dep in deps])):
            # Dependencies not yet satisfied
            time.sleep(retry_after)
            continue
        schedule_task.invoke({"task": task, "observations": observations})
        break


@as_runnable
def schedule_tasks(scheduler_input: SchedulerInput) -> List[FunctionMessage]:
    """Group the tasks into a DAG schedule."""
    # For streaming, we are making a few simplifying assumption:
    # 1. The LLM does not create cyclic dependencies
    # 2. That the LLM will not generate tasks with future deps
    # If this ceases to be a good assumption, you can either
    # adjust to do a proper topological sort (not-stream)
    # or use a more complicated data structure
    tasks = scheduler_input["tasks"]
    args_for_tasks = {}
    messages = scheduler_input["messages"]
    # If we are re-planning, we may have calls that depend on previous
    # plans. Start with those.
    observations = _get_observations(messages)
    task_names = {}
    originals = set(observations)
    # ^^ We assume each task inserts a different key above to
    # avoid race conditions...
    futures = []
    retry_after = 0.25  # Retry every quarter second
    with ThreadPoolExecutor() as executor:
        for task in tasks:
            deps = task["dependencies"]
            task_names[task["idx"]] = (
                task["tool"] if isinstance(task["tool"], str) else task["tool"].name
            )
            args_for_tasks[task["idx"]] = task["args"]
            if (
                # Depends on other tasks
                deps and (any([dep not in observations for dep in deps]))
            ):
                futures.append(
                    executor.submit(
                        schedule_pending_task, task, observations, retry_after
                    )
                )
            else:
                # No deps or all deps satisfied
                # can schedule now
                schedule_task.invoke(dict(task=task, observations=observations))
                # futures.append(executor.submit(schedule_task.invoke, dict(task=task, observations=observations)))

        # All tasks have been submitted or enqueued
        # Wait for them to complete
        wait(futures)
    # Convert observations to new tool messages to add to the state
    new_observations = {
        k: (task_names[k], args_for_tasks[k], observations[k])
        for k in sorted(observations.keys() - originals)
    }
    tool_messages = [
        FunctionMessage(
            name=name,
            content=str(obs),
            additional_kwargs={"idx": k, "args": task_args},
            tool_call_id=k,
        )
        for k, (name, task_args, obs) in new_observations.items()
    ]
    return tool_messages

In [18]:
import itertools


@as_runnable
def plan_and_schedule(state):
    messages = state["messages"]
    tasks = planner.stream(messages)
    # Begin executing the planner immediately
    try:
        tasks = itertools.chain([next(tasks)], tasks)
    except StopIteration:
        # Handle the case where tasks is empty.
        tasks = iter([])
    scheduled_tasks = schedule_tasks.invoke(
        {
            "messages": messages,
            "tasks": tasks,
        }
    )
    return {"messages": scheduled_tasks}

## Example Plan

In [19]:
tool_messages = plan_and_schedule.invoke(
    {"messages": [HumanMessage(content=example_question)]}
)["messages"]


In [43]:
tool_messages


[FunctionMessage(content='Critical to Extremely Critical Fire Weather Conditions in Southern California; Heavy Lake Effect Snow Downwind of Lakes Erie and Ontario ... Current conditions at SAN FRANCISCO DOWNTOWN (SFOC1) Lat: 37.77056°NLon: 122.42694°WElev: 150.0ft. NA. 51°F. 11°C. Humidity: 43%: Wind Speed: NA NA MPH: Barometer: NA: Dewpoint: 29°F (-2°C ... San Francisco Weather Forecasts. Weather Underground provides local & long-range weather forecasts, weatherreports, maps & tropical weather conditions for the San Francisco area. See the latest San Francisco weather forecast, current conditions, and live radar. Keep up to date on all San Francisco weather news with KRON4. Current Hazards. Daily Briefing; Submit Report; Current Outlooks; Detailed Hazards; Tsunami; Graphical Hazardous Weather Outlook; Current Conditions. Hydro Maps Data; ... National Weather Service San Francisco Bay Area, CA 21 Grace Hopper Ave, Stop 5 Monterey, CA 93943-5505 (831) 656-1725 Comments? Questions? Pleas

## Joiner

In [21]:
from langchain_core.messages import AIMessage

from pydantic import BaseModel, Field


class FinalResponse(BaseModel):
    """The final response/answer."""

    response: str


class Replan(BaseModel):
    feedback: str = Field(
        description="Analysis of the previous attempts and recommendations on what needs to be fixed."
    )


class JoinOutputs(BaseModel):
    """Decide whether to replan or whether you can return the final response."""

    thought: str = Field(
        description="The chain of thought reasoning for the selected action"
    )
    action: Union[FinalResponse, Replan]


joiner_prompt = hub.pull("wfh/llm-compiler-joiner").partial(
    examples=""
)  # You can optionally add examples
llm = ChatOpenAI(model="gpt-4-turbo-preview")

runnable = joiner_prompt | llm.with_structured_output(
    JoinOutputs, method="function_calling"
)



In [22]:
def _parse_joiner_output(decision: JoinOutputs) -> List[BaseMessage]:
    response = [AIMessage(content=f"Thought: {decision.thought}")]
    if isinstance(decision.action, Replan):
        return {
            "messages": response
            + [
                SystemMessage(
                    content=f"Context from last attempt: {decision.action.feedback}"
                )
            ]
        }
    else:
        return {"messages": response + [AIMessage(content=decision.action.response)]}


def select_recent_messages(state) -> dict:
    messages = state["messages"]
    selected = []
    for msg in messages[::-1]:
        selected.append(msg)
        if isinstance(msg, HumanMessage):
            break
    return {"messages": selected[::-1]}


joiner = select_recent_messages | runnable | _parse_joiner_output

In [23]:
input_messages = [HumanMessage(content=example_question)] + tool_messages


In [45]:
joiner.invoke({"messages": input_messages})


{'messages': [AIMessage(content="Thought: The temperature in San Francisco is reported as 51°F. To answer the user's question, this value needs to be cubed. However, the provided actions did not include the computation of 51°F raised to the 3rd power but instead showed a mathematical operation result unrelated to the user's query. Therefore, the missing computation needs to be performed to answer the question correctly.", additional_kwargs={}, response_metadata={}),
  SystemMessage(content="Context from last attempt: The mathematical computation needed to answer the question - cubing the current temperature in San Francisco - was not performed. The result provided does not match the user's request.", additional_kwargs={}, response_metadata={})]}

## Compose using LangGraph

In [25]:
from langgraph.graph import END, StateGraph, START
from langgraph.graph.message import add_messages
from typing import Annotated


class State(TypedDict):
    messages: Annotated[list, add_messages]


graph_builder = StateGraph(State)

# 1.  Define vertices
# We defined plan_and_schedule above already
# Assign each node to a state variable to update
graph_builder.add_node("plan_and_schedule", plan_and_schedule)
graph_builder.add_node("join", joiner)


## Define edges
graph_builder.add_edge("plan_and_schedule", "join")

### This condition determines looping logic


def should_continue(state):
    messages = state["messages"]
    if isinstance(messages[-1], AIMessage):
        return END
    return "plan_and_schedule"


graph_builder.add_conditional_edges(
    "join",
    # Next, we pass in the function that will determine which node is called next.
    should_continue,
)
graph_builder.add_edge(START, "plan_and_schedule")
chain = graph_builder.compile()

### Simple question

In [33]:
for step in chain.stream(
    {"messages": [HumanMessage(content="What's the GDP of New York?")]}
):
    print(step)
    print("---")

{'plan_and_schedule': {'messages': [FunctionMessage(content='ERROR(Failed to call search with args {}. Args resolved to {}. Error: 1 validation error for SearchSchema\nquery\n  Field required [type=missing, input_value={}, input_type=dict]\n    For further information visit https://errors.pydantic.dev/2.10/v/missing)', additional_kwargs={'idx': 1, 'args': {}}, response_metadata={}, name='search', id='f2bbffc7-af5c-4760-a3d0-18835c1417ce', tool_call_id=1)]}}
---
{'join': {'messages': [AIMessage(content="Thought: The search action failed due to a missing query parameter, so no information about New York's GDP was retrieved.", additional_kwargs={}, response_metadata={}, id='25cb9aee-7613-4395-a465-8d96fedb7da7'), SystemMessage(content="Context from last attempt: The plan to obtain New York's GDP was not executed successfully due to a missing query parameter in the search action. A query specifying the intent to find the GDP of New York needs to be provided for a successful execution.", ad

In [34]:
# Final answer
print(step["join"]["messages"][-1].content)

In 2023, the real gross domestic product (GDP) of New York was about 1.78 trillion U.S. dollars.


### Multi-hop question

In [35]:
steps = chain.stream(
    {
        "messages": [
            HumanMessage(
                content="What's the oldest parrot alive, and how much longer is that than the average?"
            )
        ]
    },
    {
        "recursion_limit": 100,
    },
)
for step in steps:
    print(step)
    print("---")

{'plan_and_schedule': {'messages': [FunctionMessage(content='ERROR(Failed to call search with args {}. Args resolved to {}. Error: 1 validation error for SearchSchema\nquery\n  Field required [type=missing, input_value={}, input_type=dict]\n    For further information visit https://errors.pydantic.dev/2.10/v/missing)', additional_kwargs={'idx': 1, 'args': {}}, response_metadata={}, name='search', id='1331d2cb-9361-43eb-aa54-bad8b9ac0d97', tool_call_id=1), FunctionMessage(content='ERROR(Failed to call search with args {}. Args resolved to {}. Error: 1 validation error for SearchSchema\nquery\n  Field required [type=missing, input_value={}, input_type=dict]\n    For further information visit https://errors.pydantic.dev/2.10/v/missing)', additional_kwargs={'idx': 2, 'args': {}}, response_metadata={}, name='search', id='2303e92a-c29e-47e0-bade-5faf7c09bbb6', tool_call_id=2), FunctionMessage(content="ERROR(Failed to call math with args {'context': ['oldest parrot alive age', 'average lifesp

In [36]:
# Final answer
print(step["join"]["messages"][-1].content)

I'm unable to find the current information on the oldest parrot alive and how much longer that is than the average lifespan of parrots due to search execution errors. For the most accurate and recent information, I recommend checking reputable sources or databases that track animal records directly.


### Multi-step math

In [37]:
for step in chain.stream(
    {
        "messages": [
            HumanMessage(
                content="What's ((3*(4+5)/0.5)+3245) + 8? What's 32/4.23? What's the sum of those two values?"
            )
        ]
    }
):
    print(step)

{'plan_and_schedule': {'messages': [FunctionMessage(content='3307.0', additional_kwargs={'idx': 1, 'args': {'problem': '((3*(4+5)/0.5)+3245) + 8'}}, response_metadata={}, name='math', id='4ef65a07-db80-4097-a7b6-fdd82da3337f', tool_call_id=1), FunctionMessage(content='7.565011820330969', additional_kwargs={'idx': 2, 'args': {'problem': '32/4.23'}}, response_metadata={}, name='math', id='8902cb4f-5ab8-40d2-9364-d31347542c28', tool_call_id=2), FunctionMessage(content='join', additional_kwargs={'idx': 3, 'args': ()}, response_metadata={}, name='join', id='92c23ba5-bacb-4c3d-952c-f71d675bf4a8', tool_call_id=3)]}}
{'join': {'messages': [AIMessage(content='Thought: The computation results for the two separate expressions provided by the user have been obtained: ((3*(4+5)/0.5)+3245) + 8 = 3307.0 and 32/4.23 = 7.565011820330969. To answer the final query, the sum of these two values is needed.', additional_kwargs={}, response_metadata={}, id='534f5f04-4fec-4143-98ca-1fc2190fb085'), AIMessage(c

In [38]:
# Final answer
print(step["join"]["messages"][-1].content)

The answer to the first calculation is 3307.0, the answer to the second calculation is approximately 7.57 (rounded to two decimal places), and the sum of those two values is approximately 3314.57.


### Complex Replanning Example

In [39]:
for step in chain.stream(
    {
        "messages": [
            HumanMessage(
                content="Find the current temperature in Tokyo, then, respond with a flashcard summarizing this information"
            )
        ]
    }
):
    print(step)

{'join': {'messages': [AIMessage(content='Thought: The relevant information needed to create a flashcard summarizing the current temperature in Tokyo has been gathered.', additional_kwargs={}, response_metadata={}, id='657fe944-4890-4e4c-9b67-9626bbe9feab'), AIMessage(content='**Flashcard: Current Temperature in Tokyo**\n\n- **Temperature:** 8°C (46.4°F)\n- **Feels Like:** 5°C (due to wind, humidity, and other conditions)\n- **Condition:** Clear\n- **Humidity:** 54%\n- **Wind:** 10.1 km/h\n- **Pressure:** 1016 mb\n- **Visibility:** 10 km\n- **Current Weather Assessment:** Fair, with a chance of rain at 0%.', additional_kwargs={}, response_metadata={}, id='e7c11974-e976-43c5-9ab5-fc060b1dd26c')]}}
