diff --git a/dimos/agents/mcp/fixtures/test_image.json b/dimos/agents/mcp/fixtures/test_image.json index 0e4816b8ee..060ed6ff54 100644 --- a/dimos/agents/mcp/fixtures/test_image.json +++ b/dimos/agents/mcp/fixtures/test_image.json @@ -11,10 +11,6 @@ } ] }, - { - "content": "I've taken a picture. Let me analyze and describe it for you.\nThe image features an expansive outdoor stadium. From the camera's perspective, the word 'stadium' best matches the image. Is there anything else you'd like to know or do?", - "tool_calls": [] - }, { "content": "The image shows a group of people sitting and enjoying their time at an outdoor cafe. Therefore, the word 'cafe' best matches the image.", "tool_calls": [] diff --git a/dimos/agents/mcp/mcp_client.py b/dimos/agents/mcp/mcp_client.py index 75b532e9cc..06a983a207 100644 --- a/dimos/agents/mcp/mcp_client.py +++ b/dimos/agents/mcp/mcp_client.py @@ -16,15 +16,17 @@ from queue import Empty, Queue from threading import Event, RLock, Thread import time -from typing import Any +from typing import Annotated, Any, cast import uuid import httpx -from langchain.agents import create_agent -from langchain_core.messages import HumanMessage +from langchain.agents import AgentState, create_agent +from langchain_core.messages import AIMessage, HumanMessage, ToolCall, ToolMessage from langchain_core.messages.base import BaseMessage -from langchain_core.tools import StructuredTool +from langchain_core.tools import InjectedToolCallId, StructuredTool +from langgraph.graph.message import add_messages from langgraph.graph.state import CompiledStateGraph +from langgraph.types import Command from reactivex.disposable import Disposable from dimos.agents.mcp import tool_stream @@ -41,6 +43,61 @@ logger = setup_logger() +def _fix_parallel_tool_batches(messages: list[BaseMessage]) -> list[BaseMessage]: + """Reorder interleaved [Tool, Human, Tool, Human, ...] runs that + follow a parallel-tool-call AIMessage into [Tool, Tool, ..., Human, + Human, ...] so OpenAI's "all parallel tool responses must be + contiguous" rule is satisfied. + + Image-carrying HumanMessages emitted by the MCP tool wrapper are + tagged with `additional_kwargs["tool_call_id"]` matching the + originating tool call, which is how we pair each Human with its + parallel batch. + """ + out = list(messages) + i = 0 + while i < len(out): + msg = out[i] + tool_calls = getattr(msg, "tool_calls", None) or [] + if isinstance(msg, AIMessage) and len(tool_calls) >= 2: + expected_ids = {tc.get("id") for tc in tool_calls if tc.get("id")} + tool_msgs: list[ToolMessage] = [] + other_msgs: list[BaseMessage] = [] + j = i + 1 + while j < len(out): + m = out[j] + if isinstance(m, ToolMessage) and m.tool_call_id in expected_ids: + tool_msgs.append(m) + j += 1 + elif ( + isinstance(m, HumanMessage) + and getattr(m, "additional_kwargs", {}).get("tool_call_id") in expected_ids + ): + other_msgs.append(m) + j += 1 + else: + break + if tool_msgs and other_msgs and {m.tool_call_id for m in tool_msgs} == expected_ids: + out[i + 1 : j] = [*tool_msgs, *other_msgs] + i += 1 + return out + + +def _reorder_tool_responses( + left: list[BaseMessage], right: list[BaseMessage] | BaseMessage +) -> list[BaseMessage]: + """Standard add_messages merge, then fix any parallel-tool batches.""" + # add_messages is typed against langgraph's permissive Messages union; + # list[BaseMessage] is invariant so we cast at the boundary. + merged = cast("list[BaseMessage]", add_messages(cast("Any", left), cast("Any", right))) + return _fix_parallel_tool_batches(merged) + + +class _OrderedAgentState(AgentState[Any]): + # Override the messages reducer to keep parallel ToolMessages contiguous. + messages: Annotated[list[BaseMessage], _reorder_tool_responses] # type: ignore[misc] + + class McpClientConfig(ModuleConfig): system_prompt: str | None = SYSTEM_PROMPT model: str = "gpt-4o" @@ -168,24 +225,42 @@ def _mcp_tool_to_langchain(self, mcp_tool: dict[str, Any]) -> StructuredTool: description = mcp_tool.get("description", "") input_schema = mcp_tool.get("inputSchema", {"type": "object", "properties": {}}) - def call_tool(**kwargs: Any) -> str: + def call_tool( + tool_call_id: Annotated[str, InjectedToolCallId], + **kwargs: Any, + ) -> str | Command[Any]: result = self._mcp_tool_call(name, kwargs) content = result.get("content", []) - parts = [c.get("text", "") for c in content if c.get("type") == "text"] - text = "\n".join(parts) - - # Images need to be added to the history separately because they - # cannot be included in the tool response for OpenAI models and - # probably others. - for item in content: - if item.get("type") != "text": - uuid_ = str(uuid.uuid4()) - text += f"Tool call started with UUID: {uuid_}. You will be updated with the result soon." - _append_image_to_history(self, name, uuid_, item) - - return text + text = "\n".join(c.get("text", "") for c in content if c.get("type") == "text") + image_blocks = [c for c in content if c.get("type") != "text"] + + if not image_blocks: + return text + + # Vision content can't be embedded inside a ToolMessage for OpenAI + # (and others), so we use Command to append a follow-up HumanMessage + # carrying the image blocks within the same agent turn. + # + # The HumanMessage is tagged with `additional_kwargs["tool_call_id"]` + # so `_fix_parallel_tool_batches` can pair it with the right + # ToolMessage when multiple parallel tool calls return images + # in one batch (OpenAI requires the parallel ToolMessages to + # stay contiguous). + summary = text or f"{name} returned {len(image_blocks)} non-text artefact(s)." + intro = f"Artefacts returned by '{name}' (image follows):" + return Command( + update={ + "messages": [ + ToolMessage(content=summary, tool_call_id=tool_call_id), + HumanMessage( + content=[{"type": "text", "text": intro}, *image_blocks], + additional_kwargs={"tool_call_id": tool_call_id}, + ), + ] + } + ) - return StructuredTool( + return _McpStructuredTool( name=name, description=description, func=call_tool, @@ -223,6 +298,7 @@ def on_system_modules(self, _modules: list[RPCClient]) -> None: model=model, tools=tools, system_prompt=self.config.system_prompt, + state_schema=cast("type[AgentState[Any]]", _OrderedAgentState), ) if not self._thread.is_alive(): self._thread.start() @@ -330,21 +406,63 @@ def _process_message( pretty_print_langchain_message(msg) self.agent.publish(msg) + # The graph applies _reorder_tool_responses to its internal channel, + # but stream_mode="updates" emits raw node outputs in completion + # order — and langgraph does not re-run reducers when an initial + # state dict is fed back into stream() on the next turn. Mirror the + # reducer here so _history matches what the graph would produce. + self._history = _fix_parallel_tool_batches(self._history) + if self._message_queue.empty(): self.agent_idle.publish(True) -def _append_image_to_history( - mcp_client: McpClient, func_name: str, uuid_: str, result: Any -) -> None: - mcp_client.add_message( - HumanMessage( - content=[ - { - "type": "text", - "text": f"This is the artefact for the '{func_name}' tool with UUID:={uuid_}.", - }, - result, - ] +class _McpStructuredTool(StructuredTool): + """StructuredTool that propagates `tool_call_id` to MCP tools whose + `args_schema` is a JSON-Schema dict. + + Langchain auto-injects `InjectedToolCallId` only when `args_schema` is + a Pydantic model; MCP servers supply JSON-Schema dicts (the server's + authoritative contract for the LLM), so we have to bridge ourselves. + + The bridge uses only public Runnable surface — `invoke` / `ainvoke` + accept a `ToolCall` dict as documented input, and we copy its `id` + field into the function's kwargs before delegating. + """ + + def invoke( + self, + input: str | dict[Any, Any] | ToolCall, + config: Any = None, + **kwargs: Any, + ) -> Any: + return super().invoke(_inject_tool_call_id(input), config=config, **kwargs) + + async def ainvoke( + self, + input: str | dict[Any, Any] | ToolCall, + config: Any = None, + **kwargs: Any, + ) -> Any: + return await super().ainvoke(_inject_tool_call_id(input), config=config, **kwargs) + + +def _inject_tool_call_id( + input: str | dict[Any, Any] | ToolCall, +) -> dict[Any, Any]: + """Copy the `ToolCall.id` field into `args.tool_call_id` so the inner + `call_tool` closure receives it as a kwarg. JSON-Schema dicts don't + validate against extra keys, so this is transparent to the schema. + + Raises ValueError on any invocation that isn't a `ToolCall`-shaped + dict with a non-null `id` — MCP tools have no other valid call path. + """ + if not (isinstance(input, dict) and input.get("type") == "tool_call"): + raise ValueError( + "MCP tool must be invoked via a ToolCall (a dict with " + "type='tool_call' and a non-null id), not a bare input." ) - ) + tool_call_id = input.get("id") + if tool_call_id is None: + raise ValueError("MCP tool ToolCall is missing a non-null id.") + return {**input, "args": {**(input.get("args") or {}), "tool_call_id": tool_call_id}} diff --git a/dimos/agents/mcp/test_mcp_client_unit.py b/dimos/agents/mcp/test_mcp_client_unit.py index ea26f7c54f..14788e321e 100644 --- a/dimos/agents/mcp/test_mcp_client_unit.py +++ b/dimos/agents/mcp/test_mcp_client_unit.py @@ -13,15 +13,21 @@ # limitations under the License. from __future__ import annotations +import asyncio import json from queue import Empty, Queue from unittest.mock import MagicMock, patch -from langchain_core.messages import HumanMessage +from langchain_core.messages import AIMessage, HumanMessage, ToolMessage from langchain_core.messages.base import BaseMessage +from langgraph.types import Command import pytest -from dimos.agents.mcp.mcp_client import McpClient +from dimos.agents.mcp.mcp_client import ( + McpClient, + _fix_parallel_tool_batches, + _reorder_tool_responses, +) from dimos.utils.sequential_ids import SequentialIds @@ -64,18 +70,54 @@ def _mock_post(url: str, **kwargs: object) -> MagicMock: }, }, }, + { + "name": "take_picture", + "description": "Take a picture", + "inputSchema": {"type": "object", "properties": {}}, + }, + { + "name": "narrate_picture", + "description": "Take a picture and describe what's in it", + "inputSchema": {"type": "object", "properties": {}}, + }, ] } elif method == "tools/call": name = body["params"]["name"] args = body["params"].get("arguments", {}) if name == "add": - text = str(args.get("x", 0) + args.get("y", 0)) + result = { + "content": [{"type": "text", "text": str(args.get("x", 0) + args.get("y", 0))}] + } elif name == "greet": - text = f"Hello, {args.get('name', 'world')}!" + result = {"content": [{"type": "text", "text": f"Hello, {args.get('name', 'world')}!"}]} + elif name == "take_picture": + # Simulates `dimos.msgs.sensor_msgs.Image.agent_encode()` output. + result = { + "content": [ + { + "type": "image_url", + "image_url": {"url": "data:image/jpeg;base64,FAKEPAYLOAD"}, + } + ] + } + elif name == "narrate_picture": + # Tool that returns both prose AND an image (e.g. a VLM + # describing what it sees). Exercises the `summary = text` + # branch of the Command-building path — the fallback + # "{name} returned N artefact(s)" sentinel must NOT be used + # when the tool already provided real text. + result = { + "content": [ + {"type": "text", "text": "I see a chair and a window."}, + { + "type": "image_url", + "image_url": {"url": "data:image/jpeg;base64,FAKEPAYLOAD"}, + }, + ] + } else: - text = "Skill not found" - result = {"content": [{"type": "text", "text": text}]} + result = {"content": [{"type": "text", "text": "Skill not found"}]} else: resp = MagicMock() resp.status_code = 200 @@ -113,9 +155,7 @@ def mcp_client() -> McpClient: def test_fetch_tools_from_mcp_server(mcp_client: McpClient) -> None: tools = mcp_client._fetch_tools() - assert len(tools) == 2 - assert tools[0].name == "add" - assert tools[1].name == "greet" + assert [t.name for t in tools] == ["add", "greet", "take_picture", "narrate_picture"] def test_tool_invocation_via_mcp(mcp_client: McpClient) -> None: @@ -123,8 +163,130 @@ def test_tool_invocation_via_mcp(mcp_client: McpClient) -> None: add_tool = next(t for t in tools if t.name == "add") greet_tool = next(t for t in tools if t.name == "greet") - assert add_tool.func(x=2, y=3) == "5" - assert greet_tool.func(name="Alice") == "Hello, Alice!" + # tool_call_id is an InjectedToolCallId argument; the LangGraph tool node + # supplies it at runtime, but here we call .func directly so we pass it + # explicitly. + assert add_tool.func(tool_call_id="tc-1", x=2, y=3) == "5" + assert greet_tool.func(tool_call_id="tc-2", name="Alice") == "Hello, Alice!" + + +def test_image_tool_returns_langgraph_command(mcp_client: McpClient) -> None: + """Non-text MCP content rides back as a `Command` that appends a + ``ToolMessage`` + image-bearing ``HumanMessage`` to the agent state. + + Replaces the previous side-channel (`add_message` after a UUID + placeholder), which forced an extra agent turn to deliver the image. + """ + tools = mcp_client._fetch_tools() + picture_tool = next(t for t in tools if t.name == "take_picture") + + out = picture_tool.func(tool_call_id="tc-image") + + assert isinstance(out, Command) + messages = out.update["messages"] + assert len(messages) == 2 + + tool_msg, human_msg = messages + assert isinstance(tool_msg, ToolMessage) + assert tool_msg.tool_call_id == "tc-image" + + assert isinstance(human_msg, HumanMessage) + # Tagged so the state reducer can pair this HumanMessage with the + # corresponding ToolMessage when several tool calls run in parallel. + assert human_msg.additional_kwargs.get("tool_call_id") == "tc-image" + blocks = human_msg.content + assert isinstance(blocks, list) + # First block is the intro text; the rest carry the image_url payload. + assert blocks[0]["type"] == "text" + assert any( + b.get("type") == "image_url" and "FAKEPAYLOAD" in b["image_url"]["url"] for b in blocks[1:] + ) + + +def test_image_tool_with_text_uses_real_text_as_tool_message(mcp_client: McpClient) -> None: + """When a tool returns BOTH text and image content, the ToolMessage + carries the tool's actual narration — not the + "{name} returned N artefact(s)" fallback sentinel. The image still + rides back on the follow-up HumanMessage as usual. + """ + tools = mcp_client._fetch_tools() + narrate_tool = next(t for t in tools if t.name == "narrate_picture") + + out = narrate_tool.func(tool_call_id="tc-narrate") + + assert isinstance(out, Command) + tool_msg, human_msg = out.update["messages"] + + assert isinstance(tool_msg, ToolMessage) + assert tool_msg.content == "I see a chair and a window." + assert "artefact" not in str(tool_msg.content) + + assert isinstance(human_msg, HumanMessage) + blocks = human_msg.content + assert isinstance(blocks, list) + assert any( + b.get("type") == "image_url" and "FAKEPAYLOAD" in b["image_url"]["url"] for b in blocks[1:] + ) + + +def test_structured_tool_invocation_injects_tool_call_id(mcp_client: McpClient) -> None: + """End-to-end: invoking via the ToolCall path lets the wrapper grab + `tool_call_id` even though `args_schema` is a JSON-Schema dict — the + behaviour Langchain only ships for Pydantic schemas out of the box. + """ + tools = mcp_client._fetch_tools() + picture_tool = next(t for t in tools if t.name == "take_picture") + + result = picture_tool.invoke( + { + "name": "take_picture", + "args": {}, + "id": "tc-via-invoke", + "type": "tool_call", + } + ) + + assert isinstance(result, Command) + messages = result.update["messages"] + assert messages[0].tool_call_id == "tc-via-invoke" + + +def test_structured_tool_ainvoke_injects_tool_call_id(mcp_client: McpClient) -> None: + """Async mirror of `test_structured_tool_invocation_injects_tool_call_id`: + langgraph's tool node may dispatch via `ainvoke` (e.g. under `astream`), + so the async path must inject `tool_call_id` the same way the sync path + does. A langchain release that changes ainvoke's call convention should + fail this test rather than silently drop the id. + """ + tools = mcp_client._fetch_tools() + picture_tool = next(t for t in tools if t.name == "take_picture") + + result = asyncio.run( + picture_tool.ainvoke( + { + "name": "take_picture", + "args": {}, + "id": "tc-via-ainvoke", + "type": "tool_call", + } + ) + ) + + assert isinstance(result, Command) + messages = result.update["messages"] + assert messages[0].tool_call_id == "tc-via-ainvoke" + + +def test_structured_tool_invocation_without_toolcall_raises(mcp_client: McpClient) -> None: + """Bare-dict invocation (no ToolCall envelope) must fail loud, so a + future langchain change that bypasses our `invoke` override is caught + by tests instead of silently dropping `tool_call_id`. + """ + tools = mcp_client._fetch_tools() + picture_tool = next(t for t in tools if t.name == "take_picture") + + with pytest.raises(ValueError, match="ToolCall"): + picture_tool.invoke({"name": "take_picture", "args": {}}) def test_mcp_request_error_propagation(mcp_client: McpClient) -> None: @@ -208,6 +370,169 @@ def test_tool_stream_progress_frame_becomes_human_message(mcp_client: McpClient) assert str(msg.content) == "[tool:follow_person] Found a person" +def _ai_with_parallel_calls(call_ids: list[str]) -> AIMessage: + return AIMessage( + content="", + tool_calls=[{"name": f"t{i}", "args": {}, "id": cid} for i, cid in enumerate(call_ids)], + ) + + +def _image_human(tool_call_id: str) -> HumanMessage: + return HumanMessage( + content=[{"type": "text", "text": "img"}], + additional_kwargs={"tool_call_id": tool_call_id}, + ) + + +def test_fix_parallel_tool_batches_reorders_interleaved_responses() -> None: + """[AI(parallel), Tool₁, Human₁, Tool₂, Human₂] should become + [AI, Tool₁, Tool₂, Human₁, Human₂] so OpenAI doesn't reject the next + turn for non-contiguous parallel tool responses.""" + messages: list[BaseMessage] = [ + _ai_with_parallel_calls(["a", "b"]), + ToolMessage(content="summary-a", tool_call_id="a"), + _image_human("a"), + ToolMessage(content="summary-b", tool_call_id="b"), + _image_human("b"), + ] + + out = _fix_parallel_tool_batches(messages) + + assert isinstance(out[0], AIMessage) + assert isinstance(out[1], ToolMessage) and out[1].tool_call_id == "a" + assert isinstance(out[2], ToolMessage) and out[2].tool_call_id == "b" + assert isinstance(out[3], HumanMessage) + assert out[3].additional_kwargs["tool_call_id"] == "a" + assert isinstance(out[4], HumanMessage) + assert out[4].additional_kwargs["tool_call_id"] == "b" + + +def test_fix_parallel_tool_batches_leaves_single_tool_call_alone() -> None: + """Single tool calls already satisfy the contiguity rule — don't touch them.""" + messages: list[BaseMessage] = [ + AIMessage( + content="", + tool_calls=[{"name": "t", "args": {}, "id": "solo"}], + ), + ToolMessage(content="summary", tool_call_id="solo"), + _image_human("solo"), + ] + + out = _fix_parallel_tool_batches(messages) + assert out == messages + + +def test_fix_parallel_tool_batches_leaves_already_ordered_alone() -> None: + """[AI, Tool₁, Tool₂, Human₁, Human₂] is already valid; don't reshuffle it.""" + messages: list[BaseMessage] = [ + _ai_with_parallel_calls(["a", "b"]), + ToolMessage(content="sa", tool_call_id="a"), + ToolMessage(content="sb", tool_call_id="b"), + _image_human("a"), + _image_human("b"), + ] + + out = _fix_parallel_tool_batches(messages) + assert out == messages + + +def test_fix_parallel_tool_batches_skips_untagged_human_messages() -> None: + """A plain HumanMessage with no `tool_call_id` tag terminates the run — + we won't reorder past it because we can't safely attribute it to a + parallel call.""" + plain_human = HumanMessage(content="just talking") + messages: list[BaseMessage] = [ + _ai_with_parallel_calls(["a", "b"]), + ToolMessage(content="sa", tool_call_id="a"), + plain_human, + ToolMessage(content="sb", tool_call_id="b"), + ] + + out = _fix_parallel_tool_batches(messages) + # Untouched: we stopped scanning at the plain human, so no rewrite. + assert out == messages + + +def test_reorder_tool_responses_merges_then_fixes() -> None: + """The reducer runs add_messages first, then applies the fix — so an + incoming Command-style append of [Tool, Human] for the second parallel + call lands contiguously after the first batch.""" + left: list[BaseMessage] = [ + _ai_with_parallel_calls(["a", "b"]), + ToolMessage(content="sa", tool_call_id="a"), + _image_human("a"), + ] + right: list[BaseMessage] = [ + ToolMessage(content="sb", tool_call_id="b"), + _image_human("b"), + ] + + out = _reorder_tool_responses(left, right) + + tool_ids = [m.tool_call_id for m in out if isinstance(m, ToolMessage)] + human_ids = [ + m.additional_kwargs.get("tool_call_id") for m in out if isinstance(m, HumanMessage) + ] + assert tool_ids == ["a", "b"] + assert human_ids == ["a", "b"] + # And critically: both ToolMessages come before either HumanMessage. + first_human = next(i for i, m in enumerate(out) if isinstance(m, HumanMessage)) + last_tool = max(i for i, m in enumerate(out) if isinstance(m, ToolMessage)) + assert last_tool < first_human + + +def test_process_message_normalizes_history_after_parallel_tool_batch( + mcp_client: McpClient, +) -> None: + """Regression: stream_mode="updates" yields node outputs in completion + order, so when two parallel image-returning tools each emit a + [ToolMessage, HumanMessage] Command, self._history ends up interleaved + as [Tool₁, Human₁, Tool₂, Human₂]. The graph's reducer reorders its + own channel state, but langgraph does NOT re-run reducers when an + initial state dict is fed back into stream() on the next turn — so + without an explicit fix-up here, OpenAI rejects the next user turn + for non-contiguous parallel ToolMessages. + """ + mcp_client._history = [] + mcp_client._message_queue = Queue() + mcp_client.agent_idle = MagicMock() + mcp_client.agent = MagicMock() + + ai = _ai_with_parallel_calls(["a", "b"]) + tool_a = ToolMessage(content="sa", tool_call_id="a") + human_a = _image_human("a") + tool_b = ToolMessage(content="sb", tool_call_id="b") + human_b = _image_human("b") + + # Mirror what langgraph's `stream_mode="updates"` looks like when two + # parallel ToolNode invocations finish out-of-order with respect to + # their image follow-ups: each Command landed [Tool, Human] in the + # raw node output stream. + fake_graph = MagicMock() + fake_graph.stream.return_value = iter( + [ + {"agent": {"messages": [ai]}}, + {"tools": {"messages": [tool_a, human_a]}}, + {"tools": {"messages": [tool_b, human_b]}}, + ] + ) + + user_msg = HumanMessage(content="look around") + mcp_client._process_message(fake_graph, user_msg) + + # Critical post-condition: both ToolMessages are contiguous, and only + # then come the image HumanMessages. Without the fix-up, _history + # would be [user, ai, tool_a, human_a, tool_b, human_b]. + tool_positions = [i for i, m in enumerate(mcp_client._history) if isinstance(m, ToolMessage)] + human_positions = [ + i + for i, m in enumerate(mcp_client._history) + if isinstance(m, HumanMessage) and m.additional_kwargs.get("tool_call_id") + ] + assert tool_positions == sorted(tool_positions) + assert max(tool_positions) < min(human_positions) + + def test_mcp_tool_call_sends_progress_token(mcp_client: McpClient) -> None: """Every `tools/call` request carries a `_meta.progressToken`.""" captured: dict[str, object] = {}