In [None]:
import os
import random
import typing
from dotenv import load_dotenv

import langchain.messages
from langchain.tools import tool
from langchain_openai import ChatOpenAI
from langgraph.graph import MessagesState, StateGraph
from langgraph.prebuilt import ToolNode

from langchain_openai import ChatOpenAI

load_dotenv();

# Client & Model
client = ChatOpenAI(
    model="gpt-4o-mini",
    base_url=os.environ["OPENAI_BASE_URL"],
    api_key=lambda: os.environ["OPENAI_API_KEY"]
)


In [None]:
@tool
def poem_theme_selector():
    """A tool for selecting a poem theme."""
    themes = ["space cowboys", "space pirates"]
    return random.choice(themes)  # noqa: S311

def tools_or_next(
    state: MessagesState,
) -> typing.Literal["__tools__", "__next__"]:
    """A function that selects __tools__ as long as last_message was a tool call, otherwise __next__."""
    last_message = state["messages"][-1]
    if hasattr(last_message, "tool_call_id"):
        return "__tools__"
    return "__next__"

class AgentWithTools:
    """An agent that can create answers with tools."""

    def __init__(self, name: str, instructions: str, client, tools: dict | None = None):
        self.name = name
        self.system_message = langchain.messages.SystemMessage(
            content=instructions
        )
        self.client = client

        # Add tools if available
        self.tools = tools
        self.tool_node = None
        if self.tools:
            self.client = self.client.bind_tools(self.tools.values())
            self.tool_node = ToolNode(self.tools.values())

    def action(
        self, state: MessagesState
    ) -> MessagesState:
        """Graph action to make a tool call or generate a response."""
        messages = state["messages"]

        # Get tool calls
        tool_response = self.client.invoke([self.system_message, *messages])
        messages.append(tool_response)

        # Invoke tools
        if tool_response.tool_calls:
            tool_output = self.tool_node.invoke({"messages": [tool_response]})
            messages.extend(tool_output["messages"])
        return {"messages": messages}


In [None]:
# Agents
tools = {"poem_theme_selector": poem_theme_selector}
agent_with_tools = AgentWithTools(
    name="poem_agent",
    instructions="Write a poem based on a theme.",
    client=client,
    tools=tools,
)

# Build graph
graph_builder = StateGraph(MessagesState)
graph_builder.add_node(node=agent_with_tools.name, action=agent_with_tools.action)

# Add edges
graph_builder.add_edge(start_key="__start__", end_key=agent_with_tools.name)
graph_builder.add_conditional_edges(
    source=agent_with_tools.name,
    path=tools_or_next,
    path_map={
        "__tools__": agent_with_tools.name,
        "__next__": "__end__",
    },
)
graph = graph_builder.compile()

# Display graph
display(graph)


In [None]:
messages = [langchain.messages.HumanMessage("Write a short poem. Add theme as title.")]
output = graph.invoke({"messages": messages})
for i, msg in enumerate(output["messages"]):
    if hasattr(msg, "tool_calls") and msg.tool_calls:
        print(f"{i + 1}: {msg.type.upper()}: {msg.tool_calls}")
    else:
        print(f"{i + 1}: {msg.type.upper()}: {msg.content}")
