In [1]:
import os
from typing import Annotated, Sequence, TypedDict, Literal
import operator
import functools

from langchain_core.messages import (
    BaseMessage,
    HumanMessage,
    ToolMessage,
)
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.tools import tool
from langchain_experimental.utilities import PythonREPL
from langchain_openai import ChatOpenAI
from langchain_core.messages import AIMessage

from langgraph.graph import END, StateGraph
from langgraph.prebuilt import ToolNode

# Define Tools

In [2]:
os.environ["TAVILY_API_KEY"] = "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"
tavily_tool = TavilySearchResults(max_results=5)

repl = PythonREPL()

@tool
def python_repl(
    code: Annotated[str, "The python code to execute to generate your chart."],
):
    """Use this to execute python code. If you want to see the output of a value,
    you should print it out with `print(...)`. This is visible to the user."""
    try:
        result = repl.run(code)
    except BaseException as e:
        return f"Failed to execute. Error: {repr(e)}"
    result_str = f"Successfully executed:\n```python\n{code}\n```\nStdout: {result}"
    return (
        result_str + "\n\nIf you have completed all tasks, respond with FINAL ANSWER."
    )

# State of Graph
A list of messages, along with a key to track the most recent sender

In [3]:
# This defines the object that is passed between each node
# in the graph. We will create different nodes for each agent and tool
class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], operator.add]
    sender: str

# Define Nodes
Agent Nodes, Tool Nodes

In [4]:
"""
Agent Nodes
"""
def create_agent(llm, tools, system_message: str):
    """Create an agent."""
    prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                "You are a helpful AI assistant, collaborating with other assistants."
                " Use the provided tools to progress towards answering the question."
                " If you are unable to fully answer, that's OK, another assistant with different tools "
                " will help where you left off. Execute what you can to make progress."
                " If you or any of the other assistants have the final answer or deliverable,"
                " prefix your response with FINAL ANSWER so the team knows to stop."
                " You have access to the following tools: {tool_names}.\n{system_message}",
            ),
            MessagesPlaceholder(variable_name="messages"),
        ]
    )
    prompt = prompt.partial(system_message=system_message)
    prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
    return prompt | llm.bind_tools(tools) # '|' indicates that we are combining the prompt and the tools

def agent_node(state, agent, name):
    """Helper function to create a node for a given agent"""
    result = agent.invoke(state)
    # We convert the agent output into a format that is suitable to append to the global state
    if isinstance(result, ToolMessage):
        pass
    else:
        result = AIMessage(**result.dict(exclude={"type", "name"}), name=name)
    return {
        "messages": [result],
        # Since we have a strict workflow, we can
        # track the sender so we know who to pass to next.
        "sender": name,
    }


llm = ChatOpenAI(model="gpt-4-1106-preview", api_key="sk-proj-DFEqmV2bESTGXITqzVrHT3BlbkFJ3ndYJrjURSkNmALp5kqS")

# Research agent and node
research_agent = create_agent(
    llm,
    [tavily_tool],
    system_message="You should provide accurate data for the chart_generator to use.",
)
research_node = functools.partial(agent_node, agent=research_agent, name="Researcher")

# chart_generator
chart_agent = create_agent(
    llm,
    [python_repl],
    system_message="Any charts you display will be visible by the user.",
)
chart_node = functools.partial(agent_node, agent=chart_agent, name="chart_generator")

"""
Tool Nodes
"""
tools = [tavily_tool, python_repl]
tool_node = ToolNode(tools)

# Define Edge Logic
edge logic that is needed to decide what to do based on results of the agents

In [5]:
def router(state) -> Literal["call_tool", "__end__", "continue"]:
    """
        This is the router
        Either agent can decide to end
    """
    messages = state["messages"]
    last_message = messages[-1]
    if last_message.tool_calls:
        # The previous agent is invoking a tool
        return "call_tool"
    if "FINAL ANSWER" in last_message.content:
        # Any agent decided the work is done
        return "__end__"
    return "continue"

# Define Graph

In [6]:
workflow = StateGraph(AgentState)

workflow.add_node("Researcher", research_node)
workflow.add_node("chart_generator", chart_node)
workflow.add_node("call_tool", tool_node)

workflow.add_conditional_edges(
    "Researcher",
    router,
    {"continue": "chart_generator", "call_tool": "call_tool", "__end__": END},
)
workflow.add_conditional_edges(
    "chart_generator",
    router,
    {"continue": "Researcher", "call_tool": "call_tool", "__end__": END},
)

workflow.add_conditional_edges(
    "call_tool",
    # Each agent node updates the 'sender' field
    # the tool calling node does not, meaning
    # this edge will route back to the original agent who invoked the tool
    lambda x: x["sender"], # Researcher or chart_generator
    {
        "Researcher": "Researcher",
        "chart_generator": "chart_generator",
    },
)
workflow.set_entry_point("Researcher")
graph = workflow.compile()

In [46]:
import json
from typing import List, Dict, Any, Union

def extract_content_and_urls(value: Dict[str, Any]) -> List[Dict[str, Union[str, Dict[str, str]]]]:
    result = []
    possible_keys = ['call_tool', 'Researcher']

    for key in possible_keys:
        if key in value:
            data = value[key]
            if 'messages' in data:
                messages = data['messages']
                if isinstance(messages, list) and len(messages) > 0:
                    message = messages[0]
                    content = message.content
                    # Check if the content is a JSON string
                    try:
                        json_content = json.loads(content)
                        # Handle case where content is a JSON string
                        for item in json_content:
                            url = item.get('url')
                            content = item.get('content')
                            result.append({'url': url, 'content': content})
                    except json.JSONDecodeError:
                        # Handle case where content is a regular string
                        result.append({'content': content})
            break  # Stop after finding the first valid key
    return result

In [47]:
while True:
    user_input = input("User: ") # what is the population of tokyo in 2020?
    if user_input.lower() in ["quit", "exit", "q"]:
        print("Goodbye!")
        break
    events = graph.stream(
    {
        "messages": [
            HumanMessage(
                content=user_input
            )
        ],
    },
    # Maximum number of steps to take in the graph
    {"recursion_limit": 10},
    )
    for event in events:
        print(extract_content_and_urls(event))
        print("----")

{'Researcher': {'messages': [AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_YiNNpRS2p0KuZql5j9Dn7oAd', 'function': {'arguments': '{"query":"population of Tokyo in 2020"}', 'name': 'tavily_search_results_json'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 24, 'prompt_tokens': 206, 'total_tokens': 230}, 'model_name': 'gpt-4-1106-preview', 'system_fingerprint': None, 'finish_reason': 'tool_calls', 'logprobs': None}, name='Researcher', id='run-b35d61a5-8962-4900-8983-f252f47f4a14-0', tool_calls=[{'name': 'tavily_search_results_json', 'args': {'query': 'population of Tokyo in 2020'}, 'id': 'call_YiNNpRS2p0KuZql5j9Dn7oAd'}], usage_metadata={'input_tokens': 206, 'output_tokens': 24, 'total_tokens': 230})], 'sender': 'Researcher'}}
!!
[{'content': ''}]
----
{'call_tool': {'messages': [ToolMessage(content='[{"url": "https://www.macrotrends.net/global-metrics/cities/21671/tokyo/population", "content": "The metro area population of Tokyo in