From 03885e579a8b026aa3ca13840a6c0867a8fdd5e4 Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 28 May 2026 08:13:43 -0500 Subject: [PATCH] fix(mcp): retry connect calls on transient grpc errs (#1062) * fix(mcp): retry connect calls on transient grpc errs Signed-off-by: Samantha Coyle * style: comment cleanup Signed-off-by: Samantha Coyle * fix: address copilot feedback Signed-off-by: Samantha Coyle * style: appease linter Signed-off-by: Samantha Coyle --------- Signed-off-by: Samantha Coyle (cherry picked from commit 8571c3ec0eb4261b81684212e4a853c5479d3877) Signed-off-by: dapr-bot --- .../dapr/ext/workflow/aio/mcp.py | 44 ++++- .../dapr/ext/workflow/mcp.py | 60 ++++++- .../tests/test_mcp_client.py | 167 +++++++++++++++++- 3 files changed, 255 insertions(+), 16 deletions(-) diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/aio/mcp.py b/ext/dapr-ext-workflow/dapr/ext/workflow/aio/mcp.py index 7d98f3b8a..e13e68a2e 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/aio/mcp.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/aio/mcp.py @@ -15,12 +15,20 @@ from __future__ import annotations +import asyncio import logging +import time import uuid from typing import Optional, Set from dapr.ext.workflow.aio.dapr_workflow_client import DaprWorkflowClient -from dapr.ext.workflow.mcp import _MCP_METHOD_LIST_TOOLS, MCP_WORKFLOW_PREFIX, _DaprMCPClientBase +from dapr.ext.workflow.mcp import ( + _MCP_METHOD_LIST_TOOLS, + _SCHEDULE_RETRY_INTERVAL_SECONDS, + MCP_WORKFLOW_PREFIX, + _DaprMCPClientBase, + _is_transient_schedule_error, +) from dapr.ext.workflow.workflow_state import WorkflowStatus logger = logging.getLogger(__name__) @@ -84,15 +92,35 @@ async def connect(self, mcpserver_name: str) -> None: logger.debug('Scheduling %s (instance=%s)', workflow_name, instance_id) - await self._wf_client.schedule_new_workflow( - workflow=workflow_name, - input={'mcpServerName': mcpserver_name}, - instance_id=instance_id, - ) - + deadline = time.monotonic() + self._timeout + while True: + try: + await self._wf_client.schedule_new_workflow( + workflow=workflow_name, + input={'mcpServerName': mcpserver_name}, + instance_id=instance_id, + ) + break + except Exception as exc: # noqa: BLE001 — classified by helper + if not _is_transient_schedule_error(exc): + raise + sleep_for = min(_SCHEDULE_RETRY_INTERVAL_SECONDS, deadline - time.monotonic()) + if sleep_for <= 0: + raise + logger.debug('schedule_new_workflow returned transient error %s; retrying', exc) + await asyncio.sleep(sleep_for) + + remaining = deadline - time.monotonic() + if remaining <= 0: + raise RuntimeError( + f"ListTools workflow for MCPServer '{mcpserver_name}' " + f'timed out after {self._timeout}s' + ) + # wait_for_workflow_completion treats timeout=0 as "wait forever", + # so floor the gRPC timeout at 1s when sub-second remaining survives. state = await self._wf_client.wait_for_workflow_completion( instance_id=instance_id, - timeout_in_seconds=self._timeout, + timeout_in_seconds=max(int(remaining), 1), fetch_payloads=True, ) diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/mcp.py b/ext/dapr-ext-workflow/dapr/ext/workflow/mcp.py index 4c53006db..11276a050 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/mcp.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/mcp.py @@ -33,10 +33,12 @@ import json import logging +import time import uuid from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Set +import grpc from dapr.ext.workflow.dapr_workflow_client import DaprWorkflowClient from dapr.ext.workflow.workflow_state import WorkflowStatus @@ -50,6 +52,30 @@ _MCP_METHOD_LIST_TOOLS = '.ListTools' _MCP_METHOD_CALL_TOOL = '.CallTool' +_TRANSIENT_GRPC_CODES = frozenset( + { + grpc.StatusCode.CANCELLED, + grpc.StatusCode.UNAVAILABLE, + } +) +_SCHEDULE_RETRY_INTERVAL_SECONDS = 0.5 + + +def _is_transient_schedule_error(exc: BaseException) -> bool: + """True if a schedule_new_workflow failure should be retried. + + Walks ``__cause__`` so we catch both raw ``grpc.RpcError`` and any + durabletask-layer wrapping. + """ + if isinstance(exc, grpc.RpcError): + code = getattr(exc, 'code', None) + if callable(code) and code() in _TRANSIENT_GRPC_CODES: + return True + cause = getattr(exc, '__cause__', None) + if cause is not None and cause is not exc: + return _is_transient_schedule_error(cause) + return False + # TODO(@sicoyle): see if I can use the mcp pkg class instead for this? @dataclass(frozen=True) @@ -210,15 +236,35 @@ def connect(self, mcpserver_name: str) -> None: logger.debug('Scheduling %s (instance=%s)', workflow_name, instance_id) - self._wf_client.schedule_new_workflow( - workflow=workflow_name, - input={'mcpServerName': mcpserver_name}, - instance_id=instance_id, - ) - + deadline = time.monotonic() + self._timeout + while True: + try: + self._wf_client.schedule_new_workflow( + workflow=workflow_name, + input={'mcpServerName': mcpserver_name}, + instance_id=instance_id, + ) + break + except Exception as exc: # noqa: BLE001 — classified by helper + if not _is_transient_schedule_error(exc): + raise + sleep_for = min(_SCHEDULE_RETRY_INTERVAL_SECONDS, deadline - time.monotonic()) + if sleep_for <= 0: + raise + logger.debug('schedule_new_workflow returned transient error %s; retrying', exc) + time.sleep(sleep_for) + + remaining = deadline - time.monotonic() + if remaining <= 0: + raise RuntimeError( + f"ListTools workflow for MCPServer '{mcpserver_name}' " + f'timed out after {self._timeout}s' + ) + # wait_for_workflow_completion treats timeout=0 as "wait forever", + # so floor the gRPC timeout at 1s when sub-second remaining survives. state = self._wf_client.wait_for_workflow_completion( instance_id=instance_id, - timeout_in_seconds=self._timeout, + timeout_in_seconds=max(int(remaining), 1), fetch_payloads=True, ) diff --git a/ext/dapr-ext-workflow/tests/test_mcp_client.py b/ext/dapr-ext-workflow/tests/test_mcp_client.py index 418d6d464..c5ceded45 100644 --- a/ext/dapr-ext-workflow/tests/test_mcp_client.py +++ b/ext/dapr-ext-workflow/tests/test_mcp_client.py @@ -16,14 +16,26 @@ import json import unittest from datetime import datetime -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock, MagicMock, patch +import grpc from dapr.ext.workflow._durabletask import client from dapr.ext.workflow.aio.mcp import DaprMCPClient as AioDaprMCPClient from dapr.ext.workflow.mcp import MCP_WORKFLOW_PREFIX, DaprMCPClient, MCPToolDef from dapr.ext.workflow.workflow_state import WorkflowState +class _StubRpcError(grpc.RpcError): + """Test double for grpc.RpcError with a configurable status code.""" + + def __init__(self, status_code: grpc.StatusCode): + super().__init__() + self._status_code = status_code + + def code(self) -> grpc.StatusCode: + return self._status_code + + def _make_completed_state(output_json: dict) -> WorkflowState: """Create a WorkflowState that simulates a COMPLETED workflow.""" inner = client.WorkflowState( @@ -385,6 +397,159 @@ async def test_connect_caches_tools(self): self.assertEqual(tools[1].name, 'get_forecast') +class TestDaprMCPClientConnectRetry(unittest.TestCase): + """Tests for connect()'s retry-on-transient-gRPC-error path.""" + + def test_retries_then_succeeds_on_cancelled(self): + """A CANCELLED schedule failure should be retried within the timeout budget.""" + mock_wf = MagicMock() + mock_wf.schedule_new_workflow.side_effect = [ + _StubRpcError(grpc.StatusCode.CANCELLED), + _StubRpcError(grpc.StatusCode.CANCELLED), + 'inst-1', + ] + mock_wf.wait_for_workflow_completion.return_value = _make_completed_state( + SAMPLE_LIST_TOOLS_RESPONSE + ) + + mcp_client = DaprMCPClient(timeout_in_seconds=30, wf_client=mock_wf) + with patch('dapr.ext.workflow.mcp.time.sleep'): + mcp_client.connect('weather') + + self.assertEqual(mock_wf.schedule_new_workflow.call_count, 3) + self.assertEqual(len(mcp_client.get_all_tools()), 2) + + def test_retries_on_unavailable(self): + """UNAVAILABLE should also be treated as transient.""" + mock_wf = MagicMock() + mock_wf.schedule_new_workflow.side_effect = [ + _StubRpcError(grpc.StatusCode.UNAVAILABLE), + 'inst-1', + ] + mock_wf.wait_for_workflow_completion.return_value = _make_completed_state( + SAMPLE_LIST_TOOLS_RESPONSE + ) + + mcp_client = DaprMCPClient(timeout_in_seconds=30, wf_client=mock_wf) + with patch('dapr.ext.workflow.mcp.time.sleep'): + mcp_client.connect('weather') + + self.assertEqual(mock_wf.schedule_new_workflow.call_count, 2) + + def test_non_transient_propagates_immediately(self): + """A non-CANCELLED/UNAVAILABLE error must not be retried.""" + mock_wf = MagicMock() + mock_wf.schedule_new_workflow.side_effect = _StubRpcError(grpc.StatusCode.PERMISSION_DENIED) + + mcp_client = DaprMCPClient(timeout_in_seconds=30, wf_client=mock_wf) + with patch('dapr.ext.workflow.mcp.time.sleep') as sleep_mock: + with self.assertRaises(grpc.RpcError): + mcp_client.connect('weather') + + self.assertEqual(mock_wf.schedule_new_workflow.call_count, 1) + sleep_mock.assert_not_called() + + def test_deadline_exhausted_raises_last_error(self): + """When the timeout budget runs out mid-retry, propagate the last error.""" + mock_wf = MagicMock() + mock_wf.schedule_new_workflow.side_effect = _StubRpcError(grpc.StatusCode.CANCELLED) + + mcp_client = DaprMCPClient(timeout_in_seconds=1, wf_client=mock_wf) + # Patch monotonic to advance past the deadline immediately so we don't + # actually sleep for a second in tests. + with ( + patch('dapr.ext.workflow.mcp.time.sleep'), + patch( + 'dapr.ext.workflow.mcp.time.monotonic', + side_effect=[0.0, 2.0], + ), + ): + with self.assertRaises(grpc.RpcError): + mcp_client.connect('weather') + + def test_budget_exhausted_after_schedule_succeeds(self): + """If retries burn the budget but schedule eventually succeeds, raise + without calling wait_for_workflow_completion (timeout=0 means + 'wait forever' in the underlying client).""" + mock_wf = MagicMock() + mock_wf.schedule_new_workflow.side_effect = [ + _StubRpcError(grpc.StatusCode.CANCELLED), + 'inst-1', + ] + + mcp_client = DaprMCPClient(timeout_in_seconds=1, wf_client=mock_wf) + # monotonic: 0.0 → deadline = 1.0; 0.4 → sleep_for = 0.5 (still in budget); + # 2.0 → post-loop remaining = -1.0 → raise. + with ( + patch('dapr.ext.workflow.mcp.time.sleep'), + patch( + 'dapr.ext.workflow.mcp.time.monotonic', + side_effect=[0.0, 0.4, 2.0], + ), + ): + with self.assertRaises(RuntimeError) as ctx: + mcp_client.connect('weather') + self.assertIn('timed out', str(ctx.exception)) + mock_wf.wait_for_workflow_completion.assert_not_called() + + +class TestAioDaprMCPClientConnectRetry(unittest.IsolatedAsyncioTestCase): + """Async counterpart of TestDaprMCPClientConnectRetry.""" + + async def test_retries_then_succeeds_on_cancelled(self): + mock_wf = AsyncMock() + mock_wf.schedule_new_workflow.side_effect = [ + _StubRpcError(grpc.StatusCode.CANCELLED), + 'inst-1', + ] + mock_wf.wait_for_workflow_completion.return_value = _make_completed_state( + SAMPLE_LIST_TOOLS_RESPONSE + ) + + mcp_client = AioDaprMCPClient(timeout_in_seconds=30, wf_client=mock_wf) + with patch('dapr.ext.workflow.aio.mcp.asyncio.sleep', new=AsyncMock()): + await mcp_client.connect('weather') + + self.assertEqual(mock_wf.schedule_new_workflow.await_count, 2) + self.assertEqual(len(mcp_client.get_all_tools()), 2) + + async def test_deadline_exhausted_raises(self): + mock_wf = AsyncMock() + mock_wf.schedule_new_workflow.side_effect = _StubRpcError(grpc.StatusCode.CANCELLED) + + mcp_client = AioDaprMCPClient(timeout_in_seconds=1, wf_client=mock_wf) + with ( + patch('dapr.ext.workflow.aio.mcp.asyncio.sleep', new=AsyncMock()), + patch( + 'dapr.ext.workflow.aio.mcp.time.monotonic', + side_effect=[0.0, 2.0], + ), + ): + with self.assertRaises(grpc.RpcError): + await mcp_client.connect('weather') + + async def test_budget_exhausted_after_schedule_succeeds(self): + """Async mirror of the fail-fast-after-schedule-success guard.""" + mock_wf = AsyncMock() + mock_wf.schedule_new_workflow.side_effect = [ + _StubRpcError(grpc.StatusCode.CANCELLED), + 'inst-1', + ] + + mcp_client = AioDaprMCPClient(timeout_in_seconds=1, wf_client=mock_wf) + with ( + patch('dapr.ext.workflow.aio.mcp.asyncio.sleep', new=AsyncMock()), + patch( + 'dapr.ext.workflow.aio.mcp.time.monotonic', + side_effect=[0.0, 0.4, 2.0], + ), + ): + with self.assertRaises(RuntimeError) as ctx: + await mcp_client.connect('weather') + self.assertIn('timed out', str(ctx.exception)) + mock_wf.wait_for_workflow_completion.assert_not_awaited() + + class TestMCPWorkflowPrefix(unittest.TestCase): """Tests for the workflow naming constant."""