Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 28 additions & 50 deletions src/fast_agent/acp/server/agent_acp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,7 @@
SetSessionModeRequest,
SetSessionModeResponse,
)
from acp.agent.router import build_agent_router
from acp.connection import MethodHandler
from acp.helpers import session_notification, update_agent_message_text
from acp.meta import AGENT_METHODS
from acp.schema import (
AgentCapabilities,
Implementation,
Expand Down Expand Up @@ -54,7 +51,6 @@
enrich_with_environment_context,
)
from fast_agent.interfaces import StreamingAgentProtocol
from fast_agent.llm.cancellation import CancellationToken
from fast_agent.llm.model_database import ModelDatabase
from fast_agent.mcp.helpers.content_helpers import is_text_content
from fast_agent.types import LlmStopReason, PromptMessageExtended, RequestParams
Expand All @@ -66,34 +62,6 @@
REFUSAL: StopReason = "refusal"


class ExtendedAgentSideConnection(AgentSideConnection):
"""
Extended AgentSideConnection that registers session/cancel as both request and notification.

Some clients incorrectly send session/cancel as a request (with an id) instead of
a notification. This subclass adds the cancel handler to both routing tables for
compatibility.
"""

def _create_handler(self, agent: ACPAgent) -> MethodHandler:
"""Override to add cancel as both request and notification handler."""
router = build_agent_router(agent)

# Also register cancel as a request handler for compatibility with clients
# that incorrectly send it with an id
router._requests[AGENT_METHODS["session_cancel"]] = router._notifications.get(
AGENT_METHODS["session_cancel"]
)

async def handler(method: str, params: Any | None, is_notification: bool) -> Any:
if is_notification:
await router.dispatch_notification(method, params)
return None
return await router.dispatch_request(method, params)

return handler


def map_llm_stop_reason_to_acp(llm_stop_reason: LlmStopReason | None) -> StopReason:
"""
Map fast-agent LlmStopReason to ACP StopReason.
Expand Down Expand Up @@ -221,8 +189,8 @@ def __init__(
# Track sessions with active prompts to prevent overlapping requests (per ACP protocol)
self._active_prompts: set[str] = set()

# Track cancellation tokens per session for cancel support
self._session_cancellation_tokens: dict[str, CancellationToken] = {}
# Track asyncio tasks per session for proper task-based cancellation
self._session_tasks: dict[str, asyncio.Task] = {}

# Track current agent per session for ACP mode support
self._session_current_agent: dict[str, str] = {}
Expand Down Expand Up @@ -790,9 +758,10 @@ async def prompt(self, params: PromptRequest) -> PromptResponse:
# Mark this session as having an active prompt
self._active_prompts.add(session_id)

# Create a cancellation token for this prompt
cancellation_token = CancellationToken()
self._session_cancellation_tokens[session_id] = cancellation_token
# Track the current task for proper cancellation via asyncio.Task.cancel()
current_task = asyncio.current_task()
if current_task:
self._session_tasks[session_id] = current_task

# Use try/finally to ensure session is always removed from active prompts
try:
Expand Down Expand Up @@ -944,7 +913,6 @@ def on_stream_chunk(chunk: str):
result = await agent.generate(
prompt_message,
request_params=session_request_params,
cancellation_token=cancellation_token,
)
response_text = result.last_text() or "No content generated"

Expand Down Expand Up @@ -1055,11 +1023,19 @@ def on_stream_chunk(chunk: str):
return PromptResponse(
stopReason=acp_stop_reason,
)
except asyncio.CancelledError:
# Task was cancelled - return appropriate response
logger.info(
"Prompt cancelled by user",
name="acp_prompt_cancelled",
session_id=session_id,
)
return PromptResponse(stopReason="cancelled")
finally:
# Always remove session from active prompts and cleanup cancellation token
# Always remove session from active prompts and cleanup task
async with self._session_lock:
self._active_prompts.discard(session_id)
self._session_cancellation_tokens.pop(session_id, None)
self._session_tasks.pop(session_id, None)
logger.debug(
"Removed session from active prompts",
name="acp_prompt_complete",
Expand All @@ -1073,6 +1049,9 @@ async def cancel(self, params: CancelNotification) -> None:
This cancels any in-progress prompt for the specified session.
Per ACP protocol, we should stop all LLM requests and tool invocations
as soon as possible.

Uses asyncio.Task.cancel() for proper async cancellation, which raises
asyncio.CancelledError in the running task.
"""
session_id = params.sessionId

Expand All @@ -1082,14 +1061,14 @@ async def cancel(self, params: CancelNotification) -> None:
session_id=session_id,
)

# Get the cancellation token for this session and signal cancellation
# Get the task for this session and cancel it
async with self._session_lock:
cancellation_token = self._session_cancellation_tokens.get(session_id)
if cancellation_token:
cancellation_token.cancel("user_cancelled")
task = self._session_tasks.get(session_id)
if task and not task.done():
task.cancel()
logger.info(
"Cancellation signaled for session",
name="acp_cancel_signaled",
"Task cancelled for session",
name="acp_cancel_task",
session_id=session_id,
)
else:
Expand All @@ -1116,8 +1095,7 @@ async def run_async(self) -> None:
# Note: AgentSideConnection expects (writer, reader) order
# - input_stream (writer) = where agent writes TO client
# - output_stream (reader) = where agent reads FROM client
# Use ExtendedAgentSideConnection for cancel request/notification compatibility
connection = ExtendedAgentSideConnection(
connection = AgentSideConnection(
lambda conn: self,
writer, # input_stream = StreamWriter for agent output
reader, # output_stream = StreamReader for agent input
Expand Down Expand Up @@ -1222,8 +1200,8 @@ async def _cleanup_sessions(self) -> None:
# Clean up session current agent mapping
self._session_current_agent.clear()

# Clear cancellation tokens
self._session_cancellation_tokens.clear()
# Clear tasks
self._session_tasks.clear()

# Clear stored prompt contexts
self._session_prompt_context.clear()
Expand Down
6 changes: 2 additions & 4 deletions src/fast_agent/agents/llm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from fast_agent.agents.llm_decorator import LlmDecorator, ModelT
from fast_agent.constants import FAST_AGENT_ERROR_CHANNEL
from fast_agent.context import Context
from fast_agent.llm.cancellation import CancellationToken
from fast_agent.mcp.helpers.content_helpers import get_text
from fast_agent.types import PromptMessageExtended, RequestParams
from fast_agent.types.llm_stop_reason import LlmStopReason
Expand Down Expand Up @@ -238,7 +237,6 @@ async def generate_impl(
messages: List[PromptMessageExtended],
request_params: RequestParams | None = None,
tools: List[Tool] | None = None,
cancellation_token: CancellationToken | None = None,
) -> PromptMessageExtended:
"""
Enhanced generate implementation that resets tool call tracking.
Expand Down Expand Up @@ -272,7 +270,7 @@ async def generate_impl(

try:
result, summary = await self._generate_with_summary(
messages, request_params, tools, cancellation_token
messages, request_params, tools
)
finally:
if remove_listener:
Expand All @@ -288,7 +286,7 @@ async def generate_impl(
await self.show_assistant_message(result, additional_message=summary_text)
else:
result, summary = await self._generate_with_summary(
messages, request_params, tools, cancellation_token
messages, request_params, tools
)

summary_text = (
Expand Down
12 changes: 3 additions & 9 deletions src/fast_agent/agents/llm_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
LLMFactoryProtocol,
StreamingAgentProtocol,
)
from fast_agent.llm.cancellation import CancellationToken
from fast_agent.llm.model_database import ModelDatabase
from fast_agent.llm.provider_types import Provider
from fast_agent.llm.usage_tracking import UsageAccumulator
Expand Down Expand Up @@ -289,7 +288,6 @@ async def generate(
],
request_params: RequestParams | None = None,
tools: list[Tool] | None = None,
cancellation_token: CancellationToken | None = None,
) -> PromptMessageExtended:
"""
Create a completion with the LLM using the provided messages.
Expand All @@ -305,7 +303,6 @@ async def generate(
- List of any combination of the above
request_params: Optional parameters to configure the request
tools: Optional list of tools available to the LLM
cancellation_token: Optional token to cancel the operation

Returns:
The LLM's response as a PromptMessageExtended
Expand All @@ -318,15 +315,14 @@ async def generate(

with self._tracer.start_as_current_span(f"Agent: '{self._name}' generate"):
return await self.generate_impl(
multipart_messages, final_request_params, tools, cancellation_token
multipart_messages, final_request_params, tools
)

async def generate_impl(
self,
messages: list[PromptMessageExtended],
request_params: RequestParams | None = None,
tools: list[Tool] | None = None,
cancellation_token: CancellationToken | None = None,
) -> PromptMessageExtended:
"""
Implementation method for generate.
Expand All @@ -339,13 +335,12 @@ async def generate_impl(
messages: Normalized list of PromptMessageExtended objects
request_params: Optional parameters to configure the request
tools: Optional list of tools available to the LLM
cancellation_token: Optional token to cancel the operation

Returns:
The LLM's response as a PromptMessageExtended
"""
response, _ = await self._generate_with_summary(
messages, request_params, tools, cancellation_token
messages, request_params, tools
)
return response

Expand Down Expand Up @@ -480,13 +475,12 @@ async def _generate_with_summary(
messages: list[PromptMessageExtended],
request_params: RequestParams | None = None,
tools: list[Tool] | None = None,
cancellation_token: CancellationToken | None = None,
) -> tuple[PromptMessageExtended, RemovedContentSummary | None]:
assert self._llm, "LLM is not attached"
call_ctx = self._prepare_llm_call(messages, request_params)

response = await self._llm.generate(
call_ctx.full_history, call_ctx.call_params, tools, cancellation_token
call_ctx.full_history, call_ctx.call_params, tools
)

if call_ctx.persist_history:
Expand Down
3 changes: 0 additions & 3 deletions src/fast_agent/agents/tool_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
)
from fast_agent.context import Context
from fast_agent.core.logging.logger import get_logger
from fast_agent.llm.cancellation import CancellationToken
from fast_agent.mcp.helpers.content_helpers import text_content
from fast_agent.tools.elicitation import get_elicitation_fastmcp_tool
from fast_agent.types import PromptMessageExtended, RequestParams
Expand Down Expand Up @@ -80,7 +79,6 @@ async def generate_impl(
messages: List[PromptMessageExtended],
request_params: RequestParams | None = None,
tools: List[Tool] | None = None,
cancellation_token: CancellationToken | None = None,
) -> PromptMessageExtended:
"""
Generate a response using the LLM, and handle tool calls if necessary.
Expand All @@ -97,7 +95,6 @@ async def generate_impl(
messages,
request_params=request_params,
tools=tools,
cancellation_token=cancellation_token,
)

if LlmStopReason.TOOL_USE == result.stop_reason:
Expand Down
2 changes: 0 additions & 2 deletions src/fast_agent/agents/workflow/chain_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from fast_agent.core.logging.logger import get_logger
from fast_agent.core.prompt import Prompt
from fast_agent.interfaces import ModelT
from fast_agent.llm.cancellation import CancellationToken
from fast_agent.types import PromptMessageExtended, RequestParams

logger = get_logger(__name__)
Expand Down Expand Up @@ -60,7 +59,6 @@ async def generate_impl(
messages: List[PromptMessageExtended],
request_params: Optional[RequestParams] = None,
tools: List[Tool] | None = None,
cancellation_token: CancellationToken | None = None,
) -> PromptMessageExtended:
"""
Chain the request through multiple agents in sequence.
Expand Down
2 changes: 0 additions & 2 deletions src/fast_agent/agents/workflow/evaluator_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from fast_agent.core.logging.logger import get_logger
from fast_agent.core.prompt import Prompt
from fast_agent.interfaces import AgentProtocol, ModelT
from fast_agent.llm.cancellation import CancellationToken
from fast_agent.types import PromptMessageExtended, RequestParams

logger = get_logger(__name__)
Expand Down Expand Up @@ -109,7 +108,6 @@ async def generate_impl(
messages: List[PromptMessageExtended],
request_params: RequestParams | None = None,
tools: List[Tool] | None = None,
cancellation_token: CancellationToken | None = None,
) -> PromptMessageExtended:
"""
Generate a response through evaluation-guided refinement.
Expand Down
2 changes: 0 additions & 2 deletions src/fast_agent/agents/workflow/iterative_planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from fast_agent.core.logging.logger import get_logger
from fast_agent.core.prompt import Prompt
from fast_agent.interfaces import AgentProtocol, ModelT
from fast_agent.llm.cancellation import CancellationToken
from fast_agent.types import PromptMessageExtended, RequestParams

logger = get_logger(__name__)
Expand Down Expand Up @@ -241,7 +240,6 @@ async def generate_impl(
messages: List[PromptMessageExtended],
request_params: RequestParams | None = None,
tools: List[Tool] | None = None,
cancellation_token: CancellationToken | None = None,
) -> PromptMessageExtended:
"""
Execute an orchestrated plan to process the input.
Expand Down
2 changes: 0 additions & 2 deletions src/fast_agent/agents/workflow/parallel_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from fast_agent.agents.llm_agent import LlmAgent
from fast_agent.core.logging.logger import get_logger
from fast_agent.interfaces import AgentProtocol, ModelT
from fast_agent.llm.cancellation import CancellationToken
from fast_agent.types import PromptMessageExtended, RequestParams

logger = get_logger(__name__)
Expand Down Expand Up @@ -56,7 +55,6 @@ async def generate_impl(
messages: List[PromptMessageExtended],
request_params: Optional[RequestParams] = None,
tools: List[Tool] | None = None,
cancellation_token: CancellationToken | None = None,
) -> PromptMessageExtended:
"""
Execute fan-out agents in parallel and aggregate their results with the fan-in agent.
Expand Down
2 changes: 0 additions & 2 deletions src/fast_agent/agents/workflow/router_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from fast_agent.core.logging.logger import get_logger
from fast_agent.core.prompt import Prompt
from fast_agent.interfaces import FastAgentLLMProtocol, LLMFactoryProtocol, ModelT
from fast_agent.llm.cancellation import CancellationToken
from fast_agent.types import PromptMessageExtended, RequestParams

if TYPE_CHECKING:
Expand Down Expand Up @@ -188,7 +187,6 @@ async def generate_impl(
messages: List[PromptMessageExtended],
request_params: Optional[RequestParams] = None,
tools: List[Tool] | None = None,
cancellation_token: CancellationToken | None = None,
) -> PromptMessageExtended:
"""
Route the request to the most appropriate agent and return its response.
Expand Down
Loading
Loading