Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ classifiers = [
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent"
]
requires-python = ">=3.13.5"
requires-python = ">=3.13.5,<3.14"
dependencies = [
"fastapi>=0.115.6",
"mcp==1.16.0",
Expand Down
2 changes: 1 addition & 1 deletion src/fast_agent/agents/llm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ async def show_assistant_message(
combined += segment
additional_message_text = combined

message_text = message.last_text() or ""
message_text = message

# Use provided name/model or fall back to defaults
display_name = name if name is not None else self.name
Expand Down
13 changes: 13 additions & 0 deletions src/fast_agent/mcp/mcp_agent_client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

from datetime import timedelta
from time import perf_counter
from typing import TYPE_CHECKING

from mcp import ClientSession, ServerNotification
Expand Down Expand Up @@ -207,6 +208,7 @@ async def send_request(
) -> ReceiveResultT:
logger.debug("send_request: request=", data=request.model_dump())
request_id = getattr(self, "_request_id", None)
start_time = perf_counter()
try:
result = await super().send_request(
request=request,
Expand All @@ -220,6 +222,7 @@ async def send_request(
data=result.model_dump() if result is not None else "no response returned",
)
self._attach_transport_channel(request_id, result)
self._attach_transport_elapsed(result, perf_counter() - start_time)
return result
except Exception as e:
# Handle connection errors cleanly
Expand Down Expand Up @@ -250,6 +253,16 @@ def _attach_transport_channel(self, request_id, result) -> None:
# If result cannot be mutated, ignore silently
pass

@staticmethod
def _attach_transport_elapsed(result, elapsed: float | None) -> None:
if result is None or elapsed is None:
return
try:
setattr(result, "transport_elapsed", max(elapsed, 0.0))
except Exception:
# Ignore if result is immutable
pass

async def _received_notification(self, notification: ServerNotification) -> None:
"""
Can be overridden by subclasses to handle a notification without needing
Expand Down
73 changes: 58 additions & 15 deletions src/fast_agent/mcp/mcp_connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
from fast_agent.mcp.transport_tracking import TransportChannelMetrics

if TYPE_CHECKING:
from mcp.client.auth import OAuthClientProvider

from fast_agent.context import Context
from fast_agent.mcp_server_registry import ServerRegistry

Expand Down Expand Up @@ -65,6 +67,38 @@ def _add_none_to_context(context_manager):
return StreamingContextAdapter(context_manager)


def _prepare_headers_and_auth(
server_config: MCPServerSettings,
) -> tuple[dict[str, str], Optional["OAuthClientProvider"], set[str]]:
"""
Prepare request headers and determine if OAuth authentication should be used.

Returns a copy of the headers, an OAuth auth provider when applicable, and the set
of user-supplied authorization header keys.
"""
headers: dict[str, str] = dict(server_config.headers or {})
auth_header_keys = {"authorization", "x-hf-authorization"}
user_provided_auth_keys = {key for key in headers if key.lower() in auth_header_keys}

# OAuth is only relevant for SSE/HTTP transports and should be skipped when the
# user has already supplied explicit Authorization headers.
if server_config.transport not in ("sse", "http") or user_provided_auth_keys:
return headers, None, user_provided_auth_keys

oauth_auth = build_oauth_provider(server_config)
if oauth_auth is not None:
# Scrub Authorization headers so OAuth-managed credentials are the only ones sent.
for header_name in (
"Authorization",
"authorization",
"X-HF-Authorization",
"x-hf-authorization",
):
headers.pop(header_name, None)

return headers, oauth_auth, user_provided_auth_keys


class ServerConnection:
"""
Represents a long-lived MCP server connection, including:
Expand Down Expand Up @@ -113,7 +147,9 @@ def __init__(
self.server_implementation: Implementation | None = None
self.client_capabilities: dict | None = None
self.server_instructions_available: bool = False
self.server_instructions_enabled: bool = server_config.include_instructions if server_config else True
self.server_instructions_enabled: bool = (
server_config.include_instructions if server_config else True
)
self.session_id: str | None = None
self._get_session_id_cb: GetSessionIdCallback | None = None
self.transport_metrics: TransportChannelMetrics | None = None
Expand Down Expand Up @@ -404,7 +440,9 @@ async def launch_server(

logger.debug(f"{server_name}: Found server configuration=", data=config.model_dump())

transport_metrics = TransportChannelMetrics() if config.transport in ("http", "stdio") else None
transport_metrics = (
TransportChannelMetrics() if config.transport in ("http", "stdio") else None
)

def transport_context_factory():
if config.transport == "stdio":
Expand All @@ -425,7 +463,9 @@ def transport_context_factory():

channel_hook = transport_metrics.record_event if transport_metrics else None
return _add_none_to_context(
tracking_stdio_client(server_params, channel_hook=channel_hook, errlog=error_handler)
tracking_stdio_client(
server_params, channel_hook=channel_hook, errlog=error_handler
)
)
elif config.transport == "sse":
if not config.url:
Expand All @@ -434,12 +474,12 @@ def transport_context_factory():
)
# Suppress MCP library error spam
self._suppress_mcp_sse_errors()
oauth_auth = build_oauth_provider(config)
# If using OAuth, strip any pre-existing Authorization headers to avoid conflicts
headers = dict(config.headers or {})
if oauth_auth is not None:
headers.pop("Authorization", None)
headers.pop("X-HF-Authorization", None)
headers, oauth_auth, user_auth_keys = _prepare_headers_and_auth(config)
if user_auth_keys:
logger.debug(
f"{server_name}: Using user-specified auth header(s); skipping OAuth provider.",
user_auth_headers=sorted(user_auth_keys),
)
return _add_none_to_context(
sse_client(
config.url,
Expand All @@ -453,19 +493,22 @@ def transport_context_factory():
raise ValueError(
f"Server '{server_name}' uses http transport but no url is specified"
)
oauth_auth = build_oauth_provider(config)
headers = dict(config.headers or {})
if oauth_auth is not None:
headers.pop("Authorization", None)
headers.pop("X-HF-Authorization", None)
headers, oauth_auth, user_auth_keys = _prepare_headers_and_auth(config)
if user_auth_keys:
logger.debug(
f"{server_name}: Using user-specified auth header(s); skipping OAuth provider.",
user_auth_headers=sorted(user_auth_keys),
)
channel_hook = None
if transport_metrics is not None:

def channel_hook(event):
try:
transport_metrics.record_event(event)
except Exception: # pragma: no cover - defensive guard
logger.debug(
"%s: transport metrics hook failed", server_name,
"%s: transport metrics hook failed",
server_name,
exc_info=True,
)

Expand Down
54 changes: 52 additions & 2 deletions src/fast_agent/ui/console_display.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from rich.panel import Panel
from rich.text import Text

from fast_agent.constants import REASONING
from fast_agent.ui import console
from fast_agent.ui.mcp_ui_utils import UILink
from fast_agent.ui.mermaid_utils import (
Expand Down Expand Up @@ -144,6 +145,25 @@ def __init__(self, config=None) -> None:
self._markup = config.logger.enable_markup if config else True
self._escape_xml = True

@staticmethod
def _format_elapsed(elapsed: float) -> str:
"""Format elapsed seconds for display."""
if elapsed < 0:
elapsed = 0.0
if elapsed < 0.001:
return "<1ms"
if elapsed < 1:
return f"{elapsed * 1000:.0f}ms"
if elapsed < 10:
return f"{elapsed:.2f}s"
if elapsed < 60:
return f"{elapsed:.1f}s"
minutes, seconds = divmod(elapsed, 60)
if minutes < 60:
return f"{int(minutes)}m {seconds:02.0f}s"
hours, minutes = divmod(int(minutes), 60)
return f"{hours}h {minutes:02d}m"

def display_message(
self,
content: Any,
Expand All @@ -156,6 +176,7 @@ def display_message(
is_error: bool = False,
truncate_content: bool = True,
additional_message: Text | None = None,
pre_content: Text | None = None,
) -> None:
"""
Unified method to display formatted messages to the console.
Expand All @@ -170,6 +191,8 @@ def display_message(
max_item_length: Optional max length for bottom metadata items (with ellipsis)
is_error: For tool results, whether this is an error (uses red color)
truncate_content: Whether to truncate long content
additional_message: Optional Rich Text appended after the main content
pre_content: Optional Rich Text shown before the main content
"""
# Get configuration for this message type
config = MESSAGE_CONFIGS[message_type]
Expand All @@ -191,6 +214,8 @@ def display_message(
self._create_combined_separator_status(left, right_info)

# Display the content
if pre_content and pre_content.plain:
console.console.print(pre_content, markup=self._markup)
self._display_content(
content, truncate_content, is_error, message_type, check_markdown_markers=False
)
Expand Down Expand Up @@ -544,7 +569,7 @@ def show_tool_result(self, result: CallToolResult, name: str | None = None) -> N

# Build transport channel info for bottom bar
channel = getattr(result, "transport_channel", None)
bottom_metadata = None
bottom_metadata_items: List[str] = []
if channel:
# Format channel info for bottom bar
if channel == "post-json":
Expand All @@ -560,7 +585,13 @@ def show_tool_result(self, result: CallToolResult, name: str | None = None) -> N
else:
transport_info = channel.upper()

bottom_metadata = [transport_info]
bottom_metadata_items.append(transport_info)

elapsed = getattr(result, "transport_elapsed", None)
if isinstance(elapsed, (int, float)):
bottom_metadata_items.append(self._format_elapsed(float(elapsed)))

bottom_metadata = bottom_metadata_items or None

# Build right info (without channel info)
right_info = f"[dim]tool result - {status}[/dim]"
Expand Down Expand Up @@ -724,8 +755,26 @@ async def show_assistant_message(
# Extract text from PromptMessageExtended if needed
from fast_agent.types import PromptMessageExtended

pre_content: Text | None = None

if isinstance(message_text, PromptMessageExtended):
display_text = message_text.last_text() or ""

channels = message_text.channels or {}
reasoning_blocks = channels.get(REASONING) or []
if reasoning_blocks:
from fast_agent.mcp.helpers.content_helpers import get_text

reasoning_segments = []
for block in reasoning_blocks:
text = get_text(block)
if text:
reasoning_segments.append(text)

if reasoning_segments:
joined = "\n".join(reasoning_segments)
if joined.strip():
pre_content = Text(joined, style="dim default")
else:
display_text = message_text

Expand All @@ -743,6 +792,7 @@ async def show_assistant_message(
max_item_length=max_item_length,
truncate_content=False, # Assistant messages shouldn't be truncated
additional_message=additional_message,
pre_content=pre_content,
)

# Handle mermaid diagrams separately (after the main message)
Expand Down
Loading
Loading