In [None]:
from dotenv import load_dotenv

load_dotenv()

In [None]:
from typing import Literal
from langchain_core.messages import HumanMessage
from langgraph.graph import END, START, StateGraph, MessagesState
from langchain_openai import ChatOpenAI
from langchain_core.tools import tool
from langgraph.prebuilt import ToolNode


# Define a tool for weather information
@tool
def get_weather(location: str):
    """Call to get the current weather."""
    if location.lower() in ["munich"]:
        return "It's 15 degrees Celsius and cloudy."
    else:
        return "It's 32 degrees Celsius and sunny."


tools = [get_weather]
model = ChatOpenAI(model="gpt-4o-mini").bind_tools(tools)


def call_model(state: MessagesState):
    messages = state["messages"]
    response = model.invoke(messages)
    return {"messages": [response]}


def should_continue(state: MessagesState) -> Literal["tools", END]:
    messages = state["messages"]
    last_message = messages[-1]
    if last_message.tool_calls:
        return "tools"
    return END


subgraph_workflow = StateGraph(MessagesState)
tool_node = ToolNode(tools)

subgraph_workflow.add_node("agent", call_model)
subgraph_workflow.add_node("tools", tool_node)
subgraph_workflow.add_conditional_edges("agent", should_continue)
subgraph_workflow.add_edge("tools", "agent")
subgraph_workflow.set_entry_point("agent")

subgraph = subgraph_workflow.compile()

In [None]:
from IPython.display import Image, display
from langchain_core.runnables.graph import MermaidDrawMethod

display(
    Image(
        subgraph.get_graph().draw_mermaid_png(
            draw_method=MermaidDrawMethod.API,
        )
    )
)

In [None]:
subgraph.invoke({"messages": [HumanMessage(content="How is the weather in Munich?")]})

In [None]:
from typing import TypedDict
from langgraph.graph import StateGraph


# Dummy start node implementation
def start_node(state: MessagesState):
    return state


main_graph = StateGraph(MessagesState)
main_graph.add_node("start", start_node)
main_graph.add_node("subgraph", subgraph)
main_graph.add_edge(START, "start")
main_graph.add_edge("start", "subgraph")

graph = main_graph.compile()

In [None]:
initial_state = {"messages": [HumanMessage(content="What's the weather in Munich?")]}
result = graph.invoke(initial_state, subgraphs=False)

In [None]:
result

In [None]:
class MessagesState(TypedDict):
    parent_messages: list[str]


def invoke_subgraph(state: MessagesState):
    subgraph_output = subgraph.invoke({"messages": state["parent_messages"]})
    state["parent_messages"] = subgraph_output["messages"]
    return state


main_graph = StateGraph(MessagesState)
main_graph.add_node("start", start_node)
main_graph.add_node("invoke_subgraph", invoke_subgraph)
main_graph.add_edge(START, "start")
main_graph.add_edge("start", "invoke_subgraph")


graph = main_graph.compile()

In [None]:
initial_state = {
    "parent_messages": [HumanMessage(content="What's the weather in Munich?")]
}
result = graph.invoke(initial_state, subgraphs=True)

In [None]:
result