<img src="../images/define-the-state.png" width="850" height="500">

In [None]:
from dotenv import load_dotenv
# OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
# LANGCHAIN_API_KEY = os.getenv("LANGCHAIN_API_KEY")
# LANGCHAIN_PROJECT = os.getenv("LANGCHAIN_PROJECT")
load_dotenv()

In [None]:
import operator
import json
from typing import TypedDict, Annotated, Sequence

from langgraph.prebuilt import ToolExecutor
from langgraph.prebuilt import ToolInvocation
from langchain_core.utils.function_calling import format_tool_to_openai_function
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.messages import FunctionMessage

from langchain_openai import ChatOpenAI
from langchain_community.tools.ddg_search import DuckDuckGoSearchRun


In [None]:

# Setup the tools
tools = [DuckDuckGoSearchRun()]
tool_executor = ToolExecutor(tools)

# Setup the llm
llm = ChatOpenAI(temperature=0, streaming=True)
functions = [format_tool_to_openai_function(t) for t in tools]
model = llm.bind_functions(functions)

<img src="../images/define-the-state.png" width="850" height="500">

In [None]:
class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], operator.add]

<img src="../images/agent-loop.png" width="550" height="500">

In [None]:
# Define the Nodes

def agent_node(state):
    messages = state["messages"]
    response = model.invoke(messages)
    return {"messages": [response]}


def should_continue(state):
    messages = state["messages"]
    last_message = messages[-1]

    if "function_call" not in last_message.additional_kwargs:
        return "end"
    else:
        return "continue"


def action(state):
    messages = state["messages"]
    last_message = messages[-1]
    action = ToolInvocation(
        tool=last_message.additional_kwargs["function_call"]["name"],
        tool_input=json.loads(
            last_message.additional_kwargs["function_call"]["arguments"]
        ),
    )
    response = tool_executor.invoke(action)
    function_message = FunctionMessage(content=str(response), name=action.tool)
    return {"messages": [function_message]}

In [None]:
# Define the graph (Nodes + Edges)

from langgraph.graph import StateGraph, END
workflow = StateGraph(AgentState)

workflow.add_node("agent", agent_node)
workflow.add_node("action", action)

workflow.set_entry_point("agent")

workflow.add_conditional_edges(
    "agent",
    should_continue,
    {
        "continue": "action",
        "end": END,
    },
)


workflow.add_edge("action", "agent")
agent_executor_with_mem = workflow.compile()

In [None]:
print(agent_executor_with_mem.get_graph().draw_ascii())

In [None]:
output = agent_executor_with_mem.invoke({"messages": [HumanMessage(content="what is the population size of canada?")]})
output["messages"][-1]

In [None]:
output["messages"]

In [None]:
output = agent_executor_with_mem.invoke({"messages": output["messages"]+[HumanMessage(content="and for china?")]})
output["messages"][-1]

In [None]:

inputs = {"messages": [HumanMessage(content="what is the weather in sf")]}
for event in agent_executor_with_mem.stream(inputs):
    for key, value in event.items():
        print(f"Output from node '{key}':")
        print("---")
        print(value)
    print("\n---\n")

In [None]:
from langgraph.checkpoint.sqlite import SqliteSaver

memory = SqliteSaver.from_conn_string(":memory:")
agent_executor_with_persistence_mem = workflow.compile(checkpointer=memory)

inputs = {"messages":[HumanMessage(content="what do you think about apple stock?")]}
for event in agent_executor_with_persistence_mem.stream(inputs, {"configurable": {"thread_id": "2"}}):
    for k, v in event.items():
        if k != "__end__":
            print(v)

In [None]:
for msg in agent_executor_with_persistence_mem.get_state( {"configurable": {"thread_id": "2"}}).values.get("messages"):
    print(msg.__class__)
    print(msg)
    print("#####")

In [None]:
inputs = {"messages":[HumanMessage(content="and microsoft?")]}
for event in agent_executor_with_persistence_mem.stream(inputs, {"configurable": {"thread_id": "2"}}):
    for k, v in event.items():
        if k != "__end__":
            print(v)

In [None]:
for msg in agent_executor_with_persistence_mem.get_state( {"configurable": {"thread_id": "2"}}).values.get("messages"):
    print(msg.__class__)
    print(msg)
    print("#####")



In [None]:
# View in langsmith