Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions langgraph_swarm/handoff.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,35 @@
import re
from typing import Annotated
from dataclasses import is_dataclass
from typing import Annotated, Any

from langchain_core.messages import ToolMessage
from langchain_core.tools import BaseTool, InjectedToolCallId, tool
from langgraph.graph.state import CompiledStateGraph
from langgraph.prebuilt import InjectedState, ToolNode
from langgraph.types import Command
from pydantic import BaseModel


def _get_field(obj: Any, key: str) -> Any:
"""Get a field from an object.

This function retrieves a field from a dictionary, dataclass, or Pydantic model.

Args:
obj: The object from which to retrieve the field.
key: The key or attribute name of the field to retrieve.

Returns:
The value of the specified field.

"""
if isinstance(obj, dict):
return obj[key]
if is_dataclass(obj) or isinstance(obj, BaseModel):
return getattr(obj, key)
msg = f"Unsupported type for state: {type(obj)}"
raise TypeError(msg)


WHITESPACE_RE = re.compile(r"\s+")
METADATA_KEY_HANDOFF_DESTINATION = "__handoff_destination"
Expand Down Expand Up @@ -45,7 +69,10 @@ def create_handoff_tool(

@tool(name, description=description)
def handoff_to_agent(
state: Annotated[dict, InjectedState],
# Annotation is typed as Any instead of StateLike. StateLike
# trigger validation issues from Pydantic / langchain_core interaction.
# https://github.com/langchain-ai/langchain/issues/32067
state: Annotated[Any, InjectedState],
tool_call_id: Annotated[str, InjectedToolCallId],
) -> Command:
tool_message = ToolMessage(
Expand All @@ -57,7 +84,7 @@ def handoff_to_agent(
goto=agent_name,
graph=Command.PARENT,
update={
"messages": state["messages"] + [tool_message],
"messages": [*_get_field(state, "messages"), tool_message],
Copy link
Preview

Copilot AI Jul 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The unpacking syntax [*_get_field(state, "messages"), tool_message] could be clearer. Consider using _get_field(state, "messages") + [tool_message] for better readability, especially since the original code used concatenation.

Suggested change
"messages": [*_get_field(state, "messages"), tool_message],
"messages": _get_field(state, "messages") + [tool_message],

Copilot uses AI. Check for mistakes.

"active_agent": agent_name,
},
)
Expand Down
123 changes: 123 additions & 0 deletions tests/test_swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from langchain_core.tools import BaseTool
from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import create_react_agent
from langgraph.prebuilt.chat_agent_executor import AgentStatePydantic

from langgraph_swarm import create_handoff_tool, create_swarm

Expand Down Expand Up @@ -149,3 +150,125 @@ def add(a: int, b: int) -> int:
assert turn_2["messages"][-2].content == "12"
assert turn_2["messages"][-1].content == recorded_messages[4].content
assert turn_2["active_agent"] == "Alice"


def test_basic_swarm_pydantic() -> None:
"""Test a basic swarm with Pydantic state schema."""

class SwarmState(AgentStatePydantic):
"""State schema for the multi-agent swarm."""

# NOTE: this state field is optional and is not expected to be provided by the
# user.
# If a user does provide it, the graph will start from the specified active
# agent.
# If active agent is typed as a `str`, we turn it into enum of all active agent
# names.
active_agent: str | None = None

recorded_messages = [
AIMessage(
content="",
name="Alice",
tool_calls=[
{
"name": "transfer_to_bob",
"args": {},
"id": "call_1LlFyjm6iIhDjdn7juWuPYr4",
},
],
),
AIMessage(
content="Ahoy, matey! Bob the pirate be at yer service. What be ye needin' "
"help with today on the high seas? Arrr!",
name="Bob",
),
AIMessage(
content="",
name="Bob",
tool_calls=[
{
"name": "transfer_to_alice",
"args": {},
"id": "call_T6pNmo2jTfZEK3a9avQ14f8Q",
},
],
),
AIMessage(
content="",
name="Alice",
tool_calls=[
{
"name": "add",
"args": {
"a": 5,
"b": 7,
},
"id": "call_4kLYO1amR2NfhAxfECkALCr1",
},
],
),
AIMessage(
content="The sum of 5 and 7 is 12.",
name="Alice",
),
]

model = FakeChatModel(responses=recorded_messages) # type: ignore[arg-type]

def add(a: int, b: int) -> int:
"""Add two numbers."""
return a + b

alice = create_react_agent(
model,
[add, create_handoff_tool(agent_name="Bob")],
prompt="You are Alice, an addition expert.",
name="Alice",
state_schema=SwarmState,
)

bob = create_react_agent(
model,
[
create_handoff_tool(
agent_name="Alice",
description="Transfer to Alice, she can help with math",
),
],
prompt="You are Bob, you speak like a pirate.",
name="Bob",
state_schema=SwarmState,
)

checkpointer = MemorySaver()
workflow = create_swarm([alice, bob], default_active_agent="Alice")
app = workflow.compile(checkpointer=checkpointer)

config: RunnableConfig = {"configurable": {"thread_id": "1"}}
turn_1 = app.invoke(
{ # type: ignore[arg-type]
"messages": [{"role": "user", "content": "i'd like to speak to Bob"}]
},
config,
)

# Verify turn 1 results
assert len(turn_1["messages"]) == 4
assert turn_1["messages"][-2].content == "Successfully transferred to Bob"
assert turn_1["messages"][-1].content == recorded_messages[1].content
assert turn_1["active_agent"] == "Bob"

turn_2 = app.invoke(
{ # type: ignore[arg-type]
"messages": [{"role": "user", "content": "what's 5 + 7?"}]
},
config,
)

# Verify turn 2 results
assert len(turn_2["messages"]) == 10
assert turn_2["messages"][-4].content == "Successfully transferred to Alice"
assert turn_2["messages"][-2].content == "12"
assert turn_2["messages"][-1].content == recorded_messages[4].content
assert turn_2["active_agent"] == "Alice"