Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ classifiers = [
requires-python = ">=3.13.5,<3.14"
dependencies = [
"fastapi>=0.115.6",
"mcp==1.19.0",
"mcp==1.20.0",
"opentelemetry-distro>=0.55b0",
"opentelemetry-exporter-otlp-proto-http>=1.7.0",
"pydantic-settings>=2.7.0",
Expand Down
132 changes: 111 additions & 21 deletions src/fast_agent/agents/mcp_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Any,
Callable,
Dict,
Iterable,
List,
Mapping,
Optional,
Expand Down Expand Up @@ -43,7 +44,7 @@
from fast_agent.core.logging.logger import get_logger
from fast_agent.interfaces import FastAgentLLMProtocol
from fast_agent.mcp.common import get_resource_name, get_server_name, is_namespaced_name
from fast_agent.mcp.mcp_aggregator import MCPAggregator, ServerStatus
from fast_agent.mcp.mcp_aggregator import MCPAggregator, NamespacedTool, ServerStatus
from fast_agent.skills.registry import format_skills_for_prompt
from fast_agent.tools.elicitation import (
get_elicitation_tool,
Expand Down Expand Up @@ -761,14 +762,11 @@ async def run_tools(self, request: PromptMessageExtended) -> PromptMessageExtend
)

# Select display/highlight names
display_tool_name = tool_name
highlight_name = tool_name
if namespaced_tool is not None:
display_tool_name = namespaced_tool.namespaced_tool_name
highlight_name = namespaced_tool.namespaced_tool_name
elif candidate_namespaced_tool is not None:
display_tool_name = candidate_namespaced_tool.namespaced_tool_name
highlight_name = candidate_namespaced_tool.namespaced_tool_name
display_tool_name = (
(namespaced_tool or candidate_namespaced_tool).namespaced_tool_name
if (namespaced_tool or candidate_namespaced_tool) is not None
else tool_name
)

tool_available = (
tool_name == HUMAN_INPUT_TOOL_NAME
Expand All @@ -788,14 +786,6 @@ async def run_tools(self, request: PromptMessageExtended) -> PromptMessageExtend
)
break

# Find the index of the current tool in available_tools for highlighting
highlight_index = None
try:
highlight_index = available_tools.index(highlight_name)
except ValueError:
# Tool not found in list, no highlighting
pass

metadata: dict[str, Any] | None = None
if (
self._shell_runtime_enabled
Expand All @@ -804,10 +794,18 @@ async def run_tools(self, request: PromptMessageExtended) -> PromptMessageExtend
):
metadata = self._shell_runtime.metadata(tool_args.get("command"))

display_tool_name, bottom_items, highlight_index = self._prepare_tool_display(
tool_name=tool_name,
namespaced_tool=namespaced_tool,
candidate_namespaced_tool=candidate_namespaced_tool,
local_tool=local_tool,
fallback_order=self._unique_preserving_order(available_tools),
)

self.display.show_tool_call(
name=self._name,
tool_args=tool_args,
bottom_items=available_tools,
bottom_items=bottom_items,
tool_name=display_tool_name,
highlight_index=highlight_index,
max_item_length=12,
Expand Down Expand Up @@ -849,6 +847,73 @@ async def run_tools(self, request: PromptMessageExtended) -> PromptMessageExtend

return self._finalize_tool_results(tool_results, tool_loop_error=tool_loop_error)

def _prepare_tool_display(
self,
*,
tool_name: str,
namespaced_tool: "NamespacedTool | None",
candidate_namespaced_tool: "NamespacedTool | None",
local_tool: Any | None,
fallback_order: list[str],
) -> tuple[str, list[str] | None, int | None]:
"""
Determine how we present tool metadata for the console display.

Returns a tuple of (display_tool_name, bottom_items, highlight_index).
"""
active_namespaced = namespaced_tool or candidate_namespaced_tool
display_tool_name = (
active_namespaced.namespaced_tool_name if active_namespaced is not None else tool_name
)

bottom_items: list[str] | None = None
highlight_target: str | None = None

if active_namespaced is not None:
server_tools = self._aggregator._server_to_tool_map.get(
active_namespaced.server_name, []
)
if server_tools:
bottom_items = self._unique_preserving_order(
tool_entry.tool.name for tool_entry in server_tools
)
highlight_target = active_namespaced.tool.name
elif local_tool is not None:
bottom_items = self._unique_preserving_order(self._execution_tools.keys())
highlight_target = tool_name
elif tool_name == HUMAN_INPUT_TOOL_NAME:
bottom_items = [HUMAN_INPUT_TOOL_NAME]
highlight_target = HUMAN_INPUT_TOOL_NAME

highlight_index: int | None = None
if bottom_items and highlight_target:
try:
highlight_index = bottom_items.index(highlight_target)
except ValueError:
highlight_index = None

if bottom_items is None and fallback_order:
bottom_items = fallback_order
fallback_target = display_tool_name if display_tool_name in bottom_items else tool_name
try:
highlight_index = bottom_items.index(fallback_target)
except ValueError:
highlight_index = None

return display_tool_name, bottom_items, highlight_index

@staticmethod
def _unique_preserving_order(items: Iterable[str]) -> list[str]:
"""Return a list of unique items while preserving original order."""
seen: set[str] = set()
result: list[str] = []
for item in items:
if item in seen:
continue
seen.add(item)
result.append(item)
return result

async def apply_prompt_template(self, prompt_result: GetPromptResult, prompt_name: str) -> str:
"""
Apply a prompt template as persistent context that will be included in all future conversations.
Expand Down Expand Up @@ -1028,11 +1093,17 @@ async def show_assistant_message(
# Get the list of MCP servers (if not provided)
if bottom_items is None:
if self._aggregator and self._aggregator.server_names:
server_names = self._aggregator.server_names
server_names = list(self._aggregator.server_names)
else:
server_names = []
else:
server_names = bottom_items
server_names = list(bottom_items)

server_names = self._unique_preserving_order(server_names)

shell_label = self._shell_server_label()
if shell_label:
server_names = [shell_label, *(name for name in server_names if name != shell_label)]

# Extract servers from tool calls in the message for highlighting
if highlight_items is None:
Expand Down Expand Up @@ -1065,13 +1136,23 @@ def _extract_servers_from_message(self, message: PromptMessageExtended) -> List[
Returns:
List of server names that were called
"""
servers = []
servers: list[str] = []

# Check if message has tool calls
if message.tool_calls:
for tool_request in message.tool_calls.values():
tool_name = tool_request.params.name

if (
self._shell_runtime_enabled
and self._shell_runtime.tool
and tool_name == self._shell_runtime.tool.name
):
shell_label = self._shell_server_label()
if shell_label and shell_label not in servers:
servers.append(shell_label)
continue

# Use aggregator's mapping to find the server for this tool
if tool_name in self._aggregator._namespaced_tool_map:
namespaced_tool = self._aggregator._namespaced_tool_map[tool_name]
Expand All @@ -1080,6 +1161,15 @@ def _extract_servers_from_message(self, message: PromptMessageExtended) -> List[

return servers

def _shell_server_label(self) -> str | None:
"""Return the display label for the local shell runtime."""
if not self._shell_runtime_enabled or not self._shell_runtime.tool:
return None

runtime_info = self._shell_runtime.runtime_info()
runtime_name = runtime_info.get("name")
return runtime_name or "shell"

async def _parse_resource_name(self, name: str, resource_type: str) -> tuple[str, str]:
"""Delegate resource name parsing to the aggregator."""
return await self._aggregator._parse_resource_name(name, resource_type)
Expand Down
9 changes: 5 additions & 4 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading