Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions dimos/agents/mcp/fixtures/test_image.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@
}
]
},
{
"content": "I've taken a picture. Let me analyze and describe it for you.\nThe image features an expansive outdoor stadium. From the camera's perspective, the word 'stadium' best matches the image. Is there anything else you'd like to know or do?",
"tool_calls": []
},
{
"content": "The image shows a group of people sitting and enjoying their time at an outdoor cafe. Therefore, the word 'cafe' best matches the image.",
"tool_calls": []
Expand Down
182 changes: 150 additions & 32 deletions dimos/agents/mcp/mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,17 @@
from queue import Empty, Queue
from threading import Event, RLock, Thread
import time
from typing import Any
from typing import Annotated, Any, cast
import uuid

import httpx
from langchain.agents import create_agent
from langchain_core.messages import HumanMessage
from langchain.agents import AgentState, create_agent
from langchain_core.messages import AIMessage, HumanMessage, ToolCall, ToolMessage
from langchain_core.messages.base import BaseMessage
from langchain_core.tools import StructuredTool
from langchain_core.tools import InjectedToolCallId, StructuredTool
from langgraph.graph.message import add_messages
from langgraph.graph.state import CompiledStateGraph
from langgraph.types import Command
from reactivex.disposable import Disposable

from dimos.agents.mcp import tool_stream
Expand All @@ -41,6 +43,61 @@
logger = setup_logger()


def _fix_parallel_tool_batches(messages: list[BaseMessage]) -> list[BaseMessage]:
"""Reorder interleaved [Tool, Human, Tool, Human, ...] runs that
follow a parallel-tool-call AIMessage into [Tool, Tool, ..., Human,
Human, ...] so OpenAI's "all parallel tool responses must be
contiguous" rule is satisfied.

Image-carrying HumanMessages emitted by the MCP tool wrapper are
tagged with `additional_kwargs["tool_call_id"]` matching the
originating tool call, which is how we pair each Human with its
parallel batch.
"""
out = list(messages)
i = 0
while i < len(out):
msg = out[i]
tool_calls = getattr(msg, "tool_calls", None) or []
if isinstance(msg, AIMessage) and len(tool_calls) >= 2:
expected_ids = {tc.get("id") for tc in tool_calls if tc.get("id")}
tool_msgs: list[ToolMessage] = []
other_msgs: list[BaseMessage] = []
j = i + 1
while j < len(out):
m = out[j]
if isinstance(m, ToolMessage) and m.tool_call_id in expected_ids:
tool_msgs.append(m)
j += 1
elif (
isinstance(m, HumanMessage)
and getattr(m, "additional_kwargs", {}).get("tool_call_id") in expected_ids
):
other_msgs.append(m)
j += 1
else:
break
if tool_msgs and other_msgs and {m.tool_call_id for m in tool_msgs} == expected_ids:
out[i + 1 : j] = [*tool_msgs, *other_msgs]
i += 1
return out


def _reorder_tool_responses(
left: list[BaseMessage], right: list[BaseMessage] | BaseMessage
) -> list[BaseMessage]:
"""Standard add_messages merge, then fix any parallel-tool batches."""
# add_messages is typed against langgraph's permissive Messages union;
# list[BaseMessage] is invariant so we cast at the boundary.
merged = cast("list[BaseMessage]", add_messages(cast("Any", left), cast("Any", right)))
return _fix_parallel_tool_batches(merged)


class _OrderedAgentState(AgentState[Any]):
# Override the messages reducer to keep parallel ToolMessages contiguous.
messages: Annotated[list[BaseMessage], _reorder_tool_responses] # type: ignore[misc]


class McpClientConfig(ModuleConfig):
system_prompt: str | None = SYSTEM_PROMPT
model: str = "gpt-4o"
Expand Down Expand Up @@ -168,24 +225,42 @@ def _mcp_tool_to_langchain(self, mcp_tool: dict[str, Any]) -> StructuredTool:
description = mcp_tool.get("description", "")
input_schema = mcp_tool.get("inputSchema", {"type": "object", "properties": {}})

def call_tool(**kwargs: Any) -> str:
def call_tool(
tool_call_id: Annotated[str, InjectedToolCallId],
**kwargs: Any,
) -> str | Command[Any]:
result = self._mcp_tool_call(name, kwargs)
content = result.get("content", [])
parts = [c.get("text", "") for c in content if c.get("type") == "text"]
text = "\n".join(parts)

# Images need to be added to the history separately because they
# cannot be included in the tool response for OpenAI models and
# probably others.
for item in content:
if item.get("type") != "text":
uuid_ = str(uuid.uuid4())
text += f"Tool call started with UUID: {uuid_}. You will be updated with the result soon."
_append_image_to_history(self, name, uuid_, item)

return text
text = "\n".join(c.get("text", "") for c in content if c.get("type") == "text")
image_blocks = [c for c in content if c.get("type") != "text"]

if not image_blocks:
return text

# Vision content can't be embedded inside a ToolMessage for OpenAI
# (and others), so we use Command to append a follow-up HumanMessage
# carrying the image blocks within the same agent turn.
#
# The HumanMessage is tagged with `additional_kwargs["tool_call_id"]`
# so `_fix_parallel_tool_batches` can pair it with the right
# ToolMessage when multiple parallel tool calls return images
# in one batch (OpenAI requires the parallel ToolMessages to
# stay contiguous).
summary = text or f"{name} returned {len(image_blocks)} non-text artefact(s)."
intro = f"Artefacts returned by '{name}' (image follows):"
return Command(
update={
"messages": [
ToolMessage(content=summary, tool_call_id=tool_call_id),
HumanMessage(
content=[{"type": "text", "text": intro}, *image_blocks],
additional_kwargs={"tool_call_id": tool_call_id},
),
]
}
)

return StructuredTool(
return _McpStructuredTool(
name=name,
description=description,
func=call_tool,
Expand Down Expand Up @@ -223,6 +298,7 @@ def on_system_modules(self, _modules: list[RPCClient]) -> None:
model=model,
tools=tools,
system_prompt=self.config.system_prompt,
state_schema=cast("type[AgentState[Any]]", _OrderedAgentState),
)
if not self._thread.is_alive():
self._thread.start()
Expand Down Expand Up @@ -330,21 +406,63 @@ def _process_message(
pretty_print_langchain_message(msg)
self.agent.publish(msg)

# The graph applies _reorder_tool_responses to its internal channel,
# but stream_mode="updates" emits raw node outputs in completion
# order — and langgraph does not re-run reducers when an initial
# state dict is fed back into stream() on the next turn. Mirror the
# reducer here so _history matches what the graph would produce.
self._history = _fix_parallel_tool_batches(self._history)

if self._message_queue.empty():
self.agent_idle.publish(True)


def _append_image_to_history(
mcp_client: McpClient, func_name: str, uuid_: str, result: Any
) -> None:
mcp_client.add_message(
HumanMessage(
content=[
{
"type": "text",
"text": f"This is the artefact for the '{func_name}' tool with UUID:={uuid_}.",
},
result,
]
class _McpStructuredTool(StructuredTool):
"""StructuredTool that propagates `tool_call_id` to MCP tools whose
`args_schema` is a JSON-Schema dict.

Langchain auto-injects `InjectedToolCallId` only when `args_schema` is
a Pydantic model; MCP servers supply JSON-Schema dicts (the server's
authoritative contract for the LLM), so we have to bridge ourselves.

The bridge uses only public Runnable surface — `invoke` / `ainvoke`
accept a `ToolCall` dict as documented input, and we copy its `id`
field into the function's kwargs before delegating.
"""

def invoke(
self,
input: str | dict[Any, Any] | ToolCall,
config: Any = None,
**kwargs: Any,
) -> Any:
return super().invoke(_inject_tool_call_id(input), config=config, **kwargs)

async def ainvoke(
self,
input: str | dict[Any, Any] | ToolCall,
config: Any = None,
**kwargs: Any,
) -> Any:
return await super().ainvoke(_inject_tool_call_id(input), config=config, **kwargs)
Comment thread
Mgczacki marked this conversation as resolved.


def _inject_tool_call_id(
input: str | dict[Any, Any] | ToolCall,
) -> dict[Any, Any]:
"""Copy the `ToolCall.id` field into `args.tool_call_id` so the inner
`call_tool` closure receives it as a kwarg. JSON-Schema dicts don't
validate against extra keys, so this is transparent to the schema.

Raises ValueError on any invocation that isn't a `ToolCall`-shaped
dict with a non-null `id` — MCP tools have no other valid call path.
"""
if not (isinstance(input, dict) and input.get("type") == "tool_call"):
raise ValueError(
"MCP tool must be invoked via a ToolCall (a dict with "
"type='tool_call' and a non-null id), not a bare input."
)
)
tool_call_id = input.get("id")
if tool_call_id is None:
raise ValueError("MCP tool ToolCall is missing a non-null id.")
return {**input, "args": {**(input.get("args") or {}), "tool_call_id": tool_call_id}}
Loading
Loading