diff --git a/src/fast_agent/acp/server/agent_acp_server.py b/src/fast_agent/acp/server/agent_acp_server.py index 6c5892ce..deeefedd 100644 --- a/src/fast_agent/acp/server/agent_acp_server.py +++ b/src/fast_agent/acp/server/agent_acp_server.py @@ -23,10 +23,7 @@ SetSessionModeRequest, SetSessionModeResponse, ) -from acp.agent.router import build_agent_router -from acp.connection import MethodHandler from acp.helpers import session_notification, update_agent_message_text -from acp.meta import AGENT_METHODS from acp.schema import ( AgentCapabilities, Implementation, @@ -54,7 +51,6 @@ enrich_with_environment_context, ) from fast_agent.interfaces import StreamingAgentProtocol -from fast_agent.llm.cancellation import CancellationToken from fast_agent.llm.model_database import ModelDatabase from fast_agent.mcp.helpers.content_helpers import is_text_content from fast_agent.types import LlmStopReason, PromptMessageExtended, RequestParams @@ -66,34 +62,6 @@ REFUSAL: StopReason = "refusal" -class ExtendedAgentSideConnection(AgentSideConnection): - """ - Extended AgentSideConnection that registers session/cancel as both request and notification. - - Some clients incorrectly send session/cancel as a request (with an id) instead of - a notification. This subclass adds the cancel handler to both routing tables for - compatibility. - """ - - def _create_handler(self, agent: ACPAgent) -> MethodHandler: - """Override to add cancel as both request and notification handler.""" - router = build_agent_router(agent) - - # Also register cancel as a request handler for compatibility with clients - # that incorrectly send it with an id - router._requests[AGENT_METHODS["session_cancel"]] = router._notifications.get( - AGENT_METHODS["session_cancel"] - ) - - async def handler(method: str, params: Any | None, is_notification: bool) -> Any: - if is_notification: - await router.dispatch_notification(method, params) - return None - return await router.dispatch_request(method, params) - - return handler - - def map_llm_stop_reason_to_acp(llm_stop_reason: LlmStopReason | None) -> StopReason: """ Map fast-agent LlmStopReason to ACP StopReason. @@ -221,8 +189,8 @@ def __init__( # Track sessions with active prompts to prevent overlapping requests (per ACP protocol) self._active_prompts: set[str] = set() - # Track cancellation tokens per session for cancel support - self._session_cancellation_tokens: dict[str, CancellationToken] = {} + # Track asyncio tasks per session for proper task-based cancellation + self._session_tasks: dict[str, asyncio.Task] = {} # Track current agent per session for ACP mode support self._session_current_agent: dict[str, str] = {} @@ -790,9 +758,10 @@ async def prompt(self, params: PromptRequest) -> PromptResponse: # Mark this session as having an active prompt self._active_prompts.add(session_id) - # Create a cancellation token for this prompt - cancellation_token = CancellationToken() - self._session_cancellation_tokens[session_id] = cancellation_token + # Track the current task for proper cancellation via asyncio.Task.cancel() + current_task = asyncio.current_task() + if current_task: + self._session_tasks[session_id] = current_task # Use try/finally to ensure session is always removed from active prompts try: @@ -944,7 +913,6 @@ def on_stream_chunk(chunk: str): result = await agent.generate( prompt_message, request_params=session_request_params, - cancellation_token=cancellation_token, ) response_text = result.last_text() or "No content generated" @@ -1055,11 +1023,19 @@ def on_stream_chunk(chunk: str): return PromptResponse( stopReason=acp_stop_reason, ) + except asyncio.CancelledError: + # Task was cancelled - return appropriate response + logger.info( + "Prompt cancelled by user", + name="acp_prompt_cancelled", + session_id=session_id, + ) + return PromptResponse(stopReason="cancelled") finally: - # Always remove session from active prompts and cleanup cancellation token + # Always remove session from active prompts and cleanup task async with self._session_lock: self._active_prompts.discard(session_id) - self._session_cancellation_tokens.pop(session_id, None) + self._session_tasks.pop(session_id, None) logger.debug( "Removed session from active prompts", name="acp_prompt_complete", @@ -1073,6 +1049,9 @@ async def cancel(self, params: CancelNotification) -> None: This cancels any in-progress prompt for the specified session. Per ACP protocol, we should stop all LLM requests and tool invocations as soon as possible. + + Uses asyncio.Task.cancel() for proper async cancellation, which raises + asyncio.CancelledError in the running task. """ session_id = params.sessionId @@ -1082,14 +1061,14 @@ async def cancel(self, params: CancelNotification) -> None: session_id=session_id, ) - # Get the cancellation token for this session and signal cancellation + # Get the task for this session and cancel it async with self._session_lock: - cancellation_token = self._session_cancellation_tokens.get(session_id) - if cancellation_token: - cancellation_token.cancel("user_cancelled") + task = self._session_tasks.get(session_id) + if task and not task.done(): + task.cancel() logger.info( - "Cancellation signaled for session", - name="acp_cancel_signaled", + "Task cancelled for session", + name="acp_cancel_task", session_id=session_id, ) else: @@ -1116,8 +1095,7 @@ async def run_async(self) -> None: # Note: AgentSideConnection expects (writer, reader) order # - input_stream (writer) = where agent writes TO client # - output_stream (reader) = where agent reads FROM client - # Use ExtendedAgentSideConnection for cancel request/notification compatibility - connection = ExtendedAgentSideConnection( + connection = AgentSideConnection( lambda conn: self, writer, # input_stream = StreamWriter for agent output reader, # output_stream = StreamReader for agent input @@ -1222,8 +1200,8 @@ async def _cleanup_sessions(self) -> None: # Clean up session current agent mapping self._session_current_agent.clear() - # Clear cancellation tokens - self._session_cancellation_tokens.clear() + # Clear tasks + self._session_tasks.clear() # Clear stored prompt contexts self._session_prompt_context.clear() diff --git a/src/fast_agent/agents/llm_agent.py b/src/fast_agent/agents/llm_agent.py index c1b1c3c5..0de1213b 100644 --- a/src/fast_agent/agents/llm_agent.py +++ b/src/fast_agent/agents/llm_agent.py @@ -18,7 +18,6 @@ from fast_agent.agents.llm_decorator import LlmDecorator, ModelT from fast_agent.constants import FAST_AGENT_ERROR_CHANNEL from fast_agent.context import Context -from fast_agent.llm.cancellation import CancellationToken from fast_agent.mcp.helpers.content_helpers import get_text from fast_agent.types import PromptMessageExtended, RequestParams from fast_agent.types.llm_stop_reason import LlmStopReason @@ -238,7 +237,6 @@ async def generate_impl( messages: List[PromptMessageExtended], request_params: RequestParams | None = None, tools: List[Tool] | None = None, - cancellation_token: CancellationToken | None = None, ) -> PromptMessageExtended: """ Enhanced generate implementation that resets tool call tracking. @@ -272,7 +270,7 @@ async def generate_impl( try: result, summary = await self._generate_with_summary( - messages, request_params, tools, cancellation_token + messages, request_params, tools ) finally: if remove_listener: @@ -288,7 +286,7 @@ async def generate_impl( await self.show_assistant_message(result, additional_message=summary_text) else: result, summary = await self._generate_with_summary( - messages, request_params, tools, cancellation_token + messages, request_params, tools ) summary_text = ( diff --git a/src/fast_agent/agents/llm_decorator.py b/src/fast_agent/agents/llm_decorator.py index 91127b30..c82f23fb 100644 --- a/src/fast_agent/agents/llm_decorator.py +++ b/src/fast_agent/agents/llm_decorator.py @@ -51,7 +51,6 @@ LLMFactoryProtocol, StreamingAgentProtocol, ) -from fast_agent.llm.cancellation import CancellationToken from fast_agent.llm.model_database import ModelDatabase from fast_agent.llm.provider_types import Provider from fast_agent.llm.usage_tracking import UsageAccumulator @@ -289,7 +288,6 @@ async def generate( ], request_params: RequestParams | None = None, tools: list[Tool] | None = None, - cancellation_token: CancellationToken | None = None, ) -> PromptMessageExtended: """ Create a completion with the LLM using the provided messages. @@ -305,7 +303,6 @@ async def generate( - List of any combination of the above request_params: Optional parameters to configure the request tools: Optional list of tools available to the LLM - cancellation_token: Optional token to cancel the operation Returns: The LLM's response as a PromptMessageExtended @@ -318,7 +315,7 @@ async def generate( with self._tracer.start_as_current_span(f"Agent: '{self._name}' generate"): return await self.generate_impl( - multipart_messages, final_request_params, tools, cancellation_token + multipart_messages, final_request_params, tools ) async def generate_impl( @@ -326,7 +323,6 @@ async def generate_impl( messages: list[PromptMessageExtended], request_params: RequestParams | None = None, tools: list[Tool] | None = None, - cancellation_token: CancellationToken | None = None, ) -> PromptMessageExtended: """ Implementation method for generate. @@ -339,13 +335,12 @@ async def generate_impl( messages: Normalized list of PromptMessageExtended objects request_params: Optional parameters to configure the request tools: Optional list of tools available to the LLM - cancellation_token: Optional token to cancel the operation Returns: The LLM's response as a PromptMessageExtended """ response, _ = await self._generate_with_summary( - messages, request_params, tools, cancellation_token + messages, request_params, tools ) return response @@ -480,13 +475,12 @@ async def _generate_with_summary( messages: list[PromptMessageExtended], request_params: RequestParams | None = None, tools: list[Tool] | None = None, - cancellation_token: CancellationToken | None = None, ) -> tuple[PromptMessageExtended, RemovedContentSummary | None]: assert self._llm, "LLM is not attached" call_ctx = self._prepare_llm_call(messages, request_params) response = await self._llm.generate( - call_ctx.full_history, call_ctx.call_params, tools, cancellation_token + call_ctx.full_history, call_ctx.call_params, tools ) if call_ctx.persist_history: diff --git a/src/fast_agent/agents/tool_agent.py b/src/fast_agent/agents/tool_agent.py index 72d248fb..470cca53 100644 --- a/src/fast_agent/agents/tool_agent.py +++ b/src/fast_agent/agents/tool_agent.py @@ -12,7 +12,6 @@ ) from fast_agent.context import Context from fast_agent.core.logging.logger import get_logger -from fast_agent.llm.cancellation import CancellationToken from fast_agent.mcp.helpers.content_helpers import text_content from fast_agent.tools.elicitation import get_elicitation_fastmcp_tool from fast_agent.types import PromptMessageExtended, RequestParams @@ -80,7 +79,6 @@ async def generate_impl( messages: List[PromptMessageExtended], request_params: RequestParams | None = None, tools: List[Tool] | None = None, - cancellation_token: CancellationToken | None = None, ) -> PromptMessageExtended: """ Generate a response using the LLM, and handle tool calls if necessary. @@ -97,7 +95,6 @@ async def generate_impl( messages, request_params=request_params, tools=tools, - cancellation_token=cancellation_token, ) if LlmStopReason.TOOL_USE == result.stop_reason: diff --git a/src/fast_agent/agents/workflow/chain_agent.py b/src/fast_agent/agents/workflow/chain_agent.py index 7ae273b2..b15d8ec7 100644 --- a/src/fast_agent/agents/workflow/chain_agent.py +++ b/src/fast_agent/agents/workflow/chain_agent.py @@ -15,7 +15,6 @@ from fast_agent.core.logging.logger import get_logger from fast_agent.core.prompt import Prompt from fast_agent.interfaces import ModelT -from fast_agent.llm.cancellation import CancellationToken from fast_agent.types import PromptMessageExtended, RequestParams logger = get_logger(__name__) @@ -60,7 +59,6 @@ async def generate_impl( messages: List[PromptMessageExtended], request_params: Optional[RequestParams] = None, tools: List[Tool] | None = None, - cancellation_token: CancellationToken | None = None, ) -> PromptMessageExtended: """ Chain the request through multiple agents in sequence. diff --git a/src/fast_agent/agents/workflow/evaluator_optimizer.py b/src/fast_agent/agents/workflow/evaluator_optimizer.py index 2f32be0d..0cc242b2 100644 --- a/src/fast_agent/agents/workflow/evaluator_optimizer.py +++ b/src/fast_agent/agents/workflow/evaluator_optimizer.py @@ -19,7 +19,6 @@ from fast_agent.core.logging.logger import get_logger from fast_agent.core.prompt import Prompt from fast_agent.interfaces import AgentProtocol, ModelT -from fast_agent.llm.cancellation import CancellationToken from fast_agent.types import PromptMessageExtended, RequestParams logger = get_logger(__name__) @@ -109,7 +108,6 @@ async def generate_impl( messages: List[PromptMessageExtended], request_params: RequestParams | None = None, tools: List[Tool] | None = None, - cancellation_token: CancellationToken | None = None, ) -> PromptMessageExtended: """ Generate a response through evaluation-guided refinement. diff --git a/src/fast_agent/agents/workflow/iterative_planner.py b/src/fast_agent/agents/workflow/iterative_planner.py index 691d6dff..62a51cc3 100644 --- a/src/fast_agent/agents/workflow/iterative_planner.py +++ b/src/fast_agent/agents/workflow/iterative_planner.py @@ -23,7 +23,6 @@ from fast_agent.core.logging.logger import get_logger from fast_agent.core.prompt import Prompt from fast_agent.interfaces import AgentProtocol, ModelT -from fast_agent.llm.cancellation import CancellationToken from fast_agent.types import PromptMessageExtended, RequestParams logger = get_logger(__name__) @@ -241,7 +240,6 @@ async def generate_impl( messages: List[PromptMessageExtended], request_params: RequestParams | None = None, tools: List[Tool] | None = None, - cancellation_token: CancellationToken | None = None, ) -> PromptMessageExtended: """ Execute an orchestrated plan to process the input. diff --git a/src/fast_agent/agents/workflow/parallel_agent.py b/src/fast_agent/agents/workflow/parallel_agent.py index 374705f2..d4e4af51 100644 --- a/src/fast_agent/agents/workflow/parallel_agent.py +++ b/src/fast_agent/agents/workflow/parallel_agent.py @@ -9,7 +9,6 @@ from fast_agent.agents.llm_agent import LlmAgent from fast_agent.core.logging.logger import get_logger from fast_agent.interfaces import AgentProtocol, ModelT -from fast_agent.llm.cancellation import CancellationToken from fast_agent.types import PromptMessageExtended, RequestParams logger = get_logger(__name__) @@ -56,7 +55,6 @@ async def generate_impl( messages: List[PromptMessageExtended], request_params: Optional[RequestParams] = None, tools: List[Tool] | None = None, - cancellation_token: CancellationToken | None = None, ) -> PromptMessageExtended: """ Execute fan-out agents in parallel and aggregate their results with the fan-in agent. diff --git a/src/fast_agent/agents/workflow/router_agent.py b/src/fast_agent/agents/workflow/router_agent.py index 3ca05802..7d39dc86 100644 --- a/src/fast_agent/agents/workflow/router_agent.py +++ b/src/fast_agent/agents/workflow/router_agent.py @@ -17,7 +17,6 @@ from fast_agent.core.logging.logger import get_logger from fast_agent.core.prompt import Prompt from fast_agent.interfaces import FastAgentLLMProtocol, LLMFactoryProtocol, ModelT -from fast_agent.llm.cancellation import CancellationToken from fast_agent.types import PromptMessageExtended, RequestParams if TYPE_CHECKING: @@ -188,7 +187,6 @@ async def generate_impl( messages: List[PromptMessageExtended], request_params: Optional[RequestParams] = None, tools: List[Tool] | None = None, - cancellation_token: CancellationToken | None = None, ) -> PromptMessageExtended: """ Route the request to the most appropriate agent and return its response. diff --git a/src/fast_agent/llm/cancellation.py b/src/fast_agent/llm/cancellation.py index 514f62e7..6f5bd01b 100644 --- a/src/fast_agent/llm/cancellation.py +++ b/src/fast_agent/llm/cancellation.py @@ -1,86 +1,22 @@ """ Cancellation support for LLM provider calls. -This module provides a CancellationToken that can be used to cancel -in-flight LLM requests, particularly useful for: -- ESC key cancellation in interactive sessions -- ACP session/cancel protocol support -""" - -import asyncio -from typing import Optional - - -class CancellationToken: - """ - A token that can be used to cancel ongoing LLM operations. - - The token uses an asyncio.Event internally to signal cancellation. - Once cancelled, it remains in the cancelled state. - - Usage: - token = CancellationToken() - - # In the calling code: - result = await llm.generate(messages, cancellation_token=token) - - # To cancel (e.g., from ESC key handler or ACP cancel): - token.cancel() - - # In the LLM provider: - async for chunk in stream: - if token.is_cancelled: - break - # process chunk - """ +This module previously provided a CancellationToken class for cancellation. +Now cancellation is handled natively through asyncio.Task.cancel() which raises +asyncio.CancelledError at the next await point. - def __init__(self) -> None: - self._event = asyncio.Event() - self._cancel_reason: Optional[str] = None +Usage: + # Store the task when starting work + task = asyncio.current_task() - def cancel(self, reason: Optional[str] = None) -> None: - """ - Signal cancellation. + # To cancel: + task.cancel() # Raises asyncio.CancelledError in the task - Args: - reason: Optional reason for cancellation (e.g., "user_cancelled", "timeout") - """ - self._cancel_reason = reason or "cancelled" - self._event.set() - - @property - def is_cancelled(self) -> bool: - """Check if cancellation has been requested.""" - return self._event.is_set() - - @property - def cancel_reason(self) -> Optional[str]: - """Get the reason for cancellation, if any.""" - return self._cancel_reason - - async def wait_for_cancellation(self) -> None: - """Wait until cancellation is requested.""" - await self._event.wait() - - def reset(self) -> None: - """ - Reset the token to non-cancelled state. - - Note: This should be used with caution. Generally, it's better - to create a new token for each operation. - """ - self._event.clear() - self._cancel_reason = None - - -class CancellationError(Exception): - """ - Exception raised when an operation is cancelled. - - This is not an error condition - it indicates intentional cancellation - by the user or system. - """ + # In the LLM provider, CancelledError propagates naturally: + async for chunk in stream: + # task.cancel() will raise CancelledError here + process(chunk) +""" - def __init__(self, reason: str = "Operation cancelled"): - self.reason = reason - super().__init__(reason) +# This module is kept for documentation purposes. +# All cancellation is now handled via asyncio.Task.cancel() diff --git a/src/fast_agent/llm/fastagent_llm.py b/src/fast_agent/llm/fastagent_llm.py index a3bb95ed..9424ab9d 100644 --- a/src/fast_agent/llm/fastagent_llm.py +++ b/src/fast_agent/llm/fastagent_llm.py @@ -36,7 +36,6 @@ FastAgentLLMProtocol, ModelT, ) -from fast_agent.llm.cancellation import CancellationToken from fast_agent.llm.memory import Memory, SimpleMemory from fast_agent.llm.model_database import ModelDatabase from fast_agent.llm.provider_types import Provider @@ -185,7 +184,6 @@ async def generate( messages: list[PromptMessageExtended], request_params: RequestParams | None = None, tools: list[Tool] | None = None, - cancellation_token: CancellationToken | None = None, ) -> PromptMessageExtended: """ Generate a completion using normalized message lists. @@ -197,13 +195,12 @@ async def generate( messages: List of PromptMessageExtended objects request_params: Optional parameters to configure the LLM request tools: Optional list of tools available to the LLM - cancellation_token: Optional token to cancel the operation Returns: A PromptMessageExtended containing the Assistant response Raises: - CancellationError: If the operation is cancelled via the token + asyncio.CancelledError: If the operation is cancelled via task.cancel() """ # TODO -- create a "fast-agent" control role rather than magic strings @@ -230,7 +227,7 @@ async def generate( # Track timing for this generation start_time = time.perf_counter() assistant_response: PromptMessageExtended = await self._apply_prompt_provider_specific( - full_history, request_params, tools, cancellation_token=cancellation_token + full_history, request_params, tools ) end_time = time.perf_counter() duration_ms = round((end_time - start_time) * 1000, 2) @@ -258,7 +255,6 @@ async def _apply_prompt_provider_specific( request_params: RequestParams | None = None, tools: list[Tool] | None = None, is_template: bool = False, - cancellation_token: CancellationToken | None = None, ) -> PromptMessageExtended: """ Provider-specific implementation of apply_prompt_template. @@ -271,7 +267,6 @@ async def _apply_prompt_provider_specific( request_params: Optional parameters to configure the LLM request tools: Optional list of tools available to the LLM is_template: Whether this is a template application - cancellation_token: Optional token to cancel the operation Returns: String representation of the assistant's response if generated, diff --git a/src/fast_agent/llm/internal/passthrough.py b/src/fast_agent/llm/internal/passthrough.py index 89d41250..612c492e 100644 --- a/src/fast_agent/llm/internal/passthrough.py +++ b/src/fast_agent/llm/internal/passthrough.py @@ -6,7 +6,6 @@ from fast_agent.core.logging.logger import get_logger from fast_agent.core.prompt import Prompt -from fast_agent.llm.cancellation import CancellationToken from fast_agent.llm.fastagent_llm import ( FastAgentLLM, RequestParams, @@ -77,7 +76,6 @@ async def _apply_prompt_provider_specific( request_params: RequestParams | None = None, tools: list[Tool] | None = None, is_template: bool = False, - cancellation_token: CancellationToken | None = None, ) -> PromptMessageExtended: # Add messages to history with proper is_prompt flag self.history.extend(multipart_messages, is_prompt=is_template) diff --git a/src/fast_agent/llm/internal/playback.py b/src/fast_agent/llm/internal/playback.py index cafb233b..356bb0b3 100644 --- a/src/fast_agent/llm/internal/playback.py +++ b/src/fast_agent/llm/internal/playback.py @@ -6,7 +6,6 @@ from fast_agent.core.exceptions import ModelConfigError from fast_agent.core.prompt import Prompt from fast_agent.interfaces import ModelT -from fast_agent.llm.cancellation import CancellationToken from fast_agent.llm.internal.passthrough import PassthroughLLM from fast_agent.llm.provider_types import Provider from fast_agent.llm.usage_tracking import create_turn_usage_from_messages @@ -65,7 +64,6 @@ async def generate( ], request_params: RequestParams | None = None, tools: list[Tool] | None = None, - cancellation_token: CancellationToken | None = None, ) -> PromptMessageExtended: """ Handle playback of messages in two modes: diff --git a/src/fast_agent/llm/provider/anthropic/llm_anthropic.py b/src/fast_agent/llm/provider/anthropic/llm_anthropic.py index 4e37c2d7..389da01f 100644 --- a/src/fast_agent/llm/provider/anthropic/llm_anthropic.py +++ b/src/fast_agent/llm/provider/anthropic/llm_anthropic.py @@ -1,3 +1,4 @@ +import asyncio import json from typing import Any, Type, Union, cast @@ -28,7 +29,6 @@ from fast_agent.core.prompt import Prompt from fast_agent.event_progress import ProgressAction from fast_agent.interfaces import ModelT -from fast_agent.llm.cancellation import CancellationError, CancellationToken from fast_agent.llm.fastagent_llm import ( FastAgentLLM, RequestParams, @@ -238,7 +238,6 @@ async def _process_stream( self, stream: AsyncMessageStream, model: str, - cancellation_token: CancellationToken | None = None, ) -> Message: """Process the streaming response and display real-time token usage.""" # Track estimated output tokens by counting text chunks @@ -247,11 +246,8 @@ async def _process_stream( try: # Process the raw event stream to get token counts + # Cancellation is handled via asyncio.Task.cancel() which raises CancelledError async for event in stream: - # Check for cancellation before processing each event - if cancellation_token and cancellation_token.is_cancelled: - logger.info("Stream cancelled by user") - raise CancellationError(cancellation_token.cancel_reason or "cancelled") if ( event.type == "content_block_start" and hasattr(event, "content_block") @@ -489,7 +485,6 @@ async def _anthropic_completion( pre_messages: list[MessageParam] | None = None, history: list[PromptMessageExtended] | None = None, current_extended: PromptMessageExtended | None = None, - cancellation_token: CancellationToken | None = None, ) -> PromptMessageExtended: """ Process a query using an LLM and available tools. @@ -573,9 +568,10 @@ async def _anthropic_completion( try: async with anthropic.messages.stream(**arguments) as stream: # Process the stream - response = await self._process_stream(stream, model, cancellation_token) - except CancellationError as e: - logger.info(f"Anthropic completion cancelled: {e.reason}") + response = await self._process_stream(stream, model) + except asyncio.CancelledError as e: + reason = str(e) if e.args else "cancelled" + logger.info(f"Anthropic completion cancelled: {reason}") # Return a response indicating cancellation return Prompt.assistant( TextContent(type="text", text=""), @@ -662,7 +658,6 @@ async def _apply_prompt_provider_specific( request_params: RequestParams | None = None, tools: list[Tool] | None = None, is_template: bool = False, - cancellation_token: CancellationToken | None = None, ) -> PromptMessageExtended: """ Provider-specific prompt application. @@ -683,7 +678,6 @@ async def _apply_prompt_provider_specific( pre_messages=None, history=multipart_messages, current_extended=last_message, - cancellation_token=cancellation_token, ) else: # For assistant messages: Return the last message content as text diff --git a/src/fast_agent/llm/provider/bedrock/llm_bedrock.py b/src/fast_agent/llm/provider/bedrock/llm_bedrock.py index a856d62b..b0194cd4 100644 --- a/src/fast_agent/llm/provider/bedrock/llm_bedrock.py +++ b/src/fast_agent/llm/provider/bedrock/llm_bedrock.py @@ -18,7 +18,6 @@ from fast_agent.core.logging.logger import get_logger from fast_agent.event_progress import ProgressAction from fast_agent.interfaces import ModelT -from fast_agent.llm.cancellation import CancellationError, CancellationToken from fast_agent.llm.fastagent_llm import FastAgentLLM from fast_agent.llm.provider.bedrock.multipart_converter_bedrock import BedrockConverter from fast_agent.llm.provider_types import Provider @@ -1013,7 +1012,6 @@ async def _process_stream( self, stream_response, model: str, - cancellation_token: CancellationToken | None = None, ) -> BedrockMessage: """Process streaming response from Bedrock.""" estimated_tokens = 0 @@ -1023,11 +1021,8 @@ async def _process_stream( usage = {"input_tokens": 0, "output_tokens": 0} try: + # Cancellation is handled via asyncio.Task.cancel() which raises CancelledError for event in stream_response["stream"]: - # Check for cancellation before processing each event - if cancellation_token and cancellation_token.is_cancelled: - self.logger.info("Stream cancelled by user") - raise CancellationError(cancellation_token.cancel_reason or "cancelled") if "messageStart" in event: # Message started @@ -1225,7 +1220,6 @@ async def _bedrock_completion( tools: list[Tool] | None = None, pre_messages: list[BedrockMessageParam] | None = None, history: list[PromptMessageExtended] | None = None, - cancellation_token: CancellationToken | None = None, ) -> PromptMessageExtended: """ Process a query using Bedrock and available tools. @@ -1573,7 +1567,7 @@ async def _bedrock_completion( attempted_streaming = True response = client.converse_stream(**converse_args) processed_response = await self._process_stream( - response, model, cancellation_token + response, model ) except (ClientError, BotoCoreError) as e: # Check if this is a reasoning-related error @@ -1605,7 +1599,7 @@ async def _bedrock_completion( else: response = client.converse_stream(**converse_args) processed_response = await self._process_stream( - response, model, cancellation_token + response, model ) else: # Not a reasoning error, re-raise @@ -1721,7 +1715,7 @@ async def _bedrock_completion( else: response = client.converse_stream(**converse_args) processed_response = await self._process_stream( - response, model, cancellation_token + response, model ) if not caps.schema and has_tools: caps.schema = ToolSchemaType(schema_choice) @@ -1879,7 +1873,6 @@ async def _apply_prompt_provider_specific( request_params: RequestParams | None = None, tools: list[Tool] | None = None, is_template: bool = False, - cancellation_token: CancellationToken | None = None, ) -> PromptMessageExtended: """ Provider-specific prompt application. @@ -1907,7 +1900,6 @@ async def _apply_prompt_provider_specific( tools, pre_messages=None, history=multipart_messages, - cancellation_token=cancellation_token, ) def _generate_simplified_schema(self, model: Type[ModelT]) -> str: diff --git a/src/fast_agent/llm/provider/google/llm_google_native.py b/src/fast_agent/llm/provider/google/llm_google_native.py index f0460dc3..31be2dd0 100644 --- a/src/fast_agent/llm/provider/google/llm_google_native.py +++ b/src/fast_agent/llm/provider/google/llm_google_native.py @@ -17,7 +17,6 @@ from fast_agent.core.exceptions import ProviderKeyError from fast_agent.core.prompt import Prompt -from fast_agent.llm.cancellation import CancellationError, CancellationToken from fast_agent.llm.fastagent_llm import FastAgentLLM # Import the new converter class @@ -115,7 +114,6 @@ async def _stream_generate_content( contents: list[types.Content], config: types.GenerateContentConfig, client: genai.Client, - cancellation_token: CancellationToken | None = None, ) -> types.GenerateContentResponse | None: """Stream Gemini responses and return the final aggregated completion.""" try: @@ -136,16 +134,13 @@ async def _stream_generate_content( ) return None - return await self._consume_google_stream( - response_stream, model=model, cancellation_token=cancellation_token - ) + return await self._consume_google_stream(response_stream, model=model) async def _consume_google_stream( self, response_stream, *, model: str, - cancellation_token: CancellationToken | None = None, ) -> types.GenerateContentResponse | None: """Consume the async streaming iterator and aggregate the final response.""" estimated_tokens = 0 @@ -157,11 +152,8 @@ async def _consume_google_stream( last_chunk: types.GenerateContentResponse | None = None try: + # Cancellation is handled via asyncio.Task.cancel() which raises CancelledError async for chunk in response_stream: - # Check for cancellation before processing each chunk - if cancellation_token and cancellation_token.is_cancelled: - self.logger.info("Stream cancelled by user") - raise CancellationError(cancellation_token.cancel_reason or "cancelled") last_chunk = chunk if getattr(chunk, "usage_metadata", None): @@ -337,7 +329,6 @@ async def _google_completion( *, response_mime_type: str | None = None, response_schema: object | None = None, - cancellation_token: CancellationToken | None = None, ) -> PromptMessageExtended: """ Process a query using Google's generate_content API and available tools. @@ -386,7 +377,6 @@ async def _google_completion( contents=conversation_history, config=generate_content_config, client=client, - cancellation_token=cancellation_token, ) if api_response is None: api_response = await client.aio.models.generate_content( @@ -490,7 +480,6 @@ async def _apply_prompt_provider_specific( request_params: RequestParams | None = None, tools: list[McpTool] | None = None, is_template: bool = False, - cancellation_token: CancellationToken | None = None, ) -> PromptMessageExtended: """ Provider-specific prompt application. @@ -552,7 +541,6 @@ async def _apply_prompt_provider_specific( conversation_history, request_params=request_params, tools=tools, - cancellation_token=cancellation_token, ) def _convert_extended_messages_to_provider( diff --git a/src/fast_agent/llm/provider/openai/llm_openai.py b/src/fast_agent/llm/provider/openai/llm_openai.py index c02a9ce7..02217f93 100644 --- a/src/fast_agent/llm/provider/openai/llm_openai.py +++ b/src/fast_agent/llm/provider/openai/llm_openai.py @@ -1,3 +1,4 @@ +import asyncio from typing import Any from mcp import Tool @@ -24,7 +25,6 @@ from fast_agent.core.logging.logger import get_logger from fast_agent.core.prompt import Prompt from fast_agent.event_progress import ProgressAction -from fast_agent.llm.cancellation import CancellationError, CancellationToken from fast_agent.llm.fastagent_llm import FastAgentLLM, RequestParams from fast_agent.llm.model_database import ModelDatabase from fast_agent.llm.provider.openai.multipart_converter_openai import OpenAIConverter, OpenAIMessage @@ -207,7 +207,6 @@ async def _process_stream( self, stream, model: str, - cancellation_token: CancellationToken | None = None, ): """Process the streaming response and display real-time token usage.""" # Track estimated output tokens by counting text chunks @@ -221,7 +220,7 @@ async def _process_stream( Provider.GOOGLE_OAI, ] if stream_mode == "manual" or provider_requires_manual: - return await self._process_stream_manual(stream, model, cancellation_token) + return await self._process_stream_manual(stream, model) # Use ChatCompletionStreamState helper for accumulation (OpenAI only) state = ChatCompletionStreamState() @@ -232,11 +231,8 @@ async def _process_stream( notified_tool_indices: set[int] = set() # Process the stream chunks + # Cancellation is handled via asyncio.Task.cancel() which raises CancelledError async for chunk in stream: - # Check for cancellation before processing each chunk - if cancellation_token and cancellation_token.is_cancelled: - _logger.info("Stream cancelled by user") - raise CancellationError(cancellation_token.cancel_reason or "cancelled") # Handle chunk accumulation state.handle_chunk(chunk) @@ -437,7 +433,6 @@ async def _process_stream_manual( self, stream, model: str, - cancellation_token: CancellationToken | None = None, ): """Manual stream processing for providers like Ollama that may not work with ChatCompletionStreamState.""" @@ -460,11 +455,8 @@ async def _process_stream_manual( notified_tool_indices: set[int] = set() # Process the stream chunks manually + # Cancellation is handled via asyncio.Task.cancel() which raises CancelledError async for chunk in stream: - # Check for cancellation before processing each chunk - if cancellation_token and cancellation_token.is_cancelled: - self.logger.info("Stream cancelled by user") - raise CancellationError(cancellation_token.cancel_reason or "cancelled") # Process streaming events for tool calls if chunk.choices: @@ -701,7 +693,6 @@ async def _openai_completion( message: list[OpenAIMessage] | None, request_params: RequestParams | None = None, tools: list[Tool] | None = None, - cancellation_token: CancellationToken | None = None, ) -> PromptMessageExtended: """ Process a query using an LLM and available tools. @@ -759,9 +750,10 @@ async def _openai_completion( async with self._openai_client() as client: stream = await client.chat.completions.create(**arguments) # Process the stream - response = await self._process_stream(stream, model_name, cancellation_token) - except CancellationError as e: - self.logger.info(f"OpenAI completion cancelled: {e.reason}") + response = await self._process_stream(stream, model_name) + except asyncio.CancelledError as e: + reason = str(e) if e.args else "cancelled" + self.logger.info(f"OpenAI completion cancelled: {reason}") # Return a response indicating cancellation return Prompt.assistant( TextContent(type="text", text=""), @@ -916,7 +908,6 @@ async def _apply_prompt_provider_specific( request_params: RequestParams | None = None, tools: list[Tool] | None = None, is_template: bool = False, - cancellation_token: CancellationToken | None = None, ) -> PromptMessageExtended: """ Provider-specific prompt application. @@ -936,9 +927,7 @@ async def _apply_prompt_provider_specific( if not converted_messages: converted_messages = [{"role": "user", "content": ""}] - return await self._openai_completion( - converted_messages, req_params, tools, cancellation_token - ) + return await self._openai_completion(converted_messages, req_params, tools) def _prepare_api_request( self, messages, tools: list[ChatCompletionToolParam] | None, request_params: RequestParams diff --git a/tests/integration/tool_loop/test_tool_loop.py b/tests/integration/tool_loop/test_tool_loop.py index ea0c632f..921976e3 100644 --- a/tests/integration/tool_loop/test_tool_loop.py +++ b/tests/integration/tool_loop/test_tool_loop.py @@ -7,7 +7,6 @@ from fast_agent.agents.tool_agent import ToolAgent from fast_agent.constants import FAST_AGENT_ERROR_CHANNEL from fast_agent.core.prompt import Prompt -from fast_agent.llm.cancellation import CancellationToken from fast_agent.llm.internal.passthrough import PassthroughLLM from fast_agent.llm.request_params import RequestParams from fast_agent.mcp.prompt_message_extended import PromptMessageExtended @@ -21,7 +20,6 @@ async def _apply_prompt_provider_specific( request_params: RequestParams | None = None, tools: list[Tool] | None = None, is_template: bool = False, - cancellation_token: CancellationToken | None = None, ) -> PromptMessageExtended: tool_calls = {} tool_calls["my_id"] = CallToolRequest( @@ -104,7 +102,6 @@ async def _apply_prompt_provider_specific( request_params: RequestParams | None = None, tools: list[Tool] | None = None, is_template: bool = False, - cancellation_token: CancellationToken | None = None, ) -> PromptMessageExtended: self.call_count += 1 tool_calls = { diff --git a/tests/unit/fast_agent/agents/test_agent_history_binding.py b/tests/unit/fast_agent/agents/test_agent_history_binding.py index ad5703b9..32919a2f 100644 --- a/tests/unit/fast_agent/agents/test_agent_history_binding.py +++ b/tests/unit/fast_agent/agents/test_agent_history_binding.py @@ -4,7 +4,6 @@ from fast_agent.agents.agent_types import AgentConfig from fast_agent.agents.llm_agent import LlmAgent from fast_agent.core.prompt import Prompt -from fast_agent.llm.cancellation import CancellationToken from fast_agent.llm.fastagent_llm import FastAgentLLM from fast_agent.llm.provider_types import Provider from fast_agent.llm.request_params import RequestParams @@ -22,7 +21,6 @@ async def _apply_prompt_provider_specific( request_params: RequestParams | None = None, tools=None, is_template: bool = False, - cancellation_token: CancellationToken | None = None, ) -> PromptMessageExtended: self.last_messages = list(multipart_messages) return Prompt.assistant("ok") diff --git a/tests/unit/fast_agent/agents/test_llm_content_filter.py b/tests/unit/fast_agent/agents/test_llm_content_filter.py index 04cab9ef..f321e9cc 100644 --- a/tests/unit/fast_agent/agents/test_llm_content_filter.py +++ b/tests/unit/fast_agent/agents/test_llm_content_filter.py @@ -40,7 +40,7 @@ def model_name(self) -> str | None: def provider(self) -> Provider: return self._provider - async def generate(self, messages, request_params=None, tools=None, cancellation_token=None): + async def generate(self, messages, request_params=None, tools=None): self.generated_messages = messages self._message_history = messages return PromptMessageExtended( diff --git a/tests/unit/fast_agent/llm/providers/test_llm_openai_history.py b/tests/unit/fast_agent/llm/providers/test_llm_openai_history.py index 3637be00..e406dd46 100644 --- a/tests/unit/fast_agent/llm/providers/test_llm_openai_history.py +++ b/tests/unit/fast_agent/llm/providers/test_llm_openai_history.py @@ -13,7 +13,7 @@ def __init__(self, **kwargs): super().__init__(**kwargs) self.captured = None - async def _openai_completion(self, message, request_params=None, tools=None, cancellation_token=None): + async def _openai_completion(self, message, request_params=None, tools=None): self.captured = message return Prompt.assistant("ok")