diff --git a/src/fast_agent/acp/server/agent_acp_server.py b/src/fast_agent/acp/server/agent_acp_server.py index 771515c8..74ac58bd 100644 --- a/src/fast_agent/acp/server/agent_acp_server.py +++ b/src/fast_agent/acp/server/agent_acp_server.py @@ -83,6 +83,9 @@ def __init__( self.sessions: dict[str, AgentInstance] = {} self._session_lock = asyncio.Lock() + # Track sessions with active prompts to prevent overlapping requests (per ACP protocol) + self._active_prompts: set[str] = set() + # Terminal runtime tracking (for cleanup) self._session_terminal_runtimes: dict[str, ACPTerminalRuntime] = {} @@ -236,6 +239,9 @@ async def prompt(self, params: PromptRequest) -> PromptResponse: Extracts the prompt text, sends it to the fast-agent agent, and sends the response back to the client via sessionUpdate notifications. + + Per ACP protocol, only one prompt can be active per session at a time. If a prompt + is already in progress for this session, this will immediately return a refusal. """ session_id = params.sessionId @@ -245,165 +251,190 @@ async def prompt(self, params: PromptRequest) -> PromptResponse: session_id=session_id, ) - # Get the agent instance for this session + # Check for overlapping prompt requests (per ACP protocol requirement) async with self._session_lock: - instance = self.sessions.get(session_id) + if session_id in self._active_prompts: + logger.warning( + "Overlapping prompt request detected - refusing", + name="acp_prompt_overlap", + session_id=session_id, + ) + # Return immediate refusal - ACP protocol requires sequential prompts per session + return PromptResponse(stopReason=REFUSAL) - if not instance: - logger.error( - "ACP prompt error: session not found", - name="acp_prompt_error", - session_id=session_id, - ) - # Return an error response - return PromptResponse(stopReason=REFUSAL) + # Mark this session as having an active prompt + self._active_prompts.add(session_id) - # Extract text content from the prompt - text_parts = [] - for content_block in params.prompt: - if hasattr(content_block, "type") and content_block.type == "text": - text_parts.append(content_block.text) + # Use try/finally to ensure session is always removed from active prompts + try: + # Get the agent instance for this session + async with self._session_lock: + instance = self.sessions.get(session_id) + + if not instance: + logger.error( + "ACP prompt error: session not found", + name="acp_prompt_error", + session_id=session_id, + ) + # Return an error response + return PromptResponse(stopReason=REFUSAL) - prompt_text = "\n".join(text_parts) + # Extract text content from the prompt + text_parts = [] + for content_block in params.prompt: + if hasattr(content_block, "type") and content_block.type == "text": + text_parts.append(content_block.text) - logger.info( - "Sending prompt to fast-agent", - name="acp_prompt_send", - session_id=session_id, - agent=self.primary_agent_name, - prompt_length=len(prompt_text), - ) + prompt_text = "\n".join(text_parts) - # Send to the fast-agent agent with streaming support - try: - if self.primary_agent_name: - agent = instance.agents[self.primary_agent_name] - - # Set up streaming if connection is available and agent supports it - stream_listener = None - remove_listener: Callable[[], None] | None = None - if self._connection and isinstance(agent, StreamingAgentProtocol): - update_lock = asyncio.Lock() - - async def send_stream_update(chunk: str): - """Send sessionUpdate with accumulated text so far.""" - if not chunk: - return - try: - async with update_lock: - message_chunk = update_agent_message_text(chunk) - notification = session_notification(session_id, message_chunk) - await self._connection.sessionUpdate(notification) - except Exception as e: - logger.error( - f"Error sending stream update: {e}", - name="acp_stream_error", - exc_info=True, + logger.info( + "Sending prompt to fast-agent", + name="acp_prompt_send", + session_id=session_id, + agent=self.primary_agent_name, + prompt_length=len(prompt_text), + ) + + # Send to the fast-agent agent with streaming support + try: + if self.primary_agent_name: + agent = instance.agents[self.primary_agent_name] + + # Set up streaming if connection is available and agent supports it + stream_listener = None + remove_listener: Callable[[], None] | None = None + if self._connection and isinstance(agent, StreamingAgentProtocol): + update_lock = asyncio.Lock() + + async def send_stream_update(chunk: str): + """Send sessionUpdate with accumulated text so far.""" + if not chunk: + return + try: + async with update_lock: + message_chunk = update_agent_message_text(chunk) + notification = session_notification(session_id, message_chunk) + await self._connection.sessionUpdate(notification) + except Exception as e: + logger.error( + f"Error sending stream update: {e}", + name="acp_stream_error", + exc_info=True, + ) + + def on_stream_chunk(chunk: str): + """ + Sync callback from fast-agent streaming. + Sends each chunk as it arrives to the ACP client. + """ + logger.debug( + f"Stream chunk received: {len(chunk)} chars", + name="acp_stream_chunk", + session_id=session_id, + chunk_length=len(chunk), ) - def on_stream_chunk(chunk: str): - """ - Sync callback from fast-agent streaming. - Sends each chunk as it arrives to the ACP client. - """ - logger.debug( - f"Stream chunk received: {len(chunk)} chars", - name="acp_stream_chunk", + # Send update asynchronously (don't await in sync callback) + asyncio.create_task(send_stream_update(chunk)) + + # Register the stream listener and keep the cleanup function + stream_listener = on_stream_chunk + remove_listener = agent.add_stream_listener(stream_listener) + + logger.info( + "Streaming enabled for prompt", + name="acp_streaming_enabled", session_id=session_id, - chunk_length=len(chunk), ) - # Send update asynchronously (don't await in sync callback) - asyncio.create_task(send_stream_update(chunk)) + try: + # This will trigger streaming callbacks as chunks arrive + response_text = await agent.send(prompt_text) - # Register the stream listener and keep the cleanup function - stream_listener = on_stream_chunk - remove_listener = agent.add_stream_listener(stream_listener) - - logger.info( - "Streaming enabled for prompt", - name="acp_streaming_enabled", - session_id=session_id, - ) - - try: - # This will trigger streaming callbacks as chunks arrive - response_text = await agent.send(prompt_text) - - logger.info( - "Received complete response from fast-agent", - name="acp_prompt_response", - session_id=session_id, - response_length=len(response_text), - ) + logger.info( + "Received complete response from fast-agent", + name="acp_prompt_response", + session_id=session_id, + response_length=len(response_text), + ) - # Always send final update with complete response - # (streaming sends chunks during execution, this is the final complete message) - if self._connection and response_text: - try: - message_chunk = update_agent_message_text(response_text) - notification = session_notification(session_id, message_chunk) - await self._connection.sessionUpdate(notification) - logger.info( - "Sent final sessionUpdate with complete response", - name="acp_final_update", - session_id=session_id, - ) - except Exception as e: - logger.error( - f"Error sending final update: {e}", - name="acp_final_update_error", - exc_info=True, - ) + # Always send final update with complete response + # (streaming sends chunks during execution, this is the final complete message) + if self._connection and response_text: + try: + message_chunk = update_agent_message_text(response_text) + notification = session_notification(session_id, message_chunk) + await self._connection.sessionUpdate(notification) + logger.info( + "Sent final sessionUpdate with complete response", + name="acp_final_update", + session_id=session_id, + ) + except Exception as e: + logger.error( + f"Error sending final update: {e}", + name="acp_final_update_error", + exc_info=True, + ) + + except Exception as send_error: + # Make sure listener is cleaned up even on error + if stream_listener and remove_listener: + try: + remove_listener() + logger.info( + "Removed stream listener after error", + name="acp_streaming_cleanup_error", + session_id=session_id, + ) + except Exception: + logger.exception("Failed to remove ACP stream listener after error") + # Re-raise the original error + raise send_error + + finally: + # Clean up stream listener (if not already cleaned up in except) + if stream_listener and remove_listener: + try: + remove_listener() + except Exception: + logger.exception("Failed to remove ACP stream listener") + else: + logger.info( + "Removed stream listener", + name="acp_streaming_cleanup", + session_id=session_id, + ) + + else: + logger.error("No primary agent available") + except Exception as e: + logger.error( + f"Error processing prompt: {e}", + name="acp_prompt_error", + exc_info=True, + ) + import sys + import traceback - except Exception as send_error: - # Make sure listener is cleaned up even on error - if stream_listener and remove_listener: - try: - remove_listener() - logger.info( - "Removed stream listener after error", - name="acp_streaming_cleanup_error", - session_id=session_id, - ) - except Exception: - logger.exception("Failed to remove ACP stream listener after error") - # Re-raise the original error - raise send_error - - finally: - # Clean up stream listener (if not already cleaned up in except) - if stream_listener and remove_listener: - try: - remove_listener() - except Exception: - logger.exception("Failed to remove ACP stream listener") - else: - logger.info( - "Removed stream listener", - name="acp_streaming_cleanup", - session_id=session_id, - ) + print(f"ERROR processing prompt: {e}", file=sys.stderr) + traceback.print_exc(file=sys.stderr) + raise - else: - logger.error("No primary agent available") - except Exception as e: - logger.error( - f"Error processing prompt: {e}", - name="acp_prompt_error", - exc_info=True, + # Return success + return PromptResponse( + stopReason=END_TURN, + ) + finally: + # Always remove session from active prompts, even on error + async with self._session_lock: + self._active_prompts.discard(session_id) + logger.debug( + "Removed session from active prompts", + name="acp_prompt_complete", + session_id=session_id, ) - import sys - import traceback - - print(f"ERROR processing prompt: {e}", file=sys.stderr) - traceback.print_exc(file=sys.stderr) - raise - - # Return success - return PromptResponse( - stopReason=END_TURN, - ) async def run_async(self) -> None: """ diff --git a/tests/integration/acp/test_acp_basic.py b/tests/integration/acp/test_acp_basic.py index 321e4a14..c9ed9d7a 100644 --- a/tests/integration/acp/test_acp_basic.py +++ b/tests/integration/acp/test_acp_basic.py @@ -89,3 +89,66 @@ async def _wait_for_notifications(client: TestClient, timeout: float = 2.0) -> N return await asyncio.sleep(0.05) raise AssertionError("Expected streamed session updates") + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_acp_overlapping_prompts_are_refused() -> None: + """ + Test that overlapping prompt requests for the same session are refused. + + Per ACP protocol, only one prompt can be active per session at a time. + If a second prompt arrives while one is in progress, it should be immediately + refused with stopReason="refusal". + """ + client = TestClient() + + async with spawn_agent_process(lambda _: client, *FAST_AGENT_CMD) as (connection, _process): + # Initialize + init_request = InitializeRequest( + protocolVersion=1, + clientCapabilities=ClientCapabilities( + fs={"readTextFile": True, "writeTextFile": True}, + terminal=False, + ), + clientInfo=Implementation(name="pytest-client", version="0.0.1"), + ) + init_response = await connection.initialize(init_request) + assert init_response.protocolVersion == 1 + + # Create session + session_response = await connection.newSession( + NewSessionRequest(mcpServers=[], cwd=str(TEST_DIR)) + ) + session_id = session_response.sessionId + assert session_id + + # Send two prompts truly concurrently (no sleep between them) + # This ensures they both arrive before either completes + prompt1_task = asyncio.create_task( + connection.prompt( + PromptRequest(sessionId=session_id, prompt=[text_block("first prompt")]) + ) + ) + + # Send immediately without waiting - ensures actual overlap + prompt2_task = asyncio.create_task( + connection.prompt( + PromptRequest(sessionId=session_id, prompt=[text_block("overlapping prompt")]) + ) + ) + + # Wait for both to complete + prompt1_response, prompt2_response = await asyncio.gather(prompt1_task, prompt2_task) + + # One should succeed, one should be refused + # (We don't know which one arrives first due to async scheduling) + responses = [prompt1_response.stopReason, prompt2_response.stopReason] + assert "end_turn" in responses, "One prompt should succeed" + assert "refusal" in responses, "One prompt should be refused" + + # After both complete, a new prompt should succeed + prompt3_response = await connection.prompt( + PromptRequest(sessionId=session_id, prompt=[text_block("third prompt")]) + ) + assert prompt3_response.stopReason == END_TURN