Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=langgraph_swarm/
lint lint_diff:
[ "$(PYTHON_FILES)" = "" ] || uv run ruff format $(PYTHON_FILES) --diff
[ "$(PYTHON_FILES)" = "" ] || uv run ruff check $(PYTHON_FILES) --diff
# [ "$(PYTHON_FILES)" = "" ] || uv run mypy $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || uv run mypy $(PYTHON_FILES)

format format_diff:
[ "$(PYTHON_FILES)" = "" ] || uv run ruff check --fix $(PYTHON_FILES)
Expand Down
7 changes: 6 additions & 1 deletion langgraph_swarm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from langgraph_swarm.handoff import create_handoff_tool
from langgraph_swarm.swarm import SwarmState, add_active_agent_router, create_swarm

__all__ = ["SwarmState", "add_active_agent_router", "create_handoff_tool", "create_swarm"]
__all__ = [
"SwarmState",
"add_active_agent_router",
"create_handoff_tool",
"create_swarm",
]
14 changes: 10 additions & 4 deletions langgraph_swarm/handoff.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def create_handoff_tool(
def handoff_to_agent(
state: Annotated[dict, InjectedState],
tool_call_id: Annotated[str, InjectedToolCallId],
):
) -> Command:
tool_message = ToolMessage(
content=f"Successfully transferred to {agent_name}",
name=name,
Expand All @@ -56,14 +56,19 @@ def handoff_to_agent(
return Command(
goto=agent_name,
graph=Command.PARENT,
update={"messages": state["messages"] + [tool_message], "active_agent": agent_name},
update={
"messages": state["messages"] + [tool_message],
"active_agent": agent_name,
},
)

handoff_to_agent.metadata = {METADATA_KEY_HANDOFF_DESTINATION: agent_name}
return handoff_to_agent


def get_handoff_destinations(agent: CompiledStateGraph, tool_node_name: str = "tools") -> list[str]:
def get_handoff_destinations(
agent: CompiledStateGraph, tool_node_name: str = "tools"
) -> list[str]:
"""Get a list of destinations from agent's handoff tools."""
nodes = agent.get_graph().nodes
if tool_node_name not in nodes:
Expand All @@ -77,5 +82,6 @@ def get_handoff_destinations(agent: CompiledStateGraph, tool_node_name: str = "t
return [
tool.metadata[METADATA_KEY_HANDOFF_DESTINATION]
for tool in tools
if tool.metadata is not None and METADATA_KEY_HANDOFF_DESTINATION in tool.metadata
if tool.metadata is not None
and METADATA_KEY_HANDOFF_DESTINATION in tool.metadata
]
20 changes: 13 additions & 7 deletions langgraph_swarm/swarm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Literal, Optional, Union, get_args, get_origin
from typing import Literal, Optional, Union, cast, get_args, get_origin

from langgraph.graph import START, MessagesState, StateGraph
from langgraph.pregel import Pregel
Expand Down Expand Up @@ -30,7 +30,8 @@ def _update_state_schema_agent_names(
# Check if the annotation is str or Optional[str]
is_str_type = active_agent_annotation is str
is_optional_str = (
get_origin(active_agent_annotation) is Union and get_args(active_agent_annotation)[0] is str
get_origin(active_agent_annotation) is Union
and get_args(active_agent_annotation)[0] is str
)

# We only update if the 'active_agent' is a str or Optional[str]
Expand All @@ -48,7 +49,7 @@ def _update_state_schema_agent_names(

# If it was Optional[str], make it Optional[Literal[...]]
if is_optional_str:
updated_schema.__annotations__["active_agent"] = Optional[literal_type]
updated_schema.__annotations__["active_agent"] = Optional[literal_type] # noqa: UP045
else:
updated_schema.__annotations__["active_agent"] = literal_type

Expand Down Expand Up @@ -135,8 +136,8 @@ def add(a: int, b: int) -> int:
msg,
)

def route_to_active_agent(state: dict):
return state.get("active_agent", default_active_agent)
def route_to_active_agent(state: dict) -> str:
return cast("str", state.get("active_agent", default_active_agent))

builder.add_conditional_edges(START, route_to_active_agent, path_map=route_to)
return builder
Expand Down Expand Up @@ -223,8 +224,13 @@ def add(a: int, b: int) -> int:
for agent in agents:
builder.add_node(
agent.name,
agent,
destinations=tuple(get_handoff_destinations(agent)),
# We need to update the type signatures in add_node to match
# the fact that more flexible Pregel objects are allowed.
agent, # type: ignore[arg-type]
destinations=tuple(
# Need to update implementation to support Pregel objects
get_handoff_destinations(agent) # type: ignore[arg-type]
),
)

return builder
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ python_files = ["test_*.py"]
python_functions = ["test_*"]

[tool.ruff]
line-length = 100
line-length = 88
target-version = "py310"

[tool.ruff.lint]
Expand Down
5 changes: 0 additions & 5 deletions tests/test_import.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,2 @@
def test_import() -> None:
"""Test that the code can be imported."""
from langgraph_swarm import ( # noqa: F401
add_active_agent_router,
create_handoff_tool,
create_swarm,
)
29 changes: 23 additions & 6 deletions tests/test_swarm.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, Any

from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, BaseMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.tools import BaseTool
from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import create_react_agent

from langgraph_swarm import create_handoff_tool, create_swarm

if TYPE_CHECKING:
from langchain_core.runnables.config import RunnableConfig


class FakeChatModel(BaseChatModel):
idx: int = 0
Expand All @@ -21,13 +28,19 @@ def _generate(
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs,
**kwargs: Any,
) -> ChatResult:
generation = ChatGeneration(message=self.responses[self.idx])
self.idx += 1
return ChatResult(generations=[generation])

def bind_tools(self, tools: list[any]) -> "FakeChatModel":
def bind_tools(
self,
tools: Sequence[dict[str, Any] | type | Callable[..., Any] | BaseTool],
*,
tool_choice: str | None = None,
**kwargs: Any,
) -> "FakeChatModel":
return self


Expand Down Expand Up @@ -80,7 +93,7 @@ def test_basic_swarm() -> None:
),
]

model = FakeChatModel(responses=recorded_messages)
model = FakeChatModel(responses=recorded_messages) # type: ignore[arg-type]

def add(a: int, b: int) -> int:
"""Add two numbers."""
Expand Down Expand Up @@ -109,9 +122,11 @@ def add(a: int, b: int) -> int:
workflow = create_swarm([alice, bob], default_active_agent="Alice")
app = workflow.compile(checkpointer=checkpointer)

config = {"configurable": {"thread_id": "1"}}
config: RunnableConfig = {"configurable": {"thread_id": "1"}}
turn_1 = app.invoke(
{"messages": [{"role": "user", "content": "i'd like to speak to Bob"}]},
{ # type: ignore[arg-type]
"messages": [{"role": "user", "content": "i'd like to speak to Bob"}]
},
config,
)

Expand All @@ -122,7 +137,9 @@ def add(a: int, b: int) -> int:
assert turn_1["active_agent"] == "Bob"

turn_2 = app.invoke(
{"messages": [{"role": "user", "content": "what's 5 + 7?"}]},
{ # type: ignore[arg-type]
"messages": [{"role": "user", "content": "what's 5 + 7?"}]
},
config,
)

Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.