Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 36 additions & 8 deletions ext/dapr-ext-workflow/dapr/ext/workflow/aio/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,20 @@

from __future__ import annotations

import asyncio
import logging
import time
import uuid
from typing import Optional, Set

from dapr.ext.workflow.aio.dapr_workflow_client import DaprWorkflowClient
from dapr.ext.workflow.mcp import _MCP_METHOD_LIST_TOOLS, MCP_WORKFLOW_PREFIX, _DaprMCPClientBase
from dapr.ext.workflow.mcp import (
_MCP_METHOD_LIST_TOOLS,
_SCHEDULE_RETRY_INTERVAL_SECONDS,
MCP_WORKFLOW_PREFIX,
_DaprMCPClientBase,
_is_transient_schedule_error,
)
from dapr.ext.workflow.workflow_state import WorkflowStatus

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -84,15 +92,35 @@ async def connect(self, mcpserver_name: str) -> None:

logger.debug('Scheduling %s (instance=%s)', workflow_name, instance_id)

await self._wf_client.schedule_new_workflow(
workflow=workflow_name,
input={'mcpServerName': mcpserver_name},
instance_id=instance_id,
)

deadline = time.monotonic() + self._timeout
while True:
try:
await self._wf_client.schedule_new_workflow(
workflow=workflow_name,
input={'mcpServerName': mcpserver_name},
instance_id=instance_id,
)
break
except Exception as exc: # noqa: BLE001 — classified by helper
if not _is_transient_schedule_error(exc):
raise
sleep_for = min(_SCHEDULE_RETRY_INTERVAL_SECONDS, deadline - time.monotonic())
if sleep_for <= 0:
raise
logger.debug('schedule_new_workflow returned transient error %s; retrying', exc)
await asyncio.sleep(sleep_for)

remaining = deadline - time.monotonic()
if remaining <= 0:
raise RuntimeError(
f"ListTools workflow for MCPServer '{mcpserver_name}' "
f'timed out after {self._timeout}s'
)
# wait_for_workflow_completion treats timeout=0 as "wait forever",
# so floor the gRPC timeout at 1s when sub-second remaining survives.
state = await self._wf_client.wait_for_workflow_completion(
instance_id=instance_id,
timeout_in_seconds=self._timeout,
timeout_in_seconds=max(int(remaining), 1),
fetch_payloads=True,
)

Expand Down
60 changes: 53 additions & 7 deletions ext/dapr-ext-workflow/dapr/ext/workflow/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,12 @@

import json
import logging
import time
import uuid
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Set

import grpc
from dapr.ext.workflow.dapr_workflow_client import DaprWorkflowClient
from dapr.ext.workflow.workflow_state import WorkflowStatus

Expand All @@ -50,6 +52,30 @@
_MCP_METHOD_LIST_TOOLS = '.ListTools'
_MCP_METHOD_CALL_TOOL = '.CallTool'

_TRANSIENT_GRPC_CODES = frozenset(
{
grpc.StatusCode.CANCELLED,
grpc.StatusCode.UNAVAILABLE,
}
)
_SCHEDULE_RETRY_INTERVAL_SECONDS = 0.5


def _is_transient_schedule_error(exc: BaseException) -> bool:
"""True if a schedule_new_workflow failure should be retried.

Walks ``__cause__`` so we catch both raw ``grpc.RpcError`` and any
durabletask-layer wrapping.
"""
if isinstance(exc, grpc.RpcError):
code = getattr(exc, 'code', None)
if callable(code) and code() in _TRANSIENT_GRPC_CODES:
return True
cause = getattr(exc, '__cause__', None)
if cause is not None and cause is not exc:
return _is_transient_schedule_error(cause)
return False


# TODO(@sicoyle): see if I can use the mcp pkg class instead for this?
@dataclass(frozen=True)
Expand Down Expand Up @@ -210,15 +236,35 @@ def connect(self, mcpserver_name: str) -> None:

logger.debug('Scheduling %s (instance=%s)', workflow_name, instance_id)

self._wf_client.schedule_new_workflow(
workflow=workflow_name,
input={'mcpServerName': mcpserver_name},
instance_id=instance_id,
)

deadline = time.monotonic() + self._timeout
while True:
try:
self._wf_client.schedule_new_workflow(
workflow=workflow_name,
input={'mcpServerName': mcpserver_name},
instance_id=instance_id,
)
break
except Exception as exc: # noqa: BLE001 — classified by helper
if not _is_transient_schedule_error(exc):
raise
sleep_for = min(_SCHEDULE_RETRY_INTERVAL_SECONDS, deadline - time.monotonic())
if sleep_for <= 0:
raise
logger.debug('schedule_new_workflow returned transient error %s; retrying', exc)
time.sleep(sleep_for)

remaining = deadline - time.monotonic()
if remaining <= 0:
raise RuntimeError(
f"ListTools workflow for MCPServer '{mcpserver_name}' "
f'timed out after {self._timeout}s'
)
# wait_for_workflow_completion treats timeout=0 as "wait forever",
# so floor the gRPC timeout at 1s when sub-second remaining survives.
state = self._wf_client.wait_for_workflow_completion(
instance_id=instance_id,
timeout_in_seconds=self._timeout,
timeout_in_seconds=max(int(remaining), 1),
fetch_payloads=True,
)

Expand Down
167 changes: 166 additions & 1 deletion ext/dapr-ext-workflow/tests/test_mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,26 @@
import json
import unittest
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import AsyncMock, MagicMock, patch

import grpc
from dapr.ext.workflow._durabletask import client
from dapr.ext.workflow.aio.mcp import DaprMCPClient as AioDaprMCPClient
from dapr.ext.workflow.mcp import MCP_WORKFLOW_PREFIX, DaprMCPClient, MCPToolDef
from dapr.ext.workflow.workflow_state import WorkflowState


class _StubRpcError(grpc.RpcError):
"""Test double for grpc.RpcError with a configurable status code."""

def __init__(self, status_code: grpc.StatusCode):
super().__init__()
self._status_code = status_code

def code(self) -> grpc.StatusCode:
return self._status_code


def _make_completed_state(output_json: dict) -> WorkflowState:
"""Create a WorkflowState that simulates a COMPLETED workflow."""
inner = client.WorkflowState(
Expand Down Expand Up @@ -385,6 +397,159 @@ async def test_connect_caches_tools(self):
self.assertEqual(tools[1].name, 'get_forecast')


class TestDaprMCPClientConnectRetry(unittest.TestCase):
"""Tests for connect()'s retry-on-transient-gRPC-error path."""

def test_retries_then_succeeds_on_cancelled(self):
"""A CANCELLED schedule failure should be retried within the timeout budget."""
mock_wf = MagicMock()
mock_wf.schedule_new_workflow.side_effect = [
_StubRpcError(grpc.StatusCode.CANCELLED),
_StubRpcError(grpc.StatusCode.CANCELLED),
'inst-1',
]
mock_wf.wait_for_workflow_completion.return_value = _make_completed_state(
SAMPLE_LIST_TOOLS_RESPONSE
)

mcp_client = DaprMCPClient(timeout_in_seconds=30, wf_client=mock_wf)
with patch('dapr.ext.workflow.mcp.time.sleep'):
mcp_client.connect('weather')

self.assertEqual(mock_wf.schedule_new_workflow.call_count, 3)
self.assertEqual(len(mcp_client.get_all_tools()), 2)

def test_retries_on_unavailable(self):
"""UNAVAILABLE should also be treated as transient."""
mock_wf = MagicMock()
mock_wf.schedule_new_workflow.side_effect = [
_StubRpcError(grpc.StatusCode.UNAVAILABLE),
'inst-1',
]
mock_wf.wait_for_workflow_completion.return_value = _make_completed_state(
SAMPLE_LIST_TOOLS_RESPONSE
)

mcp_client = DaprMCPClient(timeout_in_seconds=30, wf_client=mock_wf)
with patch('dapr.ext.workflow.mcp.time.sleep'):
mcp_client.connect('weather')

self.assertEqual(mock_wf.schedule_new_workflow.call_count, 2)

def test_non_transient_propagates_immediately(self):
"""A non-CANCELLED/UNAVAILABLE error must not be retried."""
mock_wf = MagicMock()
mock_wf.schedule_new_workflow.side_effect = _StubRpcError(grpc.StatusCode.PERMISSION_DENIED)

mcp_client = DaprMCPClient(timeout_in_seconds=30, wf_client=mock_wf)
with patch('dapr.ext.workflow.mcp.time.sleep') as sleep_mock:
with self.assertRaises(grpc.RpcError):
mcp_client.connect('weather')

self.assertEqual(mock_wf.schedule_new_workflow.call_count, 1)
sleep_mock.assert_not_called()

def test_deadline_exhausted_raises_last_error(self):
"""When the timeout budget runs out mid-retry, propagate the last error."""
mock_wf = MagicMock()
mock_wf.schedule_new_workflow.side_effect = _StubRpcError(grpc.StatusCode.CANCELLED)

mcp_client = DaprMCPClient(timeout_in_seconds=1, wf_client=mock_wf)
# Patch monotonic to advance past the deadline immediately so we don't
# actually sleep for a second in tests.
with (
patch('dapr.ext.workflow.mcp.time.sleep'),
patch(
'dapr.ext.workflow.mcp.time.monotonic',
side_effect=[0.0, 2.0],
),
):
with self.assertRaises(grpc.RpcError):
mcp_client.connect('weather')

def test_budget_exhausted_after_schedule_succeeds(self):
"""If retries burn the budget but schedule eventually succeeds, raise
without calling wait_for_workflow_completion (timeout=0 means
'wait forever' in the underlying client)."""
mock_wf = MagicMock()
mock_wf.schedule_new_workflow.side_effect = [
_StubRpcError(grpc.StatusCode.CANCELLED),
'inst-1',
]

mcp_client = DaprMCPClient(timeout_in_seconds=1, wf_client=mock_wf)
# monotonic: 0.0 → deadline = 1.0; 0.4 → sleep_for = 0.5 (still in budget);
# 2.0 → post-loop remaining = -1.0 → raise.
with (
patch('dapr.ext.workflow.mcp.time.sleep'),
patch(
'dapr.ext.workflow.mcp.time.monotonic',
side_effect=[0.0, 0.4, 2.0],
),
):
with self.assertRaises(RuntimeError) as ctx:
mcp_client.connect('weather')
self.assertIn('timed out', str(ctx.exception))
mock_wf.wait_for_workflow_completion.assert_not_called()


class TestAioDaprMCPClientConnectRetry(unittest.IsolatedAsyncioTestCase):
"""Async counterpart of TestDaprMCPClientConnectRetry."""

async def test_retries_then_succeeds_on_cancelled(self):
mock_wf = AsyncMock()
mock_wf.schedule_new_workflow.side_effect = [
_StubRpcError(grpc.StatusCode.CANCELLED),
'inst-1',
]
mock_wf.wait_for_workflow_completion.return_value = _make_completed_state(
SAMPLE_LIST_TOOLS_RESPONSE
)

mcp_client = AioDaprMCPClient(timeout_in_seconds=30, wf_client=mock_wf)
with patch('dapr.ext.workflow.aio.mcp.asyncio.sleep', new=AsyncMock()):
await mcp_client.connect('weather')

self.assertEqual(mock_wf.schedule_new_workflow.await_count, 2)
self.assertEqual(len(mcp_client.get_all_tools()), 2)

async def test_deadline_exhausted_raises(self):
mock_wf = AsyncMock()
mock_wf.schedule_new_workflow.side_effect = _StubRpcError(grpc.StatusCode.CANCELLED)

mcp_client = AioDaprMCPClient(timeout_in_seconds=1, wf_client=mock_wf)
with (
patch('dapr.ext.workflow.aio.mcp.asyncio.sleep', new=AsyncMock()),
patch(
'dapr.ext.workflow.aio.mcp.time.monotonic',
side_effect=[0.0, 2.0],
),
):
with self.assertRaises(grpc.RpcError):
await mcp_client.connect('weather')

async def test_budget_exhausted_after_schedule_succeeds(self):
"""Async mirror of the fail-fast-after-schedule-success guard."""
mock_wf = AsyncMock()
mock_wf.schedule_new_workflow.side_effect = [
_StubRpcError(grpc.StatusCode.CANCELLED),
'inst-1',
]

mcp_client = AioDaprMCPClient(timeout_in_seconds=1, wf_client=mock_wf)
with (
patch('dapr.ext.workflow.aio.mcp.asyncio.sleep', new=AsyncMock()),
patch(
'dapr.ext.workflow.aio.mcp.time.monotonic',
side_effect=[0.0, 0.4, 2.0],
),
):
with self.assertRaises(RuntimeError) as ctx:
await mcp_client.connect('weather')
self.assertIn('timed out', str(ctx.exception))
mock_wf.wait_for_workflow_completion.assert_not_awaited()


class TestMCPWorkflowPrefix(unittest.TestCase):
"""Tests for the workflow naming constant."""

Expand Down
Loading